bson/internal/errutil/join_test.go
2025-03-17 20:58:26 +01:00

244 lines
6.7 KiB
Go

// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package errutil
import (
"context"
"errors"
"fmt"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
)
// TestJoin_Nil asserts that join returns a nil error for the same inputs that
// [errors.Join] returns a nil error.
func TestJoin_Nil(t *testing.T) {
t.Parallel()
assert.Equal(t, errors.Join(), join(), "errors.Join() != join()")
assert.Equal(t, errors.Join(nil), join(nil), "errors.Join(nil) != join(nil)")
assert.Equal(t, errors.Join(nil, nil), join(nil, nil), "errors.Join(nil, nil) != join(nil, nil)")
}
// TestJoin_Error asserts that join returns an error with the same error message
// as the error returned by [errors.Join].
func TestJoin_Error(t *testing.T) {
t.Parallel()
err1 := errors.New("err1")
err2 := errors.New("err2")
tests := []struct {
desc string
errs []error
}{{
desc: "single error",
errs: []error{err1},
}, {
desc: "two errors",
errs: []error{err1, err2},
}, {
desc: "two errors and a nil value",
errs: []error{err1, nil, err2},
}}
for _, test := range tests {
test := test // Capture range variable.
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
want := errors.Join(test.errs...).Error()
got := join(test.errs...).Error()
assert.Equal(t,
want,
got,
"errors.Join().Error() != join().Error() for input %v",
test.errs)
})
}
}
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
// to the error returned by [errors.Join] when passed to [errors.Is].
func TestJoin_ErrorsIs(t *testing.T) {
t.Parallel()
err1 := errors.New("err1")
err2 := errors.New("err2")
tests := []struct {
desc string
errs []error
target error
}{{
desc: "one error with a matching target",
errs: []error{err1},
target: err1,
}, {
desc: "one error with a non-matching target",
errs: []error{err1},
target: err2,
}, {
desc: "nil error",
errs: []error{nil},
target: err1,
}, {
desc: "no errors",
errs: []error{},
target: err1,
}, {
desc: "two different errors with a matching target",
errs: []error{err1, err2},
target: err2,
}, {
desc: "two identical errors with a matching target",
errs: []error{err1, err1},
target: err1,
}, {
desc: "wrapped error with a matching target",
errs: []error{fmt.Errorf("error: %w", err1)},
target: err1,
}, {
desc: "nested joined error with a matching target",
errs: []error{err1, join(err2, errors.New("nope"))},
target: err2,
}, {
desc: "nested joined error with no matching targets",
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
target: err2,
}, {
desc: "nested joined error with a wrapped matching target",
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
target: err1,
}, {
desc: "context.DeadlineExceeded",
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: context.DeadlineExceeded,
}, {
desc: "wrapped context.DeadlineExceeded",
errs: []error{err1, nil, fmt.Errorf("error: %w", context.DeadlineExceeded), err2},
target: context.DeadlineExceeded,
}}
for _, test := range tests {
test := test // Capture range variable.
t.Run(test.desc, func(t *testing.T) {
// Assert that top-level errors returned by errors.Join and join
// behave the same with errors.Is.
want := errors.Join(test.errs...)
got := join(test.errs...)
assert.Equal(t,
errors.Is(want, test.target),
errors.Is(got, test.target),
"errors.Join() and join() behave differently with errors.Is")
// Assert that wrapped errors returned by errors.Join and join
// behave the same with errors.Is.
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
got = fmt.Errorf("error: %w", join(test.errs...))
assert.Equal(t,
errors.Is(want, test.target),
errors.Is(got, test.target),
"errors.Join() and join(), when wrapped, behave differently with errors.Is")
})
}
}
type errType1 struct{}
func (errType1) Error() string { return "" }
type errType2 struct{}
func (errType2) Error() string { return "" }
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
// to the error returned by [errors.Join] when passed to [errors.As].
func TestJoin_ErrorsAs(t *testing.T) {
t.Parallel()
err1 := errType1{}
err2 := errType2{}
tests := []struct {
desc string
errs []error
target interface{}
}{{
desc: "one error with a matching target",
errs: []error{err1},
target: &errType1{},
}, {
desc: "one error with a non-matching target",
errs: []error{err1},
target: &errType2{},
}, {
desc: "nil error",
errs: []error{nil},
target: &errType1{},
}, {
desc: "no errors",
errs: []error{},
target: &errType1{},
}, {
desc: "two different errors with a matching target",
errs: []error{err1, err2},
target: &errType2{},
}, {
desc: "two identical errors with a matching target",
errs: []error{err1, err1},
target: &errType1{},
}, {
desc: "wrapped error with a matching target",
errs: []error{fmt.Errorf("error: %w", err1)},
target: &errType1{},
}, {
desc: "nested joined error with a matching target",
errs: []error{err1, join(err2, errors.New("nope"))},
target: &errType2{},
}, {
desc: "nested joined error with no matching targets",
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
target: &errType2{},
}, {
desc: "nested joined error with a wrapped matching target",
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
target: &errType1{},
}, {
desc: "context.DeadlineExceeded",
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: &errType2{},
}}
for _, test := range tests {
test := test // Capture range variable.
t.Run(test.desc, func(t *testing.T) {
// Assert that top-level errors returned by errors.Join and join
// behave the same with errors.As.
want := errors.Join(test.errs...)
got := join(test.errs...)
assert.Equal(t,
errors.As(want, test.target),
errors.As(got, test.target),
"errors.Join() and join() behave differently with errors.As")
// Assert that wrapped errors returned by errors.Join and join
// behave the same with errors.As.
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
got = fmt.Errorf("error: %w", join(test.errs...))
assert.Equal(t,
errors.As(want, test.target),
errors.As(got, test.target),
"errors.Join() and join(), when wrapped, behave differently with errors.As")
})
}
}