bson/internal/aws/credentials/chain_provider_test.go
2025-03-17 20:58:26 +01:00

177 lines
4.4 KiB
Go

// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider_test.go
// See THIRD-PARTY-NOTICES for original license terms
package credentials
import (
"reflect"
"testing"
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
)
type secondStubProvider struct {
creds Value
expired bool
err error
}
func (s *secondStubProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "secondStubProvider"
return s.creds, s.err
}
func (s *secondStubProvider) IsExpired() bool {
return s.expired
}
func TestChainProviderWithNames(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&secondStubProvider{
creds: Value{
AccessKeyID: "AKIF",
SecretAccessKey: "NOSECRET",
SessionToken: "",
},
},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if e, a := "secondStubProvider", creds.ProviderName; e != a {
t.Errorf("Expect provider name to match, %v got, %v", e, a)
}
// Also check credentials
if e, a := "AKIF", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "NOSECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestChainProviderIsExpired(t *testing.T) {
stubProvider := &stubProvider{expired: true}
p := &ChainProvider{
Providers: []Provider{
stubProvider,
},
}
if !p.IsExpired() {
t.Errorf("Expect expired to be true before any Retrieve")
}
_, err := p.Retrieve()
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if p.IsExpired() {
t.Errorf("Expect not expired after retrieve")
}
stubProvider.expired = true
if !p.IsExpired() {
t.Errorf("Expect return of expired provider")
}
_, err = p.Retrieve()
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if p.IsExpired() {
t.Errorf("Expect not expired after retrieve")
}
}
func TestChainProviderWithNoProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{},
}
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
if err.Error() != "NoCredentialProviders: no valid providers in chain" {
t.Errorf("Expect no providers error returned, got %v", err)
}
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
errs := []error{
awserr.New("FirstError", "first provider error", nil),
awserr.New("SecondError", "second provider error", nil),
}
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: errs[0]},
&stubProvider{err: errs[1]},
},
}
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
if e, a := expectErr, err; !reflect.DeepEqual(e, a) {
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
}
}