193 lines
4.3 KiB
Go
193 lines
4.3 KiB
Go
// Copyright (C) MongoDB, Inc. 2023-present.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
// not use this file except in compliance with the License. You may obtain
|
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials_test.go
|
|
// See THIRD-PARTY-NOTICES for original license terms
|
|
|
|
package credentials
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
|
|
)
|
|
|
|
func isExpired(c *Credentials) bool {
|
|
c.m.RLock()
|
|
defer c.m.RUnlock()
|
|
|
|
return c.isExpiredLocked(c.creds)
|
|
}
|
|
|
|
type stubProvider struct {
|
|
creds Value
|
|
retrievedCount int
|
|
expired bool
|
|
err error
|
|
}
|
|
|
|
func (s *stubProvider) Retrieve() (Value, error) {
|
|
s.retrievedCount++
|
|
s.expired = false
|
|
s.creds.ProviderName = "stubProvider"
|
|
return s.creds, s.err
|
|
}
|
|
func (s *stubProvider) IsExpired() bool {
|
|
return s.expired
|
|
}
|
|
|
|
func TestCredentialsGet(t *testing.T) {
|
|
c := NewCredentials(&stubProvider{
|
|
creds: Value{
|
|
AccessKeyID: "AKID",
|
|
SecretAccessKey: "SECRET",
|
|
SessionToken: "",
|
|
},
|
|
expired: true,
|
|
})
|
|
|
|
creds, err := c.GetWithContext(context.Background())
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
if e, a := "AKID", creds.AccessKeyID; e != a {
|
|
t.Errorf("Expect access key ID to match, %v got %v", e, a)
|
|
}
|
|
if e, a := "SECRET", creds.SecretAccessKey; e != a {
|
|
t.Errorf("Expect secret access key to match, %v got %v", e, a)
|
|
}
|
|
if v := creds.SessionToken; len(v) != 0 {
|
|
t.Errorf("Expect session token to be empty, %v", v)
|
|
}
|
|
}
|
|
|
|
func TestCredentialsGetWithError(t *testing.T) {
|
|
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
|
|
|
|
_, err := c.GetWithContext(context.Background())
|
|
if e, a := "provider error", err.(awserr.Error).Code(); e != a {
|
|
t.Errorf("Expected provider error, %v got %v", e, a)
|
|
}
|
|
}
|
|
|
|
func TestCredentialsExpire(t *testing.T) {
|
|
stub := &stubProvider{}
|
|
c := NewCredentials(stub)
|
|
|
|
stub.expired = false
|
|
if !isExpired(c) {
|
|
t.Errorf("Expected to start out expired")
|
|
}
|
|
|
|
_, err := c.GetWithContext(context.Background())
|
|
if err != nil {
|
|
t.Errorf("Expected no err, got %v", err)
|
|
}
|
|
if isExpired(c) {
|
|
t.Errorf("Expected not to be expired")
|
|
}
|
|
|
|
stub.expired = true
|
|
if !isExpired(c) {
|
|
t.Errorf("Expected to be expired")
|
|
}
|
|
}
|
|
|
|
func TestCredentialsGetWithProviderName(t *testing.T) {
|
|
stub := &stubProvider{}
|
|
|
|
c := NewCredentials(stub)
|
|
|
|
creds, err := c.GetWithContext(context.Background())
|
|
if err != nil {
|
|
t.Errorf("Expected no error, got %v", err)
|
|
}
|
|
if e, a := creds.ProviderName, "stubProvider"; e != a {
|
|
t.Errorf("Expected provider name to match, %v got %v", e, a)
|
|
}
|
|
}
|
|
|
|
type MockProvider struct {
|
|
// The date/time when to expire on
|
|
expiration time.Time
|
|
|
|
// If set will be used by IsExpired to determine the current time.
|
|
// Defaults to time.Now if CurrentTime is not set. Available for testing
|
|
// to be able to mock out the current time.
|
|
CurrentTime func() time.Time
|
|
}
|
|
|
|
// IsExpired returns if the credentials are expired.
|
|
func (e *MockProvider) IsExpired() bool {
|
|
curTime := e.CurrentTime
|
|
if curTime == nil {
|
|
curTime = time.Now
|
|
}
|
|
return e.expiration.Before(curTime())
|
|
}
|
|
|
|
func (*MockProvider) Retrieve() (Value, error) {
|
|
return Value{}, nil
|
|
}
|
|
|
|
func TestCredentialsIsExpired_Race(_ *testing.T) {
|
|
creds := NewChainCredentials([]Provider{&MockProvider{}})
|
|
|
|
starter := make(chan struct{})
|
|
var wg sync.WaitGroup
|
|
wg.Add(10)
|
|
for i := 0; i < 10; i++ {
|
|
go func() {
|
|
defer wg.Done()
|
|
<-starter
|
|
for i := 0; i < 100; i++ {
|
|
isExpired(creds)
|
|
}
|
|
}()
|
|
}
|
|
close(starter)
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
type stubProviderConcurrent struct {
|
|
stubProvider
|
|
done chan struct{}
|
|
}
|
|
|
|
func (s *stubProviderConcurrent) Retrieve() (Value, error) {
|
|
<-s.done
|
|
return s.stubProvider.Retrieve()
|
|
}
|
|
|
|
func TestCredentialsGetConcurrent(t *testing.T) {
|
|
stub := &stubProviderConcurrent{
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
c := NewCredentials(stub)
|
|
done := make(chan struct{})
|
|
|
|
for i := 0; i < 2; i++ {
|
|
go func() {
|
|
_, err := c.GetWithContext(context.Background())
|
|
if err != nil {
|
|
t.Errorf("Expected no err, got %v", err)
|
|
}
|
|
done <- struct{}{}
|
|
}()
|
|
}
|
|
|
|
// Validates that a single call to Retrieve is shared between two calls to Get
|
|
stub.done <- struct{}{}
|
|
<-done
|
|
<-done
|
|
}
|