bson/internal/aws/credentials/credentials_test.go
2025-03-17 20:58:26 +01:00

193 lines
4.3 KiB
Go

// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials_test.go
// See THIRD-PARTY-NOTICES for original license terms
package credentials
import (
"context"
"sync"
"testing"
"time"
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
)
func isExpired(c *Credentials) bool {
c.m.RLock()
defer c.m.RUnlock()
return c.isExpiredLocked(c.creds)
}
type stubProvider struct {
creds Value
retrievedCount int
expired bool
err error
}
func (s *stubProvider) Retrieve() (Value, error) {
s.retrievedCount++
s.expired = false
s.creds.ProviderName = "stubProvider"
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
return s.expired
}
func TestCredentialsGet(t *testing.T) {
c := NewCredentials(&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
expired: true,
})
creds, err := c.GetWithContext(context.Background())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.GetWithContext(context.Background())
if e, a := "provider error", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected provider error, %v got %v", e, a)
}
}
func TestCredentialsExpire(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
stub.expired = false
if !isExpired(c) {
t.Errorf("Expected to start out expired")
}
_, err := c.GetWithContext(context.Background())
if err != nil {
t.Errorf("Expected no err, got %v", err)
}
if isExpired(c) {
t.Errorf("Expected not to be expired")
}
stub.expired = true
if !isExpired(c) {
t.Errorf("Expected to be expired")
}
}
func TestCredentialsGetWithProviderName(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
creds, err := c.GetWithContext(context.Background())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := creds.ProviderName, "stubProvider"; e != a {
t.Errorf("Expected provider name to match, %v got %v", e, a)
}
}
type MockProvider struct {
// The date/time when to expire on
expiration time.Time
// If set will be used by IsExpired to determine the current time.
// Defaults to time.Now if CurrentTime is not set. Available for testing
// to be able to mock out the current time.
CurrentTime func() time.Time
}
// IsExpired returns if the credentials are expired.
func (e *MockProvider) IsExpired() bool {
curTime := e.CurrentTime
if curTime == nil {
curTime = time.Now
}
return e.expiration.Before(curTime())
}
func (*MockProvider) Retrieve() (Value, error) {
return Value{}, nil
}
func TestCredentialsIsExpired_Race(_ *testing.T) {
creds := NewChainCredentials([]Provider{&MockProvider{}})
starter := make(chan struct{})
var wg sync.WaitGroup
wg.Add(10)
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
<-starter
for i := 0; i < 100; i++ {
isExpired(creds)
}
}()
}
close(starter)
wg.Wait()
}
type stubProviderConcurrent struct {
stubProvider
done chan struct{}
}
func (s *stubProviderConcurrent) Retrieve() (Value, error) {
<-s.done
return s.stubProvider.Retrieve()
}
func TestCredentialsGetConcurrent(t *testing.T) {
stub := &stubProviderConcurrent{
done: make(chan struct{}),
}
c := NewCredentials(stub)
done := make(chan struct{})
for i := 0; i < 2; i++ {
go func() {
_, err := c.GetWithContext(context.Background())
if err != nil {
t.Errorf("Expected no err, got %v", err)
}
done <- struct{}{}
}()
}
// Validates that a single call to Retrieve is shared between two calls to Get
stub.done <- struct{}{}
<-done
<-done
}