bson/internal/credproviders/assume_role_provider.go
2025-03-17 20:58:26 +01:00

149 lines
4.8 KiB
Go

// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package credproviders
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"time"
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
"gitea.psichedelico.com/go/bson/internal/uuid"
)
const (
// assumeRoleProviderName provides a name of assume role provider
assumeRoleProviderName = "AssumeRoleProvider"
stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15`
)
// An AssumeRoleProvider retrieves credentials for assume role with web identity.
type AssumeRoleProvider struct {
AwsRoleArnEnv EnvVar
AwsWebIdentityTokenFileEnv EnvVar
AwsRoleSessionNameEnv EnvVar
httpClient *http.Client
expiration time.Time
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
expiryWindow time.Duration
}
// NewAssumeRoleProvider returns a pointer to an assume role provider.
func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider {
return &AssumeRoleProvider{
// AwsRoleArnEnv is the environment variable for AWS_ROLE_ARN
AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"),
// AwsWebIdentityTokenFileEnv is the environment variable for AWS_WEB_IDENTITY_TOKEN_FILE
AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"),
// AwsRoleSessionNameEnv is the environment variable for AWS_ROLE_SESSION_NAME
AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"),
httpClient: httpClient,
expiryWindow: expiryWindow,
}
}
// RetrieveWithContext retrieves the keys from the AWS service.
func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
const defaultHTTPTimeout = 10 * time.Second
v := credentials.Value{ProviderName: assumeRoleProviderName}
roleArn := a.AwsRoleArnEnv.Get()
tokenFile := a.AwsWebIdentityTokenFileEnv.Get()
if tokenFile == "" && roleArn == "" {
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing")
}
if tokenFile != "" && roleArn == "" {
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing")
}
if tokenFile == "" && roleArn != "" {
return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing")
}
token, err := ioutil.ReadFile(tokenFile)
if err != nil {
return v, err
}
sessionName := a.AwsRoleSessionNameEnv.Get()
if sessionName == "" {
// Use a UUID if the RoleSessionName is not given.
id, err := uuid.New()
if err != nil {
return v, err
}
sessionName = id.String()
}
fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token))
req, err := http.NewRequest(http.MethodPost, fullURI, nil)
if err != nil {
return v, err
}
req.Header.Set("Accept", "application/json")
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
defer cancel()
resp, err := a.httpClient.Do(req.WithContext(ctx))
if err != nil {
return v, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return v, fmt.Errorf("response failure: %s", resp.Status)
}
var stsResp struct {
Response struct {
Result struct {
Credentials struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string `json:"SecretAccessKey"`
Token string `json:"SessionToken"`
Expiration float64 `json:"Expiration"`
} `json:"Credentials"`
} `json:"AssumeRoleWithWebIdentityResult"`
} `json:"AssumeRoleWithWebIdentityResponse"`
}
err = json.NewDecoder(resp.Body).Decode(&stsResp)
if err != nil {
return v, err
}
v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID
v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey
v.SessionToken = stsResp.Response.Result.Credentials.Token
if !v.HasKeys() {
return v, errors.New("failed to retrieve web identity keys")
}
sec := int64(stsResp.Response.Result.Credentials.Expiration)
a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow)
return v, nil
}
// Retrieve retrieves the keys from the AWS service.
func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
return a.RetrieveWithContext(context.Background())
}
// IsExpired returns true if the credentials are expired.
func (a *AssumeRoleProvider) IsExpired() bool {
return a.expiration.Before(time.Now())
}