149 lines
4.8 KiB
Go
149 lines
4.8 KiB
Go
// Copyright (C) MongoDB, Inc. 2023-present.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
// not use this file except in compliance with the License. You may obtain
|
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
package credproviders
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"time"
|
|
|
|
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
|
"gitea.psichedelico.com/go/bson/internal/uuid"
|
|
)
|
|
|
|
const (
|
|
// assumeRoleProviderName provides a name of assume role provider
|
|
assumeRoleProviderName = "AssumeRoleProvider"
|
|
|
|
stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15`
|
|
)
|
|
|
|
// An AssumeRoleProvider retrieves credentials for assume role with web identity.
|
|
type AssumeRoleProvider struct {
|
|
AwsRoleArnEnv EnvVar
|
|
AwsWebIdentityTokenFileEnv EnvVar
|
|
AwsRoleSessionNameEnv EnvVar
|
|
|
|
httpClient *http.Client
|
|
expiration time.Time
|
|
|
|
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
|
|
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
|
|
//
|
|
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
|
// 10 seconds before the credentials are actually expired.
|
|
expiryWindow time.Duration
|
|
}
|
|
|
|
// NewAssumeRoleProvider returns a pointer to an assume role provider.
|
|
func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider {
|
|
return &AssumeRoleProvider{
|
|
// AwsRoleArnEnv is the environment variable for AWS_ROLE_ARN
|
|
AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"),
|
|
// AwsWebIdentityTokenFileEnv is the environment variable for AWS_WEB_IDENTITY_TOKEN_FILE
|
|
AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"),
|
|
// AwsRoleSessionNameEnv is the environment variable for AWS_ROLE_SESSION_NAME
|
|
AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"),
|
|
httpClient: httpClient,
|
|
expiryWindow: expiryWindow,
|
|
}
|
|
}
|
|
|
|
// RetrieveWithContext retrieves the keys from the AWS service.
|
|
func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
|
|
const defaultHTTPTimeout = 10 * time.Second
|
|
|
|
v := credentials.Value{ProviderName: assumeRoleProviderName}
|
|
|
|
roleArn := a.AwsRoleArnEnv.Get()
|
|
tokenFile := a.AwsWebIdentityTokenFileEnv.Get()
|
|
if tokenFile == "" && roleArn == "" {
|
|
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing")
|
|
}
|
|
if tokenFile != "" && roleArn == "" {
|
|
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing")
|
|
}
|
|
if tokenFile == "" && roleArn != "" {
|
|
return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing")
|
|
}
|
|
token, err := ioutil.ReadFile(tokenFile)
|
|
if err != nil {
|
|
return v, err
|
|
}
|
|
|
|
sessionName := a.AwsRoleSessionNameEnv.Get()
|
|
if sessionName == "" {
|
|
// Use a UUID if the RoleSessionName is not given.
|
|
id, err := uuid.New()
|
|
if err != nil {
|
|
return v, err
|
|
}
|
|
sessionName = id.String()
|
|
}
|
|
|
|
fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token))
|
|
|
|
req, err := http.NewRequest(http.MethodPost, fullURI, nil)
|
|
if err != nil {
|
|
return v, err
|
|
}
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
|
defer cancel()
|
|
resp, err := a.httpClient.Do(req.WithContext(ctx))
|
|
if err != nil {
|
|
return v, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return v, fmt.Errorf("response failure: %s", resp.Status)
|
|
}
|
|
|
|
var stsResp struct {
|
|
Response struct {
|
|
Result struct {
|
|
Credentials struct {
|
|
AccessKeyID string `json:"AccessKeyId"`
|
|
SecretAccessKey string `json:"SecretAccessKey"`
|
|
Token string `json:"SessionToken"`
|
|
Expiration float64 `json:"Expiration"`
|
|
} `json:"Credentials"`
|
|
} `json:"AssumeRoleWithWebIdentityResult"`
|
|
} `json:"AssumeRoleWithWebIdentityResponse"`
|
|
}
|
|
|
|
err = json.NewDecoder(resp.Body).Decode(&stsResp)
|
|
if err != nil {
|
|
return v, err
|
|
}
|
|
v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID
|
|
v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey
|
|
v.SessionToken = stsResp.Response.Result.Credentials.Token
|
|
if !v.HasKeys() {
|
|
return v, errors.New("failed to retrieve web identity keys")
|
|
}
|
|
sec := int64(stsResp.Response.Result.Credentials.Expiration)
|
|
a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow)
|
|
|
|
return v, nil
|
|
}
|
|
|
|
// Retrieve retrieves the keys from the AWS service.
|
|
func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
|
|
return a.RetrieveWithContext(context.Background())
|
|
}
|
|
|
|
// IsExpired returns true if the credentials are expired.
|
|
func (a *AssumeRoleProvider) IsExpired() bool {
|
|
return a.expiration.Before(time.Now())
|
|
}
|