Initial commit

This commit is contained in:
Francesco Bellini 2025-03-17 20:58:26 +01:00
commit 4b4cceb81c
2206 changed files with 469613 additions and 0 deletions

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

1692
THIRD-PARTY-NOTICES Normal file

File diff suppressed because it is too large Load Diff

42
array_codec.go Normal file
View File

@ -0,0 +1,42 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
// arrayCodec is the Codec used for bsoncore.Array values.
type arrayCodec struct{}
// EncodeValue is the ValueEncoder for bsoncore.Array values.
func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreArray {
return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
arr := val.Interface().(bsoncore.Array)
return copyArrayFromBytes(vw, arr)
}
// DecodeValue is the ValueDecoder for bsoncore.Array values.
func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tCoreArray {
return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
arr, err := appendArrayBytes(val.Interface().(bsoncore.Array), vr)
val.Set(reflect.ValueOf(arr))
return err
}

449
benchmark_test.go Normal file
View File

@ -0,0 +1,449 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"sync"
"testing"
)
type encodetest struct {
Field1String string
Field1Int64 int64
Field1Float64 float64
Field2String string
Field2Int64 int64
Field2Float64 float64
Field3String string
Field3Int64 int64
Field3Float64 float64
Field4String string
Field4Int64 int64
Field4Float64 float64
}
type nestedtest1 struct {
Nested nestedtest2
}
type nestedtest2 struct {
Nested nestedtest3
}
type nestedtest3 struct {
Nested nestedtest4
}
type nestedtest4 struct {
Nested nestedtest5
}
type nestedtest5 struct {
Nested nestedtest6
}
type nestedtest6 struct {
Nested nestedtest7
}
type nestedtest7 struct {
Nested nestedtest8
}
type nestedtest8 struct {
Nested nestedtest9
}
type nestedtest9 struct {
Nested nestedtest10
}
type nestedtest10 struct {
Nested nestedtest11
}
type nestedtest11 struct {
Nested encodetest
}
var encodetestInstance = encodetest{
Field1String: "foo",
Field1Int64: 1,
Field1Float64: 3.0,
Field2String: "bar",
Field2Int64: 2,
Field2Float64: 3.1,
Field3String: "baz",
Field3Int64: 3,
Field3Float64: 3.14,
Field4String: "qux",
Field4Int64: 4,
Field4Float64: 3.141,
}
var nestedInstance = nestedtest1{
nestedtest2{
nestedtest3{
nestedtest4{
nestedtest5{
nestedtest6{
nestedtest7{
nestedtest8{
nestedtest9{
nestedtest10{
nestedtest11{
encodetest{
Field1String: "foo",
Field1Int64: 1,
Field1Float64: 3.0,
Field2String: "bar",
Field2Int64: 2,
Field2Float64: 3.1,
Field3String: "baz",
Field3Int64: 3,
Field3Float64: 3.14,
Field4String: "qux",
Field4Int64: 4,
Field4Float64: 3.141,
},
},
},
},
},
},
},
},
},
},
},
}
const extendedBSONDir = "./testdata/extended_bson"
var (
extJSONFiles map[string]map[string]interface{}
extJSONFilesMu sync.Mutex
)
// readExtJSONFile reads the GZIP-compressed extended JSON document from the given filename in the
// "extended BSON" test data directory (./testdata/extended_bson) and returns it as a
// map[string]interface{}. It panics on any errors.
func readExtJSONFile(filename string) map[string]interface{} {
extJSONFilesMu.Lock()
defer extJSONFilesMu.Unlock()
if v, ok := extJSONFiles[filename]; ok {
return v
}
filePath := path.Join(extendedBSONDir, filename)
file, err := os.Open(filePath)
if err != nil {
panic(fmt.Sprintf("error opening file %q: %s", filePath, err))
}
defer func() {
_ = file.Close()
}()
gz, err := gzip.NewReader(file)
if err != nil {
panic(fmt.Sprintf("error creating GZIP reader: %s", err))
}
defer func() {
_ = gz.Close()
}()
data, err := ioutil.ReadAll(gz)
if err != nil {
panic(fmt.Sprintf("error reading GZIP contents of file: %s", err))
}
var v map[string]interface{}
err = UnmarshalExtJSON(data, false, &v)
if err != nil {
panic(fmt.Sprintf("error unmarshalling extended JSON: %s", err))
}
if extJSONFiles == nil {
extJSONFiles = make(map[string]map[string]interface{})
}
extJSONFiles[filename] = v
return v
}
func BenchmarkMarshal(b *testing.B) {
cases := []struct {
desc string
value interface{}
}{
{
desc: "simple struct",
value: encodetestInstance,
},
{
desc: "nested struct",
value: nestedInstance,
},
{
desc: "deep_bson.json.gz",
value: readExtJSONFile("deep_bson.json.gz"),
},
{
desc: "flat_bson.json.gz",
value: readExtJSONFile("flat_bson.json.gz"),
},
{
desc: "full_bson.json.gz",
value: readExtJSONFile("full_bson.json.gz"),
},
}
for _, tc := range cases {
b.Run(tc.desc, func(b *testing.B) {
b.Run("BSON", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling BSON: %s", err)
}
}
})
b.Run("extJSON", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := MarshalExtJSON(tc.value, true, false)
if err != nil {
b.Errorf("error marshalling extended JSON: %s", err)
}
}
})
b.Run("JSON", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := json.Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling JSON: %s", err)
}
}
})
})
}
}
func BenchmarkUnmarshal(b *testing.B) {
cases := []struct {
desc string
value interface{}
}{
{
desc: "simple struct",
value: encodetestInstance,
},
{
desc: "nested struct",
value: nestedInstance,
},
{
desc: "deep_bson.json.gz",
value: readExtJSONFile("deep_bson.json.gz"),
},
{
desc: "flat_bson.json.gz",
value: readExtJSONFile("flat_bson.json.gz"),
},
{
desc: "full_bson.json.gz",
value: readExtJSONFile("full_bson.json.gz"),
},
}
for _, tc := range cases {
b.Run(tc.desc, func(b *testing.B) {
b.Run("BSON", func(b *testing.B) {
data, err := Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling BSON: %s", err)
return
}
b.ResetTimer()
var v2 map[string]interface{}
for i := 0; i < b.N; i++ {
err := Unmarshal(data, &v2)
if err != nil {
b.Errorf("error unmarshalling BSON: %s", err)
}
}
})
b.Run("extJSON", func(b *testing.B) {
data, err := MarshalExtJSON(tc.value, true, false)
if err != nil {
b.Errorf("error marshalling extended JSON: %s", err)
return
}
b.ResetTimer()
var v2 map[string]interface{}
for i := 0; i < b.N; i++ {
err := UnmarshalExtJSON(data, true, &v2)
if err != nil {
b.Errorf("error unmarshalling extended JSON: %s", err)
}
}
})
b.Run("JSON", func(b *testing.B) {
data, err := json.Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling JSON: %s", err)
return
}
b.ResetTimer()
var v2 map[string]interface{}
for i := 0; i < b.N; i++ {
err := json.Unmarshal(data, &v2)
if err != nil {
b.Errorf("error unmarshalling JSON: %s", err)
}
}
})
})
}
}
// The following benchmarks are copied from the Go standard library's
// encoding/json package.
type codeResponse struct {
Tree *codeNode `json:"tree"`
Username string `json:"username"`
}
type codeNode struct {
Name string `json:"name"`
Kids []*codeNode `json:"kids"`
CLWeight float64 `json:"cl_weight"`
Touches int `json:"touches"`
MinT int64 `json:"min_t"`
MaxT int64 `json:"max_t"`
MeanT int64 `json:"mean_t"`
}
var codeJSON []byte
var codeBSON []byte
var codeStruct codeResponse
func codeInit() {
f, err := os.Open("testdata/code.json.gz")
if err != nil {
panic(err)
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
panic(err)
}
data, err := io.ReadAll(gz)
if err != nil {
panic(err)
}
codeJSON = data
if err := json.Unmarshal(codeJSON, &codeStruct); err != nil {
panic("json.Unmarshal code.json: " + err.Error())
}
if data, err = json.Marshal(&codeStruct); err != nil {
panic("json.Marshal code.json: " + err.Error())
}
if codeBSON, err = Marshal(&codeStruct); err != nil {
panic("Marshal code.json: " + err.Error())
}
if !bytes.Equal(data, codeJSON) {
println("different lengths", len(data), len(codeJSON))
for i := 0; i < len(data) && i < len(codeJSON); i++ {
if data[i] != codeJSON[i] {
println("re-marshal: changed at byte", i)
println("orig: ", string(codeJSON[i-10:i+10]))
println("new: ", string(data[i-10:i+10]))
break
}
}
panic("re-marshal code.json: different result")
}
}
func BenchmarkCodeUnmarshal(b *testing.B) {
b.ReportAllocs()
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
b.Run("BSON", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var r codeResponse
if err := Unmarshal(codeBSON, &r); err != nil {
b.Fatal("Unmarshal:", err)
}
}
})
b.SetBytes(int64(len(codeBSON)))
})
b.Run("JSON", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var r codeResponse
if err := json.Unmarshal(codeJSON, &r); err != nil {
b.Fatal("json.Unmarshal:", err)
}
}
})
b.SetBytes(int64(len(codeJSON)))
})
}
func BenchmarkCodeMarshal(b *testing.B) {
b.ReportAllocs()
if codeJSON == nil {
b.StopTimer()
codeInit()
b.StartTimer()
}
b.Run("BSON", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := Marshal(&codeStruct); err != nil {
b.Fatal("Marshal:", err)
}
}
})
b.SetBytes(int64(len(codeBSON)))
})
b.Run("JSON", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := json.Marshal(&codeStruct); err != nil {
b.Fatal("json.Marshal:", err)
}
}
})
b.SetBytes(int64(len(codeJSON)))
})
}

View File

@ -0,0 +1,191 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/hex"
"encoding/json"
"math"
"os"
"path"
"testing"
"gitea.psichedelico.com/go/bson/internal/require"
)
const bsonBinaryVectorDir = "./testdata/bson-binary-vector/"
type bsonBinaryVectorTests struct {
Description string `json:"description"`
TestKey string `json:"test_key"`
Tests []bsonBinaryVectorTestCase `json:"tests"`
}
type bsonBinaryVectorTestCase struct {
Description string `json:"description"`
Valid bool `json:"valid"`
Vector []interface{} `json:"vector"`
DtypeHex string `json:"dtype_hex"`
DtypeAlias string `json:"dtype_alias"`
Padding int `json:"padding"`
CanonicalBson string `json:"canonical_bson"`
}
func TestBsonBinaryVectorSpec(t *testing.T) {
t.Parallel()
jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir)
require.NoErrorf(t, err, "error finding JSON files in %s: %v", bsonBinaryVectorDir, err)
for _, file := range jsonFiles {
filepath := path.Join(bsonBinaryVectorDir, file)
content, err := os.ReadFile(filepath)
require.NoErrorf(t, err, "reading test file %s", filepath)
var tests bsonBinaryVectorTests
require.NoErrorf(t, json.Unmarshal(content, &tests), "parsing test file %s", filepath)
t.Run(tests.Description, func(t *testing.T) {
t.Parallel()
for _, test := range tests.Tests {
test := test
t.Run(test.Description, func(t *testing.T) {
t.Parallel()
runBsonBinaryVectorTest(t, tests.TestKey, test)
})
}
})
}
t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) {
t.Parallel()
t.Run("Marshaling", func(t *testing.T) {
_, err := NewPackedBitVector(nil, 1)
require.EqualError(t, err, errNonZeroVectorPadding.Error())
})
})
t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) {
t.Parallel()
t.Run("Marshaling", func(t *testing.T) {
_, err := NewPackedBitVector(nil, 8)
require.EqualError(t, err, errVectorPaddingTooLarge.Error())
})
})
}
func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
v := make([]T, len(s))
for i, e := range s {
f := math.NaN()
switch val := e.(type) {
case float64:
f = val
case string:
if val == "inf" {
f = math.Inf(0)
} else if val == "-inf" {
f = math.Inf(-1)
}
}
v[i] = T(f)
}
return v
}
func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) {
testVector := make(map[string]Vector)
switch alias := test.DtypeHex; alias {
case "0x03":
testVector[testKey] = Vector{
dType: Int8Vector,
int8Data: convertSlice[int8](test.Vector),
}
case "0x27":
testVector[testKey] = Vector{
dType: Float32Vector,
float32Data: convertSlice[float32](test.Vector),
}
case "0x10":
testVector[testKey] = Vector{
dType: PackedBitVector,
bitData: convertSlice[byte](test.Vector),
bitPadding: uint8(test.Padding),
}
default:
t.Fatalf("unsupported vector type: %s", alias)
}
testBSON, err := hex.DecodeString(test.CanonicalBson)
require.NoError(t, err, "decoding canonical BSON")
t.Run("Unmarshaling", func(t *testing.T) {
skipCases := map[string]string{
"Overflow Vector INT8": "compile-time restriction",
"Underflow Vector INT8": "compile-time restriction",
"INT8 with float inputs": "compile-time restriction",
"Overflow Vector PACKED_BIT": "compile-time restriction",
"Underflow Vector PACKED_BIT": "compile-time restriction",
"Vector with float values PACKED_BIT": "compile-time restriction",
"Negative padding PACKED_BIT": "compile-time restriction",
}
if reason, ok := skipCases[test.Description]; ok {
t.Skipf("skip test case %s: %s", test.Description, reason)
}
errMap := map[string]string{
"FLOAT32 with padding": "padding must be 0",
"INT8 with padding": "padding must be 0",
"Padding specified with no vector data PACKED_BIT": "padding must be 0",
"Exceeding maximum padding PACKED_BIT": "padding cannot be larger than 7",
}
t.Parallel()
var got map[string]Vector
err := Unmarshal(testBSON, &got)
if test.Valid {
require.NoError(t, err)
require.Equal(t, testVector, got)
} else if errMsg, ok := errMap[test.Description]; ok {
require.ErrorContains(t, err, errMsg)
} else {
require.Error(t, err)
}
})
t.Run("Marshaling", func(t *testing.T) {
skipCases := map[string]string{
"FLOAT32 with padding": "private padding field",
"Insufficient vector data with 3 bytes FLOAT32": "invalid case",
"Insufficient vector data with 5 bytes FLOAT32": "invalid case",
"Overflow Vector INT8": "compile-time restriction",
"Underflow Vector INT8": "compile-time restriction",
"INT8 with padding": "private padding field",
"INT8 with float inputs": "compile-time restriction",
"Overflow Vector PACKED_BIT": "compile-time restriction",
"Underflow Vector PACKED_BIT": "compile-time restriction",
"Vector with float values PACKED_BIT": "compile-time restriction",
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
"Negative padding PACKED_BIT": "compile-time restriction",
}
if reason, ok := skipCases[test.Description]; ok {
t.Skipf("skip test case %s: %s", test.Description, reason)
}
t.Parallel()
got, err := Marshal(testVector)
require.NoError(t, err)
require.Equal(t, testBSON, got)
})
}

504
bson_corpus_spec_test.go Normal file
View File

@ -0,0 +1,504 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path"
"reflect"
"strconv"
"strings"
"testing"
"unicode"
"unicode/utf8"
"gitea.psichedelico.com/go/bson/internal/assert"
"gitea.psichedelico.com/go/bson/internal/require"
"github.com/google/go-cmp/cmp"
)
type testCase struct {
Description string `json:"description"`
BsonType string `json:"bson_type"`
TestKey *string `json:"test_key"`
Valid []validityTestCase `json:"valid"`
DecodeErrors []decodeErrorTestCase `json:"decodeErrors"`
ParseErrors []parseErrorTestCase `json:"parseErrors"`
Deprecated *bool `json:"deprecated"`
}
type validityTestCase struct {
Description string `json:"description"`
CanonicalBson string `json:"canonical_bson"`
CanonicalExtJSON string `json:"canonical_extjson"`
RelaxedExtJSON *string `json:"relaxed_extjson"`
DegenerateBSON *string `json:"degenerate_bson"`
DegenerateExtJSON *string `json:"degenerate_extjson"`
ConvertedBSON *string `json:"converted_bson"`
ConvertedExtJSON *string `json:"converted_extjson"`
Lossy *bool `json:"lossy"`
}
type decodeErrorTestCase struct {
Description string `json:"description"`
Bson string `json:"bson"`
}
type parseErrorTestCase struct {
Description string `json:"description"`
String string `json:"string"`
}
const dataDir = "./testdata/bson-corpus/"
func findJSONFilesInDir(dir string) ([]string, error) {
files := make([]string, 0)
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
for _, entry := range entries {
if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
continue
}
files = append(files, entry.Name())
}
return files, nil
}
// seedExtJSON will add the byte representation of the "extJSON" string to the fuzzer's coprus.
func seedExtJSON(f *testing.F, extJSON string, extJSONType string, desc string) {
jbytes, err := jsonToBytes(extJSON, extJSONType, desc)
if err != nil {
f.Fatalf("failed to convert JSON to bytes: %v", err)
}
f.Add(jbytes)
}
// seedTestCase will add the byte representation for each "extJSON" string of each valid test case to the fuzzer's
// corpus.
func seedTestCase(f *testing.F, tcase *testCase) {
for _, vtc := range tcase.Valid {
seedExtJSON(f, vtc.CanonicalExtJSON, "canonical", vtc.Description)
// Seed the relaxed extended JSON.
if vtc.RelaxedExtJSON != nil {
seedExtJSON(f, *vtc.RelaxedExtJSON, "relaxed", vtc.Description)
}
// Seed the degenerate extended JSON.
if vtc.DegenerateExtJSON != nil {
seedExtJSON(f, *vtc.DegenerateExtJSON, "degenerate", vtc.Description)
}
// Seed the converted extended JSON.
if vtc.ConvertedExtJSON != nil {
seedExtJSON(f, *vtc.ConvertedExtJSON, "converted", vtc.Description)
}
}
}
// seedBSONCorpus will unmarshal the data from "testdata/bson-corpus" into a slice of "testCase" structs and then
// marshal the "*_extjson" field of each "validityTestCase" into a slice of bytes to seed the fuzz corpus.
func seedBSONCorpus(f *testing.F) {
fileNames, err := findJSONFilesInDir(dataDir)
if err != nil {
f.Fatalf("failed to find JSON files in directory %q: %v", dataDir, err)
}
for _, fileName := range fileNames {
filePath := path.Join(dataDir, fileName)
file, err := os.Open(filePath)
if err != nil {
f.Fatalf("failed to open file %q: %v", filePath, err)
}
var tcase testCase
if err := json.NewDecoder(file).Decode(&tcase); err != nil {
f.Fatal(err)
}
seedTestCase(f, &tcase)
}
}
func needsEscapedUnicode(bsonType string) bool {
return bsonType == "0x02" || bsonType == "0x0D" || bsonType == "0x0E" || bsonType == "0x0F"
}
func unescapeUnicode(s, bsonType string) string {
if !needsEscapedUnicode(bsonType) {
return s
}
newS := ""
for i := 0; i < len(s); i++ {
c := s[i]
switch c {
case '\\':
switch s[i+1] {
case 'u':
us := s[i : i+6]
u, err := strconv.Unquote(strings.Replace(strconv.Quote(us), `\\u`, `\u`, 1))
if err != nil {
return ""
}
for _, r := range u {
if r < ' ' {
newS += fmt.Sprintf(`\u%04x`, r)
} else {
newS += string(r)
}
}
i += 5
default:
newS += string(c)
}
default:
if c > unicode.MaxASCII {
r, size := utf8.DecodeRune([]byte(s[i:]))
newS += string(r)
i += size - 1
} else {
newS += string(c)
}
}
}
return newS
}
func normalizeCanonicalDouble(t *testing.T, key string, cEJ string) string {
// Unmarshal string into map
cEJMap := make(map[string]map[string]string)
err := json.Unmarshal([]byte(cEJ), &cEJMap)
require.NoError(t, err)
// Parse the float contained by the map.
expectedString := cEJMap[key]["$numberDouble"]
expectedFloat, err := strconv.ParseFloat(expectedString, 64)
require.NoError(t, err)
// Normalize the string
return fmt.Sprintf(`{"%s":{"$numberDouble":"%s"}}`, key, formatDouble(expectedFloat))
}
func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
// Unmarshal string into map
rEJMap := make(map[string]float64)
err := json.Unmarshal([]byte(rEJ), &rEJMap)
if err != nil {
return normalizeCanonicalDouble(t, key, rEJ)
}
// Parse the float contained by the map.
expectedFloat := rEJMap[key]
// Normalize the string
return fmt.Sprintf(`{"%s":%s}`, key, formatDouble(expectedFloat))
}
// bsonToNative decodes the BSON bytes (b) into a native Document
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
var doc D
err := Unmarshal(b, &doc)
require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType)
return doc
}
// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
// canonical BSON (cB)
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
actual, err := Marshal(doc)
require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType)
if diff := cmp.Diff(cB, actual); diff != "" {
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
testDesc, docSrcDesc, cB, actual)
t.FailNow()
}
}
// jsonToNative decodes the extended JSON string (ej) into a native Document
func jsonToNative(ej, ejType, testDesc string) (D, error) {
var doc D
if err := UnmarshalExtJSON([]byte(ej), ejType != "relaxed", &doc); err != nil {
return nil, fmt.Errorf("%s: decoding %s extended JSON: %w", testDesc, ejType, err)
}
return doc, nil
}
// jsonToBytes decodes the extended JSON string (ej) into canonical BSON and then encodes it into a byte slice.
func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
native, err := jsonToNative(ej, ejType, testDesc)
if err != nil {
return nil, err
}
b, err := Marshal(native)
if err != nil {
return nil, fmt.Errorf("%s: encoding %s BSON: %w", testDesc, ejType, err)
}
return b, nil
}
// nativeToJSON encodes the native Document (doc) into an extended JSON string
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType)
if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
testDesc, ejType, docSrcDesc, ejShortName, diff)
t.FailNow()
}
}
func runTest(t *testing.T, file string) {
filepath := path.Join(dataDir, file)
content, err := os.ReadFile(filepath)
require.NoError(t, err)
// Remove ".json" from filename.
file = file[:len(file)-5]
testName := "bson_corpus--" + file
t.Run(testName, func(t *testing.T) {
var test testCase
require.NoError(t, json.Unmarshal(content, &test))
t.Run("valid", func(t *testing.T) {
for _, v := range test.Valid {
t.Run(v.Description, func(t *testing.T) {
// get canonical BSON
cB, err := hex.DecodeString(v.CanonicalBson)
require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description)
// get canonical extended JSON
var compactEJ bytes.Buffer
require.NoError(t, json.Compact(&compactEJ, []byte(v.CanonicalExtJSON)))
cEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
if test.BsonType == "0x01" {
cEJ = normalizeCanonicalDouble(t, *test.TestKey, cEJ)
}
/*** canonical BSON round-trip tests ***/
doc := bsonToNative(t, cB, "canonical", v.Description)
// native_to_bson(bson_to_native(cB)) = cB
nativeToBSON(t, cB, doc, v.Description, "canonical", "bson_to_native(cB)")
// native_to_canonical_extended_json(bson_to_native(cB)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "bson_to_native(cB)")
// native_to_relaxed_extended_json(bson_to_native(cB)) = rEJ (if rEJ exists)
if v.RelaxedExtJSON != nil {
var compactEJ bytes.Buffer
require.NoError(t, json.Compact(&compactEJ, []byte(*v.RelaxedExtJSON)))
rEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
if test.BsonType == "0x01" {
rEJ = normalizeRelaxedDouble(t, *test.TestKey, rEJ)
}
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "rEJ", "bson_to_native(cB)")
/*** relaxed extended JSON round-trip tests (if exists) ***/
doc, err = jsonToNative(rEJ, "relaxed", v.Description)
require.NoError(t, err)
// native_to_relaxed_extended_json(json_to_native(rEJ)) = rEJ
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "eJR", "json_to_native(rEJ)")
}
/*** canonical extended JSON round-trip tests ***/
doc, err = jsonToNative(cEJ, "canonical", v.Description)
require.NoError(t, err)
// native_to_canonical_extended_json(json_to_native(cEJ)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "json_to_native(cEJ)")
// native_to_bson(json_to_native(cEJ)) = cb (unless lossy)
if v.Lossy == nil || !*v.Lossy {
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(cEJ)")
}
/*** degenerate BSON round-trip tests (if exists) ***/
if v.DegenerateBSON != nil {
dB, err := hex.DecodeString(*v.DegenerateBSON)
require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description)
doc = bsonToNative(t, dB, "degenerate", v.Description)
// native_to_bson(bson_to_native(dB)) = cB
nativeToBSON(t, cB, doc, v.Description, "degenerate", "bson_to_native(dB)")
}
/*** degenerate JSON round-trip tests (if exists) ***/
if v.DegenerateExtJSON != nil {
var compactEJ bytes.Buffer
require.NoError(t, json.Compact(&compactEJ, []byte(*v.DegenerateExtJSON)))
dEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
if test.BsonType == "0x01" {
dEJ = normalizeCanonicalDouble(t, *test.TestKey, dEJ)
}
doc, err = jsonToNative(dEJ, "degenerate canonical", v.Description)
require.NoError(t, err)
// native_to_canonical_extended_json(json_to_native(dEJ)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
// native_to_bson(json_to_native(dEJ)) = cB (unless lossy)
if v.Lossy == nil || !*v.Lossy {
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(dEJ)")
}
}
})
}
})
t.Run("decode error", func(t *testing.T) {
for _, d := range test.DecodeErrors {
t.Run(d.Description, func(t *testing.T) {
b, err := hex.DecodeString(d.Bson)
require.NoError(t, err, d.Description)
var doc D
err = Unmarshal(b, &doc)
// The driver unmarshals invalid UTF-8 strings without error. Loop over the unmarshalled elements
// and assert that there was no error if any of the string or DBPointer values contain invalid UTF-8
// characters.
for _, elem := range doc {
value := reflect.ValueOf(elem.Value)
invalidString := (value.Kind() == reflect.String) && !utf8.ValidString(value.String())
dbPtr, ok := elem.Value.(DBPointer)
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)
if invalidString || invalidDBPtr {
require.NoError(t, err, d.Description)
return
}
}
require.Errorf(t, err, "%s: expected decode error", d.Description)
})
}
})
t.Run("parse error", func(t *testing.T) {
for _, p := range test.ParseErrors {
t.Run(p.Description, func(t *testing.T) {
s := unescapeUnicode(p.String, test.BsonType)
if test.BsonType == "0x13" {
s = fmt.Sprintf(`{"decimal128": {"$numberDecimal": "%s"}}`, s)
}
switch test.BsonType {
case "0x00", "0x05", "0x13":
var doc D
err := UnmarshalExtJSON([]byte(s), true, &doc)
// Null bytes are validated when marshaling to BSON
if strings.Contains(p.Description, "Null") {
_, err = Marshal(doc)
}
require.Errorf(t, err, "%s: expected parse error", p.Description)
default:
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
t.Fail()
}
})
}
})
})
}
func Test_BsonCorpus(t *testing.T) {
jsonFiles, err := findJSONFilesInDir(dataDir)
require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err)
for _, file := range jsonFiles {
runTest(t, file)
}
}
func TestRelaxedUUIDValidation(t *testing.T) {
testCases := []struct {
description string
canonicalExtJSON string
degenerateExtJSON string
expectedErr string
}{
{
"valid uuid",
"{\"x\" : { \"$binary\" : {\"base64\" : \"c//SZESzTGmQ6OfR38A11A==\", \"subType\" : \"04\"}}}",
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035d4\"}}",
"",
},
{
"invalid uuid--no hyphens",
"",
"{\"x\" : { \"$uuid\" : \"73ffd26444b34c6990e8e7d1dfc035d4\"}}",
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
},
{
"invalid uuid--trailing hyphens",
"",
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035--\"}}",
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
},
{
"invalid uuid--malformed hex",
"",
"{\"x\" : { \"$uuid\" : \"q3@fd26l-44b3-4c69-90e8-e7d1dfc035d4\"}}",
"$uuid value does not follow RFC 4122 format regarding hex bytes: encoding/hex: invalid byte: U+0071 'q'",
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
// get canonical extended JSON (if provided)
cEJ := ""
if tc.canonicalExtJSON != "" {
var compactCEJ bytes.Buffer
require.NoError(t, json.Compact(&compactCEJ, []byte(tc.canonicalExtJSON)))
cEJ = unescapeUnicode(compactCEJ.String(), "0x05")
}
// get degenerate extended JSON
var compactDEJ bytes.Buffer
require.NoError(t, json.Compact(&compactDEJ, []byte(tc.degenerateExtJSON)))
dEJ := unescapeUnicode(compactDEJ.String(), "0x05")
// convert dEJ to native doc
var doc D
err := UnmarshalExtJSON([]byte(dEJ), true, &doc)
if tc.expectedErr != "" {
assert.Equal(t, tc.expectedErr, err.Error(), "expected error %v, got %v", tc.expectedErr, err)
} else {
assert.Nil(t, err, "expected no error, got error: %v", err)
// Marshal doc into extended JSON and compare with cEJ
nativeToJSON(t, cEJ, doc, tc.description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
}
})
}
}

679
bson_test.go Normal file
View File

@ -0,0 +1,679 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
"testing"
"time"
"gitea.psichedelico.com/go/bson/internal/assert"
"gitea.psichedelico.com/go/bson/internal/require"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
"github.com/google/go-cmp/cmp"
)
func noerr(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Errorf("Unexpected error: (%T)%v", err, err)
t.FailNow()
}
}
func TestTimestamp(t *testing.T) {
t.Parallel()
testCases := []struct {
description string
tp Timestamp
tp2 Timestamp
expectedAfter bool
expectedBefore bool
expectedEqual bool
expectedCompare int
}{
{
description: "equal",
tp: Timestamp{T: 12345, I: 67890},
tp2: Timestamp{T: 12345, I: 67890},
expectedBefore: false,
expectedAfter: false,
expectedEqual: true,
expectedCompare: 0,
},
{
description: "T greater than",
tp: Timestamp{T: 12345, I: 67890},
tp2: Timestamp{T: 2345, I: 67890},
expectedBefore: false,
expectedAfter: true,
expectedEqual: false,
expectedCompare: 1,
},
{
description: "I greater than",
tp: Timestamp{T: 12345, I: 67890},
tp2: Timestamp{T: 12345, I: 7890},
expectedBefore: false,
expectedAfter: true,
expectedEqual: false,
expectedCompare: 1,
},
{
description: "T less than",
tp: Timestamp{T: 12345, I: 67890},
tp2: Timestamp{T: 112345, I: 67890},
expectedBefore: true,
expectedAfter: false,
expectedEqual: false,
expectedCompare: -1,
},
{
description: "I less than",
tp: Timestamp{T: 12345, I: 67890},
tp2: Timestamp{T: 12345, I: 167890},
expectedBefore: true,
expectedAfter: false,
expectedEqual: false,
expectedCompare: -1,
},
}
for _, tc := range testCases {
tc := tc // Capture range variable.
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.expectedAfter, tc.tp.After(tc.tp2), "expected After results to be the same")
assert.Equal(t, tc.expectedBefore, tc.tp.Before(tc.tp2), "expected Before results to be the same")
assert.Equal(t, tc.expectedEqual, tc.tp.Equal(tc.tp2), "expected Equal results to be the same")
assert.Equal(t, tc.expectedCompare, tc.tp.Compare(tc.tp2), "expected Compare result to be the same")
})
}
}
func TestPrimitiveIsZero(t *testing.T) {
testcases := []struct {
name string
zero Zeroer
nonzero Zeroer
}{
{"binary", Binary{}, Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}},
{"decimal128", Decimal128{}, NewDecimal128(1, 2)},
{"objectID", ObjectID{}, NewObjectID()},
{"regex", Regex{}, Regex{Pattern: "foo", Options: "bar"}},
{"dbPointer", DBPointer{}, DBPointer{DB: "foobar", Pointer: ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}}},
{"timestamp", Timestamp{}, Timestamp{T: 12345, I: 67890}},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
require.True(t, tc.zero.IsZero())
require.False(t, tc.nonzero.IsZero())
})
}
}
func TestRegexCompare(t *testing.T) {
testcases := []struct {
name string
r1 Regex
r2 Regex
eq bool
}{
{"equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo1", Options: "bar1"}, true},
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo2", Options: "bar2"}, false},
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo1", Options: "bar2"}, false},
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo2", Options: "bar1"}, false},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
require.True(t, tc.r1.Equal(tc.r2) == tc.eq)
})
}
}
func TestDateTime(t *testing.T) {
t.Run("json", func(t *testing.T) {
t.Run("round trip", func(t *testing.T) {
original := DateTime(1000)
jsonBytes, err := json.Marshal(original)
assert.Nil(t, err, "Marshal error: %v", err)
var unmarshalled DateTime
err = json.Unmarshal(jsonBytes, &unmarshalled)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, original, unmarshalled, "expected DateTime %v, got %v", original, unmarshalled)
})
t.Run("decode null", func(t *testing.T) {
jsonBytes := []byte("null")
var dt DateTime
err := json.Unmarshal(jsonBytes, &dt)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, DateTime(0), dt, "expected DateTime value to be 0, got %v", dt)
})
t.Run("UTC", func(t *testing.T) {
dt := DateTime(1681145535123)
jsonBytes, err := json.Marshal(dt)
assert.Nil(t, err, "Marshal error: %v", err)
assert.Equal(t, `"2023-04-10T16:52:15.123Z"`, string(jsonBytes))
})
})
t.Run("NewDateTimeFromTime", func(t *testing.T) {
t.Run("range is not limited", func(t *testing.T) {
// If the implementation internally calls time.Time.UnixNano(), the constructor cannot handle times after
// the year 2262.
timeFormat := "2006-01-02T15:04:05.999Z07:00"
timeString := "3001-01-01T00:00:00Z"
tt, err := time.Parse(timeFormat, timeString)
assert.Nil(t, err, "Parse error: %v", err)
dt := NewDateTimeFromTime(tt)
assert.True(t, dt > 0, "expected a valid DateTime greater than 0, got %v", dt)
})
})
}
func TestTimeRoundTrip(t *testing.T) {
val := struct {
Value time.Time
ID string
}{
ID: "time-rt-test",
}
if !val.Value.IsZero() {
t.Errorf("Did not get zero time as expected.")
}
bsonOut, err := Marshal(val)
noerr(t, err)
rtval := struct {
Value time.Time
ID string
}{}
err = Unmarshal(bsonOut, &rtval)
noerr(t, err)
if !cmp.Equal(val, rtval) {
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
}
if !rtval.Value.IsZero() {
t.Errorf("Did not get zero time as expected.")
}
}
func TestNonNullTimeRoundTrip(t *testing.T) {
now := time.Now()
now = time.Unix(now.Unix(), 0)
val := struct {
Value time.Time
ID string
}{
ID: "time-rt-test",
Value: now,
}
bsonOut, err := Marshal(val)
noerr(t, err)
rtval := struct {
Value time.Time
ID string
}{}
err = Unmarshal(bsonOut, &rtval)
noerr(t, err)
if !cmp.Equal(val, rtval) {
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
}
}
func TestD(t *testing.T) {
t.Run("can marshal", func(t *testing.T) {
d := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendStringElement(want, "foo", "bar")
want = bsoncore.AppendStringElement(want, "hello", "world")
want = bsoncore.AppendDoubleElement(want, "pi", 3.14159)
want, err := bsoncore.AppendDocumentEnd(want, idx)
noerr(t, err)
got, err := Marshal(d)
noerr(t, err)
if !bytes.Equal(got, want) {
t.Errorf("Marshaled documents do not match. got %v; want %v", Raw(got), Raw(want))
}
})
t.Run("can unmarshal", func(t *testing.T) {
want := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendStringElement(doc, "foo", "bar")
doc = bsoncore.AppendStringElement(doc, "hello", "world")
doc = bsoncore.AppendDoubleElement(doc, "pi", 3.14159)
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
noerr(t, err)
var got D
err = Unmarshal(doc, &got)
noerr(t, err)
if !cmp.Equal(got, want) {
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, want)
}
})
}
func TestDStringer(t *testing.T) {
got := D{{"a", 1}, {"b", 2}}.String()
want := `{"a":{"$numberInt":"1"},"b":{"$numberInt":"2"}}`
assert.Equal(t, want, got, "expected: %s, got: %s", want, got)
}
func TestMStringer(t *testing.T) {
type msg struct {
A json.RawMessage `json:"a"`
B json.RawMessage `json:"b"`
}
var res msg
got := M{"a": 1, "b": 2}.String()
err := json.Unmarshal([]byte(got), &res)
require.NoError(t, err, "Unmarshal error")
want := msg{
A: json.RawMessage(`{"$numberInt":"1"}`),
B: json.RawMessage(`{"$numberInt":"2"}`),
}
assert.Equal(t, want, res, "returned string did not unmarshal to the expected document, returned string: %s", got)
}
func TestD_MarshalJSON(t *testing.T) {
t.Parallel()
testcases := []struct {
name string
test D
expected interface{}
}{
{
"nil",
nil,
nil,
},
{
"empty",
D{},
struct{}{},
},
{
"non-empty",
D{
{"a", 42},
{"b", true},
{"c", "answer"},
{"d", nil},
{"e", 2.71828},
{"f", A{42, true, "answer", nil, 2.71828}},
{"g", D{{"foo", "bar"}}},
},
struct {
A int `json:"a"`
B bool `json:"b"`
C string `json:"c"`
D interface{} `json:"d"`
E float32 `json:"e"`
F []interface{} `json:"f"`
G map[string]interface{} `json:"g"`
}{
A: 42,
B: true,
C: "answer",
D: nil,
E: 2.71828,
F: []interface{}{42, true, "answer", nil, 2.71828},
G: map[string]interface{}{"foo": "bar"},
},
},
}
for _, tc := range testcases {
tc := tc
t.Run("json.Marshal "+tc.name, func(t *testing.T) {
t.Parallel()
got, err := json.Marshal(tc.test)
assert.NoError(t, err)
want, _ := json.Marshal(tc.expected)
assert.Equal(t, want, got)
})
}
for _, tc := range testcases {
tc := tc
t.Run("json.MarshalIndent "+tc.name, func(t *testing.T) {
t.Parallel()
got, err := json.MarshalIndent(tc.test, "<prefix>", "<indent>")
assert.NoError(t, err)
want, _ := json.MarshalIndent(tc.expected, "<prefix>", "<indent>")
assert.Equal(t, want, got)
})
}
}
func TestD_UnmarshalJSON(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
test []byte
expected D
}{
{
"nil",
[]byte(`null`),
nil,
},
{
"empty",
[]byte(`{}`),
D{},
},
{
"non-empty",
[]byte(`{"hello":"world","pi":3.142,"boolean":true,"nothing":null,"list":["hello world",3.142,false,null,{"Lorem":"ipsum"}],"document":{"foo":"bar"}}`),
D{
{"hello", "world"},
{"pi", 3.142},
{"boolean", true},
{"nothing", nil},
{"list", []interface{}{"hello world", 3.142, false, nil, D{{"Lorem", "ipsum"}}}},
{"document", D{{"foo", "bar"}}},
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var got D
err := json.Unmarshal(tc.test, &got)
assert.NoError(t, err)
assert.Equal(t, tc.expected, got)
})
}
})
t.Run("failure", func(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
test string
}{
{
"illegal",
`nil`,
},
{
"invalid",
`{"pi": 3.142ipsum}`,
},
{
"malformatted",
`{"pi", 3.142}`,
},
{
"truncated",
`{"pi": 3.142`,
},
{
"array type",
`["pi", 3.142]`,
},
{
"boolean type",
`true`,
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var a map[string]interface{}
want := json.Unmarshal([]byte(tc.test), &a)
var b D
got := json.Unmarshal([]byte(tc.test), &b)
switch w := want.(type) {
case *json.UnmarshalTypeError:
w.Type = reflect.TypeOf(b)
require.IsType(t, want, got)
g := got.(*json.UnmarshalTypeError)
assert.Equal(t, w, g)
default:
assert.Equal(t, want, got)
}
})
}
})
}
type stringerString string
func (ss stringerString) String() string {
return "bar"
}
type keyBool bool
func (kb keyBool) MarshalKey() (string, error) {
return fmt.Sprintf("%v", kb), nil
}
func (kb *keyBool) UnmarshalKey(key string) error {
switch key {
case "true":
*kb = true
case "false":
*kb = false
default:
return fmt.Errorf("invalid bool value %v", key)
}
return nil
}
type keyStruct struct {
val int64
}
func (k keyStruct) MarshalText() (text []byte, err error) {
str := strconv.FormatInt(k.val, 10)
return []byte(str), nil
}
func (k *keyStruct) UnmarshalText(text []byte) error {
val, err := strconv.ParseInt(string(text), 10, 64)
if err != nil {
return err
}
*k = keyStruct{
val: val,
}
return nil
}
func TestMapCodec(t *testing.T) {
t.Run("EncodeKeysWithStringer", func(t *testing.T) {
strstr := stringerString("foo")
mapObj := map[stringerString]int{strstr: 1}
testCases := []struct {
name string
mapCodec *mapCodec
key string
}{
{"default", &mapCodec{}, "foo"},
{"true", &mapCodec{encodeKeysWithStringer: true}, "bar"},
{"false", &mapCodec{encodeKeysWithStringer: false}, "foo"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mapRegistry := NewRegistry()
mapRegistry.RegisterKindEncoder(reflect.Map, tc.mapCodec)
buf := new(bytes.Buffer)
vw := NewDocumentWriter(buf)
enc := NewEncoder(vw)
enc.SetRegistry(mapRegistry)
err := enc.Encode(mapObj)
assert.Nil(t, err, "Encode error: %v", err)
str := buf.String()
assert.True(t, strings.Contains(str, tc.key), "expected result to contain %v, got: %v", tc.key, str)
})
}
})
t.Run("keys implements keyMarshaler and keyUnmarshaler", func(t *testing.T) {
mapObj := map[keyBool]int{keyBool(true): 1}
doc, err := Marshal(mapObj)
assert.Nil(t, err, "Marshal error: %v", err)
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendInt32Element(want, "true", 1)
want, _ = bsoncore.AppendDocumentEnd(want, idx)
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
var got map[keyBool]int
err = Unmarshal(doc, &got)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
})
t.Run("keys implements encoding.TextMarshaler and encoding.TextUnmarshaler", func(t *testing.T) {
mapObj := map[keyStruct]int{
{val: 10}: 100,
}
doc, err := Marshal(mapObj)
assert.Nil(t, err, "Marshal error: %v", err)
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendInt32Element(want, "10", 100)
want, _ = bsoncore.AppendDocumentEnd(want, idx)
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
var got map[keyStruct]int
err = Unmarshal(doc, &got)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
})
}
func TestExtJSONEscapeKey(t *testing.T) {
doc := D{
{
Key: "\\usb#",
Value: int32(1),
},
{
Key: "regex",
Value: Regex{Pattern: "ab\\\\\\\"ab", Options: "\""},
},
}
b, err := MarshalExtJSON(&doc, false, false)
noerr(t, err)
want := `{"\\usb#":1,"regex":{"$regularExpression":{"pattern":"ab\\\\\\\"ab","options":"\""}}}`
if diff := cmp.Diff(want, string(b)); diff != "" {
t.Errorf("Marshaled documents do not match. got %v, want %v", string(b), want)
}
var got D
err = UnmarshalExtJSON(b, false, &got)
noerr(t, err)
if !cmp.Equal(got, doc) {
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, doc)
}
}
func TestBsoncoreArray(t *testing.T) {
type BSONDocumentArray struct {
Array []D `bson:"array"`
}
type BSONArray struct {
Array bsoncore.Array `bson:"array"`
}
bda := BSONDocumentArray{
Array: []D{
{{"x", 1}},
{{"x", 2}},
{{"x", 3}},
},
}
expectedBSON, err := Marshal(bda)
assert.Nil(t, err, "Marshal bsoncore.Document array error: %v", err)
var ba BSONArray
err = Unmarshal(expectedBSON, &ba)
assert.Nil(t, err, "Unmarshal error: %v", err)
actualBSON, err := Marshal(ba)
assert.Nil(t, err, "Marshal bsoncore.Array error: %v", err)
assert.Equal(t, expectedBSON, actualBSON,
"expected BSON to be %v after Marshalling again; got %v", expectedBSON, actualBSON)
doc := bsoncore.Document(actualBSON)
v := doc.Lookup("array")
assert.Equal(t, bsoncore.TypeArray, v.Type, "expected type array, got %v", v.Type)
}
var baseTime = time.Date(2024, 10, 11, 12, 13, 14, 12345678, time.UTC)
func BenchmarkDateTimeMarshalJSON(b *testing.B) {
t := NewDateTimeFromTime(baseTime)
data, err := t.MarshalJSON()
if err != nil {
b.Fatal(err)
}
b.ReportAllocs()
b.SetBytes(int64(len(data)))
for i := 0; i < b.N; i++ {
if _, err := t.MarshalJSON(); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkDateTimeUnmarshalJSON(b *testing.B) {
t := NewDateTimeFromTime(baseTime)
data, err := t.MarshalJSON()
if err != nil {
b.Fatal(err)
}
b.ReportAllocs()
b.SetBytes(int64(len(data)))
for i := 0; i < b.N; i++ {
var dt DateTime
if err := dt.UnmarshalJSON(data); err != nil {
b.Fatal(err)
}
}
}

199
bsoncodec.go Normal file
View File

@ -0,0 +1,199 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
"strings"
)
var (
emptyValue = reflect.Value{}
)
// ValueEncoderError is an error returned from a ValueEncoder when the provided value can't be
// encoded by the ValueEncoder.
type ValueEncoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vee ValueEncoderError) Error() string {
typeKinds := make([]string, 0, len(vee.Types)+len(vee.Kinds))
for _, t := range vee.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vee.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vee.Received.Kind().String()
if vee.Received.IsValid() {
received = vee.Received.Type().String()
}
return fmt.Sprintf("%s can only encode valid %s, but got %s", vee.Name, strings.Join(typeKinds, ", "), received)
}
// ValueDecoderError is an error returned from a ValueDecoder when the provided value can't be
// decoded by the ValueDecoder.
type ValueDecoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vde ValueDecoderError) Error() string {
typeKinds := make([]string, 0, len(vde.Types)+len(vde.Kinds))
for _, t := range vde.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vde.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vde.Received.Kind().String()
if vde.Received.IsValid() {
received = vde.Received.Type().String()
}
return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received)
}
// EncodeContext is the contextual information required for a Codec to encode a
// value.
type EncodeContext struct {
*Registry
// minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64,
// uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits)
// that can represent the integer value.
minSize bool
errorOnInlineDuplicates bool
stringifyMapKeysWithFmt bool
nilMapAsEmpty bool
nilSliceAsEmpty bool
nilByteSliceAsEmpty bool
omitZeroStruct bool
useJSONStructTags bool
}
// DecodeContext is the contextual information required for a Codec to decode a
// value.
type DecodeContext struct {
*Registry
// truncate, if true, instructs decoders to to truncate the fractional part of BSON "double"
// values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64,
// uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to
// BSON "decimal128" values.
truncate bool
// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the
// usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is
// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an
// error.
defaultDocumentType reflect.Type
binaryAsSlice bool
// a false value results in a decoding error.
objectIDAsHexString bool
useJSONStructTags bool
useLocalTimeZone bool
zeroMaps bool
zeroStructs bool
}
// ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON.
// The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the
// EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc
// is provided to allow use of a function with the correct signature as a ValueEncoder. An
// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and
// to provide configuration information.
type ValueEncoder interface {
EncodeValue(EncodeContext, ValueWriter, reflect.Value) error
}
// ValueEncoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueEncoder.
type ValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error
// EncodeValue implements the ValueEncoder interface.
func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
return fn(ec, vw, val)
}
// ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type.
// Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc,
// ValueDecoderFunc is provided to allow the use of a function with the correct signature as a
// ValueDecoder. A DecodeContext instance is provided and serves similar functionality to the
// EncodeContext.
type ValueDecoder interface {
DecodeValue(DecodeContext, ValueReader, reflect.Value) error
}
// ValueDecoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueDecoder.
type ValueDecoderFunc func(DecodeContext, ValueReader, reflect.Value) error
// DecodeValue implements the ValueDecoder interface.
func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
return fn(dc, vr, val)
}
// typeDecoder is the interface implemented by types that can handle the decoding of a value given its type.
type typeDecoder interface {
decodeType(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
}
// typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder.
type typeDecoderFunc func(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
return fn(dc, vr, t)
}
// decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder.
type decodeAdapter struct {
ValueDecoderFunc
typeDecoderFunc
}
var _ ValueDecoder = decodeAdapter{}
var _ typeDecoder = decodeAdapter{}
func decodeTypeOrValueWithInfo(vd ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if td, _ := vd.(typeDecoder); td != nil {
val, err := td.decodeType(dc, vr, t)
if err == nil && val.Type() != t {
// This conversion step is necessary for slices and maps. If a user declares variables like:
//
// type myBool bool
// var m map[string]myBool
//
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
// because we'll try to assign a value of type bool to one of type myBool.
val = val.Convert(t)
}
return val, err
}
val := reflect.New(t).Elem()
err := vd.DecodeValue(dc, vr, val)
return val, err
}

72
bsoncodec_test.go Normal file
View File

@ -0,0 +1,72 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
"testing"
)
func ExampleValueEncoder() {
var _ ValueEncoderFunc = func(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{Name: "StringEncodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
return vw.WriteString(val.String())
}
}
func ExampleValueDecoder() {
var _ ValueDecoderFunc = func(_ DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.String {
return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
if vr.Type() != TypeString {
return fmt.Errorf("cannot decode %v into a string type", vr.Type())
}
str, err := vr.ReadString()
if err != nil {
return err
}
val.SetString(str)
return nil
}
}
type llCodec struct {
t *testing.T
decodeval interface{}
encodeval interface{}
err error
}
func (llc *llCodec) EncodeValue(_ EncodeContext, _ ValueWriter, i interface{}) error {
if llc.err != nil {
return llc.err
}
llc.encodeval = i
return nil
}
func (llc *llCodec) DecodeValue(_ DecodeContext, _ ValueReader, val reflect.Value) error {
if llc.err != nil {
return llc.err
}
if !reflect.TypeOf(llc.decodeval).AssignableTo(val.Type()) {
llc.t.Errorf("decodeval must be assignable to val provided to DecodeValue, but is not. decodeval %T; val %T", llc.decodeval, val)
return nil
}
val.Set(reflect.ValueOf(llc.decodeval))
return nil
}

846
bsonrw_test.go Normal file
View File

@ -0,0 +1,846 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"testing"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
var (
_ ValueReader = &valueReaderWriter{}
_ ValueWriter = &valueReaderWriter{}
)
// invoked is a type used to indicate what method was called last.
type invoked byte
// These are the different methods that can be invoked.
const (
nothing invoked = iota
readArray
readBinary
readBoolean
readDocument
readCodeWithScope
readDBPointer
readDateTime
readDecimal128
readDouble
readInt32
readInt64
readJavascript
readMaxKey
readMinKey
readNull
readObjectID
readRegex
readString
readSymbol
readTimestamp
readUndefined
readElement
readValue
writeArray
writeBinary
writeBinaryWithSubtype
writeBoolean
writeCodeWithScope
writeDBPointer
writeDateTime
writeDecimal128
writeDouble
writeInt32
writeInt64
writeJavascript
writeMaxKey
writeMinKey
writeNull
writeObjectID
writeRegex
writeString
writeDocument
writeSymbol
writeTimestamp
writeUndefined
writeDocumentElement
writeDocumentEnd
writeArrayElement
writeArrayEnd
skip
)
func (i invoked) String() string {
switch i {
case nothing:
return "Nothing"
case readArray:
return "ReadArray"
case readBinary:
return "ReadBinary"
case readBoolean:
return "ReadBoolean"
case readDocument:
return "ReadDocument"
case readCodeWithScope:
return "ReadCodeWithScope"
case readDBPointer:
return "ReadDBPointer"
case readDateTime:
return "ReadDateTime"
case readDecimal128:
return "ReadDecimal128"
case readDouble:
return "ReadDouble"
case readInt32:
return "ReadInt32"
case readInt64:
return "ReadInt64"
case readJavascript:
return "ReadJavascript"
case readMaxKey:
return "ReadMaxKey"
case readMinKey:
return "ReadMinKey"
case readNull:
return "ReadNull"
case readObjectID:
return "ReadObjectID"
case readRegex:
return "ReadRegex"
case readString:
return "ReadString"
case readSymbol:
return "ReadSymbol"
case readTimestamp:
return "ReadTimestamp"
case readUndefined:
return "ReadUndefined"
case readElement:
return "ReadElement"
case readValue:
return "ReadValue"
case writeArray:
return "WriteArray"
case writeBinary:
return "WriteBinary"
case writeBinaryWithSubtype:
return "WriteBinaryWithSubtype"
case writeBoolean:
return "WriteBoolean"
case writeCodeWithScope:
return "WriteCodeWithScope"
case writeDBPointer:
return "WriteDBPointer"
case writeDateTime:
return "WriteDateTime"
case writeDecimal128:
return "WriteDecimal128"
case writeDouble:
return "WriteDouble"
case writeInt32:
return "WriteInt32"
case writeInt64:
return "WriteInt64"
case writeJavascript:
return "WriteJavascript"
case writeMaxKey:
return "WriteMaxKey"
case writeMinKey:
return "WriteMinKey"
case writeNull:
return "WriteNull"
case writeObjectID:
return "WriteObjectID"
case writeRegex:
return "WriteRegex"
case writeString:
return "WriteString"
case writeDocument:
return "WriteDocument"
case writeSymbol:
return "WriteSymbol"
case writeTimestamp:
return "WriteTimestamp"
case writeUndefined:
return "WriteUndefined"
case writeDocumentElement:
return "WriteDocumentElement"
case writeDocumentEnd:
return "WriteDocumentEnd"
case writeArrayElement:
return "WriteArrayElement"
case writeArrayEnd:
return "WriteArrayEnd"
default:
return "<unknown>"
}
}
// valueReaderWriter is a test implementation of a bsonrw.ValueReader and bsonrw.ValueWriter
type valueReaderWriter struct {
T *testing.T
invoked invoked
Return interface{} // Can be a primitive or a bsoncore.Value
BSONType Type
Err error
ErrAfter invoked // error after this method is called
depth uint64
}
// prevent infinite recursion.
func (llvrw *valueReaderWriter) checkdepth() {
llvrw.depth++
if llvrw.depth > 1000 {
panic("max depth exceeded")
}
}
// Type implements the ValueReader interface.
func (llvrw *valueReaderWriter) Type() Type {
llvrw.checkdepth()
return llvrw.BSONType
}
// Skip implements the ValueReader interface.
func (llvrw *valueReaderWriter) Skip() error {
llvrw.checkdepth()
llvrw.invoked = skip
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// ReadArray implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadArray() (ArrayReader, error) {
llvrw.checkdepth()
llvrw.invoked = readArray
if llvrw.ErrAfter == llvrw.invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// ReadBinary implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadBinary() (b []byte, btype byte, err error) {
llvrw.checkdepth()
llvrw.invoked = readBinary
if llvrw.ErrAfter == llvrw.invoked {
return nil, 0x00, llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bsoncore.Value:
subtype, data, _, ok := bsoncore.ReadBinary(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value provided for return value of ReadBinary.")
return nil, 0x00, nil
}
return data, subtype, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadBinary: %T", llvrw.Return)
return nil, 0x00, nil
}
}
// ReadBoolean implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadBoolean() (bool, error) {
llvrw.checkdepth()
llvrw.invoked = readBoolean
if llvrw.ErrAfter == llvrw.invoked {
return false, llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bool:
return tt, nil
case bsoncore.Value:
b, _, ok := bsoncore.ReadBoolean(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value provided for return value of ReadBoolean.")
return false, nil
}
return b, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadBoolean: %T", llvrw.Return)
return false, nil
}
}
// ReadDocument implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadDocument() (DocumentReader, error) {
llvrw.checkdepth()
llvrw.invoked = readDocument
if llvrw.ErrAfter == llvrw.invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// ReadCodeWithScope implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
llvrw.checkdepth()
llvrw.invoked = readCodeWithScope
if llvrw.ErrAfter == llvrw.invoked {
return "", nil, llvrw.Err
}
return "", llvrw, nil
}
// ReadDBPointer implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadDBPointer() (ns string, oid ObjectID, err error) {
llvrw.checkdepth()
llvrw.invoked = readDBPointer
if llvrw.ErrAfter == llvrw.invoked {
return "", ObjectID{}, llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bsoncore.Value:
ns, oid, _, ok := bsoncore.ReadDBPointer(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value instance provided for return value of ReadDBPointer")
return "", ObjectID{}, nil
}
return ns, oid, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadDBPointer: %T", llvrw.Return)
return "", ObjectID{}, nil
}
}
// ReadDateTime implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadDateTime() (int64, error) {
llvrw.checkdepth()
llvrw.invoked = readDateTime
if llvrw.ErrAfter == llvrw.invoked {
return 0, llvrw.Err
}
dt, ok := llvrw.Return.(int64)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadDateTime: %T", llvrw.Return)
return 0, nil
}
return dt, nil
}
// ReadDecimal128 implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadDecimal128() (Decimal128, error) {
llvrw.checkdepth()
llvrw.invoked = readDecimal128
if llvrw.ErrAfter == llvrw.invoked {
return Decimal128{}, llvrw.Err
}
d128, ok := llvrw.Return.(Decimal128)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadDecimal128: %T", llvrw.Return)
return Decimal128{}, nil
}
return d128, nil
}
// ReadDouble implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadDouble() (float64, error) {
llvrw.checkdepth()
llvrw.invoked = readDouble
if llvrw.ErrAfter == llvrw.invoked {
return 0, llvrw.Err
}
f64, ok := llvrw.Return.(float64)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadDouble: %T", llvrw.Return)
return 0, nil
}
return f64, nil
}
// ReadInt32 implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadInt32() (int32, error) {
llvrw.checkdepth()
llvrw.invoked = readInt32
if llvrw.ErrAfter == llvrw.invoked {
return 0, llvrw.Err
}
i32, ok := llvrw.Return.(int32)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadInt32: %T", llvrw.Return)
return 0, nil
}
return i32, nil
}
// ReadInt64 implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadInt64() (int64, error) {
llvrw.checkdepth()
llvrw.invoked = readInt64
if llvrw.ErrAfter == llvrw.invoked {
return 0, llvrw.Err
}
i64, ok := llvrw.Return.(int64)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadInt64: %T", llvrw.Return)
return 0, nil
}
return i64, nil
}
// ReadJavascript implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadJavascript() (code string, err error) {
llvrw.checkdepth()
llvrw.invoked = readJavascript
if llvrw.ErrAfter == llvrw.invoked {
return "", llvrw.Err
}
js, ok := llvrw.Return.(string)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadJavascript: %T", llvrw.Return)
return "", nil
}
return js, nil
}
// ReadMaxKey implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadMaxKey() error {
llvrw.checkdepth()
llvrw.invoked = readMaxKey
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// ReadMinKey implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadMinKey() error {
llvrw.checkdepth()
llvrw.invoked = readMinKey
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// ReadNull implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadNull() error {
llvrw.checkdepth()
llvrw.invoked = readNull
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// ReadObjectID implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadObjectID() (ObjectID, error) {
llvrw.checkdepth()
llvrw.invoked = readObjectID
if llvrw.ErrAfter == llvrw.invoked {
return ObjectID{}, llvrw.Err
}
oid, ok := llvrw.Return.(ObjectID)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadObjectID: %T", llvrw.Return)
return ObjectID{}, nil
}
return oid, nil
}
// ReadRegex implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadRegex() (pattern string, options string, err error) {
llvrw.checkdepth()
llvrw.invoked = readRegex
if llvrw.ErrAfter == llvrw.invoked {
return "", "", llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bsoncore.Value:
pattern, options, _, ok := bsoncore.ReadRegex(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value instance provided for ReadRegex")
return "", "", nil
}
return pattern, options, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadRegex: %T", llvrw.Return)
return "", "", nil
}
}
// ReadString implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadString() (string, error) {
llvrw.checkdepth()
llvrw.invoked = readString
if llvrw.ErrAfter == llvrw.invoked {
return "", llvrw.Err
}
str, ok := llvrw.Return.(string)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadString: %T", llvrw.Return)
return "", nil
}
return str, nil
}
// ReadSymbol implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadSymbol() (symbol string, err error) {
llvrw.checkdepth()
llvrw.invoked = readSymbol
if llvrw.ErrAfter == llvrw.invoked {
return "", llvrw.Err
}
switch tt := llvrw.Return.(type) {
case string:
return tt, nil
case bsoncore.Value:
symbol, _, ok := bsoncore.ReadSymbol(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value instance provided for ReadSymbol")
return "", nil
}
return symbol, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadSymbol: %T", llvrw.Return)
return "", nil
}
}
// ReadTimestamp implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadTimestamp() (t uint32, i uint32, err error) {
llvrw.checkdepth()
llvrw.invoked = readTimestamp
if llvrw.ErrAfter == llvrw.invoked {
return 0, 0, llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bsoncore.Value:
t, i, _, ok := bsoncore.ReadTimestamp(tt.Data)
if !ok {
llvrw.T.Errorf("Invalid Value instance provided for return value of ReadTimestamp")
return 0, 0, nil
}
return t, i, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadTimestamp: %T", llvrw.Return)
return 0, 0, nil
}
}
// ReadUndefined implements the ValueReader interface.
func (llvrw *valueReaderWriter) ReadUndefined() error {
llvrw.checkdepth()
llvrw.invoked = readUndefined
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteArray implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteArray() (ArrayWriter, error) {
llvrw.checkdepth()
llvrw.invoked = writeArray
if llvrw.ErrAfter == llvrw.invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteBinary implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteBinary([]byte) error {
llvrw.checkdepth()
llvrw.invoked = writeBinary
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteBinaryWithSubtype implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteBinaryWithSubtype([]byte, byte) error {
llvrw.checkdepth()
llvrw.invoked = writeBinaryWithSubtype
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteBoolean implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteBoolean(bool) error {
llvrw.checkdepth()
llvrw.invoked = writeBoolean
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteCodeWithScope implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteCodeWithScope(string) (DocumentWriter, error) {
llvrw.checkdepth()
llvrw.invoked = writeCodeWithScope
if llvrw.ErrAfter == llvrw.invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteDBPointer implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteDBPointer(string, ObjectID) error {
llvrw.checkdepth()
llvrw.invoked = writeDBPointer
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteDateTime implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteDateTime(int64) error {
llvrw.checkdepth()
llvrw.invoked = writeDateTime
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteDecimal128 implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteDecimal128(Decimal128) error {
llvrw.checkdepth()
llvrw.invoked = writeDecimal128
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteDouble implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteDouble(float64) error {
llvrw.checkdepth()
llvrw.invoked = writeDouble
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteInt32 implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteInt32(int32) error {
llvrw.checkdepth()
llvrw.invoked = writeInt32
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteInt64 implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteInt64(int64) error {
llvrw.checkdepth()
llvrw.invoked = writeInt64
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteJavascript implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteJavascript(string) error {
llvrw.checkdepth()
llvrw.invoked = writeJavascript
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteMaxKey implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteMaxKey() error {
llvrw.checkdepth()
llvrw.invoked = writeMaxKey
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteMinKey implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteMinKey() error {
llvrw.checkdepth()
llvrw.invoked = writeMinKey
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteNull implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteNull() error {
llvrw.checkdepth()
llvrw.invoked = writeNull
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteObjectID implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteObjectID(ObjectID) error {
llvrw.checkdepth()
llvrw.invoked = writeObjectID
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteRegex implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteRegex(string, string) error {
llvrw.checkdepth()
llvrw.invoked = writeRegex
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteString implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteString(string) error {
llvrw.checkdepth()
llvrw.invoked = writeString
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteDocument implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteDocument() (DocumentWriter, error) {
llvrw.checkdepth()
llvrw.invoked = writeDocument
if llvrw.ErrAfter == llvrw.invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteSymbol implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteSymbol(string) error {
llvrw.checkdepth()
llvrw.invoked = writeSymbol
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteTimestamp implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteTimestamp(uint32, uint32) error {
llvrw.checkdepth()
llvrw.invoked = writeTimestamp
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// WriteUndefined implements the ValueWriter interface.
func (llvrw *valueReaderWriter) WriteUndefined() error {
llvrw.checkdepth()
llvrw.invoked = writeUndefined
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// ReadElement implements the DocumentReader interface.
func (llvrw *valueReaderWriter) ReadElement() (string, ValueReader, error) {
llvrw.checkdepth()
llvrw.invoked = readElement
if llvrw.ErrAfter == llvrw.invoked {
return "", nil, llvrw.Err
}
return "", llvrw, nil
}
// WriteDocumentElement implements the DocumentWriter interface.
func (llvrw *valueReaderWriter) WriteDocumentElement(string) (ValueWriter, error) {
llvrw.checkdepth()
llvrw.invoked = writeDocumentElement
if llvrw.ErrAfter == llvrw.invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteDocumentEnd implements the DocumentWriter interface.
func (llvrw *valueReaderWriter) WriteDocumentEnd() error {
llvrw.checkdepth()
llvrw.invoked = writeDocumentEnd
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}
// ReadValue implements the ArrayReader interface.
func (llvrw *valueReaderWriter) ReadValue() (ValueReader, error) {
llvrw.checkdepth()
llvrw.invoked = readValue
if llvrw.ErrAfter == llvrw.invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteArrayElement implements the ArrayWriter interface.
func (llvrw *valueReaderWriter) WriteArrayElement() (ValueWriter, error) {
llvrw.checkdepth()
llvrw.invoked = writeArrayElement
if llvrw.ErrAfter == llvrw.invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteArrayEnd implements the ArrayWriter interface.
func (llvrw *valueReaderWriter) WriteArrayEnd() error {
llvrw.checkdepth()
llvrw.invoked = writeArrayEnd
if llvrw.ErrAfter == llvrw.invoked {
return llvrw.Err
}
return nil
}

97
byte_slice_codec.go Normal file
View File

@ -0,0 +1,97 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"reflect"
)
// byteSliceCodec is the Codec used for []byte values.
type byteSliceCodec struct {
// encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values
// instead of BSON null.
encodeNilAsEmpty bool
}
// Assert that byteSliceCodec satisfies the typeDecoder interface, which allows it to be
// used by collection type decoders (e.g. map, slice, etc) to set individual values in a
// collection.
var _ typeDecoder = &byteSliceCodec{}
// EncodeValue is the ValueEncoder for []byte.
func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tByteSlice {
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty {
return vw.WriteNull()
}
return vw.WriteBinary(val.Interface().([]byte))
}
func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tByteSlice {
return emptyValue, ValueDecoderError{
Name: "ByteSliceDecodeValue",
Types: []reflect.Type{tByteSlice},
Received: reflect.Zero(t),
}
}
var data []byte
var err error
switch vrType := vr.Type(); vrType {
case TypeString:
str, err := vr.ReadString()
if err != nil {
return emptyValue, err
}
data = []byte(str)
case TypeSymbol:
sym, err := vr.ReadSymbol()
if err != nil {
return emptyValue, err
}
data = []byte(sym)
case TypeBinary:
var subtype byte
data, subtype, err = vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "[]byte"}
}
case TypeNull:
err = vr.ReadNull()
case TypeUndefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a []byte", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(data), nil
}
// DecodeValue is the ValueDecoder for []byte.
func (bsc *byteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tByteSlice {
return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
elem, err := bsc.decodeType(dc, vr, tByteSlice)
if err != nil {
return err
}
val.Set(elem)
return nil
}

166
codec_cache.go Normal file
View File

@ -0,0 +1,166 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"sync"
"sync/atomic"
)
// Runtime check that the kind encoder and decoder caches can store any valid
// reflect.Kind constant.
func init() {
if s := reflect.Kind(len(kindEncoderCache{}.entries)).String(); s != "kind27" {
panic("The capacity of kindEncoderCache is too small.\n" +
"This is due to a new type being added to reflect.Kind.")
}
}
// statically assert array size
var _ = (kindEncoderCache{}).entries[reflect.UnsafePointer]
var _ = (kindDecoderCache{}).entries[reflect.UnsafePointer]
type typeEncoderCache struct {
cache sync.Map // map[reflect.Type]ValueEncoder
}
func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) {
c.cache.Store(rt, enc)
}
func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) {
if v, _ := c.cache.Load(rt); v != nil {
return v.(ValueEncoder), true
}
return nil, false
}
func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder {
if v, loaded := c.cache.LoadOrStore(rt, enc); loaded {
enc = v.(ValueEncoder)
}
return enc
}
func (c *typeEncoderCache) Clone() *typeEncoderCache {
cc := new(typeEncoderCache)
c.cache.Range(func(k, v interface{}) bool {
if k != nil && v != nil {
cc.cache.Store(k, v)
}
return true
})
return cc
}
type typeDecoderCache struct {
cache sync.Map // map[reflect.Type]ValueDecoder
}
func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) {
c.cache.Store(rt, dec)
}
func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) {
if v, _ := c.cache.Load(rt); v != nil {
return v.(ValueDecoder), true
}
return nil, false
}
func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder {
if v, loaded := c.cache.LoadOrStore(rt, dec); loaded {
dec = v.(ValueDecoder)
}
return dec
}
func (c *typeDecoderCache) Clone() *typeDecoderCache {
cc := new(typeDecoderCache)
c.cache.Range(func(k, v interface{}) bool {
if k != nil && v != nil {
cc.cache.Store(k, v)
}
return true
})
return cc
}
// atomic.Value requires that all calls to Store() have the same concrete type
// so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type
// is always the same (since different concrete types may implement the
// ValueEncoder interface).
type kindEncoderCacheEntry struct {
enc ValueEncoder
}
type kindEncoderCache struct {
entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry
}
func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) {
if enc != nil && rt < reflect.Kind(len(c.entries)) {
c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc})
}
}
func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) {
if rt < reflect.Kind(len(c.entries)) {
if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok {
return ent.enc, ent.enc != nil
}
}
return nil, false
}
func (c *kindEncoderCache) Clone() *kindEncoderCache {
cc := new(kindEncoderCache)
for i, v := range c.entries {
if val := v.Load(); val != nil {
cc.entries[i].Store(val)
}
}
return cc
}
// atomic.Value requires that all calls to Store() have the same concrete type
// so we wrap the ValueDecoder with a kindDecoderCacheEntry to ensure the type
// is always the same (since different concrete types may implement the
// ValueDecoder interface).
type kindDecoderCacheEntry struct {
dec ValueDecoder
}
type kindDecoderCache struct {
entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry
}
func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) {
if rt < reflect.Kind(len(c.entries)) {
c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec})
}
}
func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) {
if rt < reflect.Kind(len(c.entries)) {
if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok {
return ent.dec, ent.dec != nil
}
}
return nil, false
}
func (c *kindDecoderCache) Clone() *kindDecoderCache {
cc := new(kindDecoderCache)
for i, v := range c.entries {
if val := v.Load(); val != nil {
cc.entries[i].Store(val)
}
}
return cc
}

176
codec_cache_test.go Normal file
View File

@ -0,0 +1,176 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"strconv"
"strings"
"testing"
)
// NB(charlie): the array size is a power of 2 because we use the remainder of
// it (mod) in benchmarks and that is faster when the size is a power of 2.
var codecCacheTestTypes = [16]reflect.Type{
reflect.TypeOf(uint8(0)),
reflect.TypeOf(uint16(0)),
reflect.TypeOf(uint32(0)),
reflect.TypeOf(uint64(0)),
reflect.TypeOf(uint(0)),
reflect.TypeOf(uintptr(0)),
reflect.TypeOf(int8(0)),
reflect.TypeOf(int16(0)),
reflect.TypeOf(int32(0)),
reflect.TypeOf(int64(0)),
reflect.TypeOf(int(0)),
reflect.TypeOf(float32(0)),
reflect.TypeOf(float64(0)),
reflect.TypeOf(true),
reflect.TypeOf(struct{ A int }{}),
reflect.TypeOf(map[int]int{}),
}
func TestTypeCache(t *testing.T) {
rt := reflect.TypeOf(int(0))
ec := new(typeEncoderCache)
dc := new(typeDecoderCache)
codec := new(fakeCodec)
ec.Store(rt, codec)
dc.Store(rt, codec)
if v, ok := ec.Load(rt); !ok || !reflect.DeepEqual(v, codec) {
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true)
}
if v, ok := dc.Load(rt); !ok || !reflect.DeepEqual(v, codec) {
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true)
}
// Make sure we overwrite the stored value with nil
ec.Store(rt, nil)
dc.Store(rt, nil)
if v, ok := ec.Load(rt); ok || v != nil {
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false)
}
if v, ok := dc.Load(rt); ok || v != nil {
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false)
}
}
func TestTypeCacheClone(t *testing.T) {
codec := new(fakeCodec)
ec1 := new(typeEncoderCache)
dc1 := new(typeDecoderCache)
for _, rt := range codecCacheTestTypes {
ec1.Store(rt, codec)
dc1.Store(rt, codec)
}
ec2 := ec1.Clone()
dc2 := dc1.Clone()
for _, rt := range codecCacheTestTypes {
if v, _ := ec2.Load(rt); !reflect.DeepEqual(v, codec) {
t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec)
}
if v, _ := dc2.Load(rt); !reflect.DeepEqual(v, codec) {
t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec)
}
}
}
func TestKindCacheArray(t *testing.T) {
// Check array bounds
var c kindEncoderCache
codec := new(fakeCodec)
c.Store(reflect.UnsafePointer, codec) // valid
c.Store(reflect.UnsafePointer+1, codec) // ignored
if v, ok := c.Load(reflect.UnsafePointer); !ok || v != codec {
t.Errorf("Load(reflect.UnsafePointer) = %v, %t; want: %v, %t", v, ok, codec, true)
}
if v, ok := c.Load(reflect.UnsafePointer + 1); ok || v != nil {
t.Errorf("Load(reflect.UnsafePointer + 1) = %v, %t; want: %v, %t", v, ok, nil, false)
}
// Make sure that reflect.UnsafePointer is the last/largest reflect.Type.
//
// The String() method of invalid reflect.Type types are of the format
// "kind{NUMBER}".
for rt := reflect.UnsafePointer + 1; rt < reflect.UnsafePointer+16; rt++ {
s := rt.String()
if !strings.Contains(s, strconv.Itoa(int(rt))) {
t.Errorf("reflect.Type(%d) appears to be valid: %q", rt, s)
}
}
}
func TestKindCacheClone(t *testing.T) {
e1 := new(kindEncoderCache)
d1 := new(kindDecoderCache)
codec := new(fakeCodec)
for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ {
e1.Store(k, codec)
d1.Store(k, codec)
}
e2 := e1.Clone()
for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ {
v1, ok1 := e1.Load(k)
v2, ok2 := e2.Load(k)
if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil {
t.Errorf("Encoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2)
}
}
d2 := d1.Clone()
for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ {
v1, ok1 := d1.Load(k)
v2, ok2 := d2.Load(k)
if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil {
t.Errorf("Decoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2)
}
}
}
func TestKindCacheEncoderNilEncoder(t *testing.T) {
t.Run("Encoder", func(t *testing.T) {
c := new(kindEncoderCache)
c.Store(reflect.Invalid, ValueEncoder(nil))
v, ok := c.Load(reflect.Invalid)
if v != nil || ok {
t.Errorf("Load of nil ValueEncoder should return: nil, false; got: %v, %t", v, ok)
}
})
t.Run("Decoder", func(t *testing.T) {
c := new(kindDecoderCache)
c.Store(reflect.Invalid, ValueDecoder(nil))
v, ok := c.Load(reflect.Invalid)
if v != nil || ok {
t.Errorf("Load of nil ValueDecoder should return: nil, false; got: %v, %t", v, ok)
}
})
}
func BenchmarkEncoderCacheLoad(b *testing.B) {
c := new(typeEncoderCache)
codec := new(fakeCodec)
typs := codecCacheTestTypes
for _, t := range typs {
c.Store(t, codec)
}
b.RunParallel(func(pb *testing.PB) {
for i := 0; pb.Next(); i++ {
c.Load(typs[i%len(typs)])
}
})
}
func BenchmarkEncoderCacheStore(b *testing.B) {
c := new(typeEncoderCache)
codec := new(fakeCodec)
b.RunParallel(func(pb *testing.PB) {
typs := codecCacheTestTypes
for i := 0; pb.Next(); i++ {
c.Store(typs[i%len(typs)], codec)
}
})
}

61
cond_addr_codec.go Normal file
View File

@ -0,0 +1,61 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
)
// condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder.
type condAddrEncoder struct {
canAddrEnc ValueEncoder
elseEnc ValueEncoder
}
var _ ValueEncoder = &condAddrEncoder{}
// newCondAddrEncoder returns an condAddrEncoder.
func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder {
encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
return &encoder
}
// EncodeValue is the ValueEncoderFunc for a value that may be addressable.
func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if val.CanAddr() {
return cae.canAddrEnc.EncodeValue(ec, vw, val)
}
if cae.elseEnc != nil {
return cae.elseEnc.EncodeValue(ec, vw, val)
}
return errNoEncoder{Type: val.Type()}
}
// condAddrDecoder is the decoder used when a pointer to the value has a decoder.
type condAddrDecoder struct {
canAddrDec ValueDecoder
elseDec ValueDecoder
}
var _ ValueDecoder = &condAddrDecoder{}
// newCondAddrDecoder returns an CondAddrDecoder.
func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder {
decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec}
return &decoder
}
// DecodeValue is the ValueDecoderFunc for a value that may be addressable.
func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if val.CanAddr() {
return cad.canAddrDec.DecodeValue(dc, vr, val)
}
if cad.elseDec != nil {
return cad.elseDec.DecodeValue(dc, vr, val)
}
return errNoDecoder{Type: val.Type()}
}

95
cond_addr_codec_test.go Normal file
View File

@ -0,0 +1,95 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
)
func TestCondAddrCodec(t *testing.T) {
var inner int
canAddrVal := reflect.ValueOf(&inner)
addressable := canAddrVal.Elem()
unaddressable := reflect.ValueOf(inner)
rw := &valueReaderWriter{}
t.Run("addressEncode", func(t *testing.T) {
invoked := 0
encode1 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error {
invoked = 1
return nil
})
encode2 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error {
invoked = 2
return nil
})
condEncoder := newCondAddrEncoder(encode1, encode2)
testCases := []struct {
name string
val reflect.Value
invoked int
}{
{"canAddr", addressable, 1},
{"else", unaddressable, 2},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := condEncoder.EncodeValue(EncodeContext{}, rw, tc.val)
assert.Nil(t, err, "CondAddrEncoder error: %v", err)
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
})
}
t.Run("error", func(t *testing.T) {
errEncoder := newCondAddrEncoder(encode1, nil)
err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable)
want := errNoEncoder{Type: unaddressable.Type()}
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
})
})
t.Run("addressDecode", func(t *testing.T) {
invoked := 0
decode1 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error {
invoked = 1
return nil
})
decode2 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error {
invoked = 2
return nil
})
condDecoder := newCondAddrDecoder(decode1, decode2)
testCases := []struct {
name string
val reflect.Value
invoked int
}{
{"canAddr", addressable, 1},
{"else", unaddressable, 2},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val)
assert.Nil(t, err, "CondAddrDecoder error: %v", err)
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
})
}
t.Run("error", func(t *testing.T) {
errDecoder := newCondAddrDecoder(decode1, nil)
err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable)
want := errNoDecoder{Type: unaddressable.Type()}
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
})
})
}

431
copier.go Normal file
View File

@ -0,0 +1,431 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"fmt"
"io"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
// copyDocument handles copying one document from the src to the dst.
func copyDocument(dst ValueWriter, src ValueReader) error {
dr, err := src.ReadDocument()
if err != nil {
return err
}
dw, err := dst.WriteDocument()
if err != nil {
return err
}
return copyDocumentCore(dw, dr)
}
// copyArrayFromBytes copies the values from a BSON array represented as a
// []byte to a ValueWriter.
func copyArrayFromBytes(dst ValueWriter, src []byte) error {
aw, err := dst.WriteArray()
if err != nil {
return err
}
err = copyBytesToArrayWriter(aw, src)
if err != nil {
return err
}
return aw.WriteArrayEnd()
}
// copyDocumentFromBytes copies the values from a BSON document represented as a
// []byte to a ValueWriter.
func copyDocumentFromBytes(dst ValueWriter, src []byte) error {
dw, err := dst.WriteDocument()
if err != nil {
return err
}
err = copyBytesToDocumentWriter(dw, src)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
type writeElementFn func(key string) (ValueWriter, error)
// copyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an
// ArrayWriter.
func copyBytesToArrayWriter(dst ArrayWriter, src []byte) error {
wef := func(_ string) (ValueWriter, error) {
return dst.WriteArrayElement()
}
return copyBytesToValueWriter(src, wef)
}
// copyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a
// DocumentWriter.
func copyBytesToDocumentWriter(dst DocumentWriter, src []byte) error {
wef := func(key string) (ValueWriter, error) {
return dst.WriteDocumentElement(key)
}
return copyBytesToValueWriter(src, wef)
}
func copyBytesToValueWriter(src []byte, wef writeElementFn) error {
// TODO(skriptble): Create errors types here. Anything that is a tag should be a property.
length, rem, ok := bsoncore.ReadLength(src)
if !ok {
return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src))
}
if len(src) < int(length) {
return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length)
}
rem = rem[:length-4]
var t bsoncore.Type
var key string
var val bsoncore.Value
for {
t, rem, ok = bsoncore.ReadType(rem)
if !ok {
return io.EOF
}
if t == bsoncore.Type(0) {
if len(rem) != 0 {
return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem)
}
break
}
key, rem, ok = bsoncore.ReadKey(rem)
if !ok {
return fmt.Errorf("invalid key found. remaining bytes=%v", rem)
}
// write as either array element or document element using writeElementFn
vw, err := wef(key)
if err != nil {
return err
}
val, rem, ok = bsoncore.ReadValue(rem, t)
if !ok {
return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t)
}
err = copyValueFromBytes(vw, Type(t), val.Data)
if err != nil {
return err
}
}
return nil
}
// copyDocumentToBytes copies an entire document from the ValueReader and
// returns it as bytes.
func copyDocumentToBytes(src ValueReader) ([]byte, error) {
return appendDocumentBytes(nil, src)
}
// appendDocumentBytes functions the same as CopyDocumentToBytes, but will
// append the result to dst.
func appendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(bytesReader); ok {
_, dst, err := br.readValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer putValueWriter(vw)
vw.reset(dst)
err := copyDocument(vw, src)
dst = vw.buf
return dst, err
}
// appendArrayBytes copies an array from the ValueReader to dst.
func appendArrayBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(bytesReader); ok {
_, dst, err := br.readValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer putValueWriter(vw)
vw.reset(dst)
err := copyArray(vw, src)
dst = vw.buf
return dst, err
}
// copyValueFromBytes will write the value represtend by t and src to dst.
func copyValueFromBytes(dst ValueWriter, t Type, src []byte) error {
if wvb, ok := dst.(bytesWriter); ok {
return wvb.writeValueBytes(t, src)
}
vr := newDocumentReader(bytes.NewReader(src))
vr.pushElement(t)
return copyValue(dst, vr)
}
// copyValueToBytes copies a value from src and returns it as a Type and a
// []byte.
func copyValueToBytes(src ValueReader) (Type, []byte, error) {
if br, ok := src.(bytesReader); ok {
return br.readValueBytes(nil)
}
vw := vwPool.Get().(*valueWriter)
defer putValueWriter(vw)
vw.reset(nil)
vw.push(mElement)
err := copyValue(vw, src)
if err != nil {
return 0, nil, err
}
return Type(vw.buf[0]), vw.buf[2:], nil
}
// copyValue will copy a single value from src to dst.
func copyValue(dst ValueWriter, src ValueReader) error {
var err error
switch src.Type() {
case TypeDouble:
var f64 float64
f64, err = src.ReadDouble()
if err != nil {
break
}
err = dst.WriteDouble(f64)
case TypeString:
var str string
str, err = src.ReadString()
if err != nil {
return err
}
err = dst.WriteString(str)
case TypeEmbeddedDocument:
err = copyDocument(dst, src)
case TypeArray:
err = copyArray(dst, src)
case TypeBinary:
var data []byte
var subtype byte
data, subtype, err = src.ReadBinary()
if err != nil {
break
}
err = dst.WriteBinaryWithSubtype(data, subtype)
case TypeUndefined:
err = src.ReadUndefined()
if err != nil {
break
}
err = dst.WriteUndefined()
case TypeObjectID:
var oid ObjectID
oid, err = src.ReadObjectID()
if err != nil {
break
}
err = dst.WriteObjectID(oid)
case TypeBoolean:
var b bool
b, err = src.ReadBoolean()
if err != nil {
break
}
err = dst.WriteBoolean(b)
case TypeDateTime:
var dt int64
dt, err = src.ReadDateTime()
if err != nil {
break
}
err = dst.WriteDateTime(dt)
case TypeNull:
err = src.ReadNull()
if err != nil {
break
}
err = dst.WriteNull()
case TypeRegex:
var pattern, options string
pattern, options, err = src.ReadRegex()
if err != nil {
break
}
err = dst.WriteRegex(pattern, options)
case TypeDBPointer:
var ns string
var pointer ObjectID
ns, pointer, err = src.ReadDBPointer()
if err != nil {
break
}
err = dst.WriteDBPointer(ns, pointer)
case TypeJavaScript:
var js string
js, err = src.ReadJavascript()
if err != nil {
break
}
err = dst.WriteJavascript(js)
case TypeSymbol:
var symbol string
symbol, err = src.ReadSymbol()
if err != nil {
break
}
err = dst.WriteSymbol(symbol)
case TypeCodeWithScope:
var code string
var srcScope DocumentReader
code, srcScope, err = src.ReadCodeWithScope()
if err != nil {
break
}
var dstScope DocumentWriter
dstScope, err = dst.WriteCodeWithScope(code)
if err != nil {
break
}
err = copyDocumentCore(dstScope, srcScope)
case TypeInt32:
var i32 int32
i32, err = src.ReadInt32()
if err != nil {
break
}
err = dst.WriteInt32(i32)
case TypeTimestamp:
var t, i uint32
t, i, err = src.ReadTimestamp()
if err != nil {
break
}
err = dst.WriteTimestamp(t, i)
case TypeInt64:
var i64 int64
i64, err = src.ReadInt64()
if err != nil {
break
}
err = dst.WriteInt64(i64)
case TypeDecimal128:
var d128 Decimal128
d128, err = src.ReadDecimal128()
if err != nil {
break
}
err = dst.WriteDecimal128(d128)
case TypeMinKey:
err = src.ReadMinKey()
if err != nil {
break
}
err = dst.WriteMinKey()
case TypeMaxKey:
err = src.ReadMaxKey()
if err != nil {
break
}
err = dst.WriteMaxKey()
default:
err = fmt.Errorf("cannot copy unknown BSON type %s", src.Type())
}
return err
}
func copyArray(dst ValueWriter, src ValueReader) error {
ar, err := src.ReadArray()
if err != nil {
return err
}
aw, err := dst.WriteArray()
if err != nil {
return err
}
for {
vr, err := ar.ReadValue()
if errors.Is(err, ErrEOA) {
break
}
if err != nil {
return err
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = copyValue(vw, vr)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
for {
key, vr, err := dr.ReadElement()
if errors.Is(err, ErrEOD) {
break
}
if err != nil {
return err
}
vw, err := dw.WriteDocumentElement(key)
if err != nil {
return err
}
err = copyValue(vw, vr)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// bytesReader is the interface used to read BSON bytes from a valueReader.
//
// The bytes of the value will be appended to dst.
type bytesReader interface {
readValueBytes(dst []byte) (Type, []byte, error)
}
// bytesWriter is the interface used to write BSON bytes to a valueWriter.
type bytesWriter interface {
writeValueBytes(t Type, b []byte) error
}

528
copier_test.go Normal file
View File

@ -0,0 +1,528 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"fmt"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
func TestCopier(t *testing.T) {
t.Run("CopyDocument", func(t *testing.T) {
t.Run("ReadDocument Error", func(t *testing.T) {
want := errors.New("ReadDocumentError")
src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadDocument}
got := copyDocument(nil, src)
if !assert.CompareErrors(got, want) {
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
}
})
t.Run("WriteDocument Error", func(t *testing.T) {
want := errors.New("WriteDocumentError")
src := &TestValueReaderWriter{}
dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteDocument}
got := copyDocument(dst, src)
if !assert.CompareErrors(got, want) {
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
}
})
t.Run("success", func(t *testing.T) {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendStringElement(doc, "Hello", "world")
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
noerr(t, err)
src := newDocumentReader(bytes.NewReader(doc))
dst := newValueWriterFromSlice(make([]byte, 0))
want := doc
err = copyDocument(dst, src)
noerr(t, err)
got := dst.buf
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
}
})
})
t.Run("copyArray", func(t *testing.T) {
t.Run("ReadArray Error", func(t *testing.T) {
want := errors.New("ReadArrayError")
src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadArray}
got := copyArray(nil, src)
if !assert.CompareErrors(got, want) {
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
}
})
t.Run("WriteArray Error", func(t *testing.T) {
want := errors.New("WriteArrayError")
src := &TestValueReaderWriter{}
dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteArray}
got := copyArray(dst, src)
if !assert.CompareErrors(got, want) {
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
}
})
t.Run("success", func(t *testing.T) {
idx, doc := bsoncore.AppendDocumentStart(nil)
aidx, doc := bsoncore.AppendArrayElementStart(doc, "foo")
doc = bsoncore.AppendStringElement(doc, "0", "Hello, world!")
doc, err := bsoncore.AppendArrayEnd(doc, aidx)
noerr(t, err)
doc, err = bsoncore.AppendDocumentEnd(doc, idx)
noerr(t, err)
src := newDocumentReader(bytes.NewReader(doc))
_, err = src.ReadDocument()
noerr(t, err)
_, _, err = src.ReadElement()
noerr(t, err)
dst := newValueWriterFromSlice(make([]byte, 0))
_, err = dst.WriteDocument()
noerr(t, err)
_, err = dst.WriteDocumentElement("foo")
noerr(t, err)
want := doc
err = copyArray(dst, src)
noerr(t, err)
err = dst.WriteDocumentEnd()
noerr(t, err)
got := dst.buf
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
}
})
})
t.Run("CopyValue", func(t *testing.T) {
testCases := []struct {
name string
dst *TestValueReaderWriter
src *TestValueReaderWriter
err error
}{
{
"Double/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeDouble, err: errors.New("1"), errAfter: llvrwReadDouble},
errors.New("1"),
},
{
"Double/dst/error",
&TestValueReaderWriter{bsontype: TypeDouble, err: errors.New("2"), errAfter: llvrwWriteDouble},
&TestValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14159)},
errors.New("2"),
},
{
"String/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeString, err: errors.New("1"), errAfter: llvrwReadString},
errors.New("1"),
},
{
"String/dst/error",
&TestValueReaderWriter{bsontype: TypeString, err: errors.New("2"), errAfter: llvrwWriteString},
&TestValueReaderWriter{bsontype: TypeString, readval: "hello, world"},
errors.New("2"),
},
{
"Document/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeEmbeddedDocument, err: errors.New("1"), errAfter: llvrwReadDocument},
errors.New("1"),
},
{
"Array/dst/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeArray, err: errors.New("2"), errAfter: llvrwReadArray},
errors.New("2"),
},
{
"Binary/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeBinary, err: errors.New("1"), errAfter: llvrwReadBinary},
errors.New("1"),
},
{
"Binary/dst/error",
&TestValueReaderWriter{bsontype: TypeBinary, err: errors.New("2"), errAfter: llvrwWriteBinaryWithSubtype},
&TestValueReaderWriter{
bsontype: TypeBinary,
readval: bsoncore.Value{
Type: bsoncore.TypeBinary,
Data: []byte{0x03, 0x00, 0x00, 0x00, 0xFF, 0x01, 0x02, 0x03},
},
},
errors.New("2"),
},
{
"Undefined/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeUndefined, err: errors.New("1"), errAfter: llvrwReadUndefined},
errors.New("1"),
},
{
"Undefined/dst/error",
&TestValueReaderWriter{bsontype: TypeUndefined, err: errors.New("2"), errAfter: llvrwWriteUndefined},
&TestValueReaderWriter{bsontype: TypeUndefined},
errors.New("2"),
},
{
"ObjectID/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeObjectID, err: errors.New("1"), errAfter: llvrwReadObjectID},
errors.New("1"),
},
{
"ObjectID/dst/error",
&TestValueReaderWriter{bsontype: TypeObjectID, err: errors.New("2"), errAfter: llvrwWriteObjectID},
&TestValueReaderWriter{bsontype: TypeObjectID, readval: ObjectID{0x01, 0x02, 0x03}},
errors.New("2"),
},
{
"Boolean/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeBoolean, err: errors.New("1"), errAfter: llvrwReadBoolean},
errors.New("1"),
},
{
"Boolean/dst/error",
&TestValueReaderWriter{bsontype: TypeBoolean, err: errors.New("2"), errAfter: llvrwWriteBoolean},
&TestValueReaderWriter{bsontype: TypeBoolean, readval: bool(true)},
errors.New("2"),
},
{
"DateTime/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeDateTime, err: errors.New("1"), errAfter: llvrwReadDateTime},
errors.New("1"),
},
{
"DateTime/dst/error",
&TestValueReaderWriter{bsontype: TypeDateTime, err: errors.New("2"), errAfter: llvrwWriteDateTime},
&TestValueReaderWriter{bsontype: TypeDateTime, readval: int64(1234567890)},
errors.New("2"),
},
{
"Null/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeNull, err: errors.New("1"), errAfter: llvrwReadNull},
errors.New("1"),
},
{
"Null/dst/error",
&TestValueReaderWriter{bsontype: TypeNull, err: errors.New("2"), errAfter: llvrwWriteNull},
&TestValueReaderWriter{bsontype: TypeNull},
errors.New("2"),
},
{
"Regex/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeRegex, err: errors.New("1"), errAfter: llvrwReadRegex},
errors.New("1"),
},
{
"Regex/dst/error",
&TestValueReaderWriter{bsontype: TypeRegex, err: errors.New("2"), errAfter: llvrwWriteRegex},
&TestValueReaderWriter{
bsontype: TypeRegex,
readval: bsoncore.Value{
Type: bsoncore.TypeRegex,
Data: bsoncore.AppendRegex(nil, "hello", "world"),
},
},
errors.New("2"),
},
{
"DBPointer/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeDBPointer, err: errors.New("1"), errAfter: llvrwReadDBPointer},
errors.New("1"),
},
{
"DBPointer/dst/error",
&TestValueReaderWriter{bsontype: TypeDBPointer, err: errors.New("2"), errAfter: llvrwWriteDBPointer},
&TestValueReaderWriter{
bsontype: TypeDBPointer,
readval: bsoncore.Value{
Type: bsoncore.TypeDBPointer,
Data: bsoncore.AppendDBPointer(nil, "foo", ObjectID{0x01, 0x02, 0x03}),
},
},
errors.New("2"),
},
{
"Javascript/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeJavaScript, err: errors.New("1"), errAfter: llvrwReadJavascript},
errors.New("1"),
},
{
"Javascript/dst/error",
&TestValueReaderWriter{bsontype: TypeJavaScript, err: errors.New("2"), errAfter: llvrwWriteJavascript},
&TestValueReaderWriter{bsontype: TypeJavaScript, readval: "hello, world"},
errors.New("2"),
},
{
"Symbol/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeSymbol, err: errors.New("1"), errAfter: llvrwReadSymbol},
errors.New("1"),
},
{
"Symbol/dst/error",
&TestValueReaderWriter{bsontype: TypeSymbol, err: errors.New("2"), errAfter: llvrwWriteSymbol},
&TestValueReaderWriter{
bsontype: TypeSymbol,
readval: bsoncore.Value{
Type: bsoncore.TypeSymbol,
Data: bsoncore.AppendSymbol(nil, "hello, world"),
},
},
errors.New("2"),
},
{
"CodeWithScope/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeCodeWithScope, err: errors.New("1"), errAfter: llvrwReadCodeWithScope},
errors.New("1"),
},
{
"CodeWithScope/dst/error",
&TestValueReaderWriter{bsontype: TypeCodeWithScope, err: errors.New("2"), errAfter: llvrwWriteCodeWithScope},
&TestValueReaderWriter{bsontype: TypeCodeWithScope},
errors.New("2"),
},
{
"CodeWithScope/dst/copyDocumentCore error",
&TestValueReaderWriter{err: errors.New("3"), errAfter: llvrwWriteDocumentElement},
&TestValueReaderWriter{bsontype: TypeCodeWithScope},
errors.New("3"),
},
{
"Int32/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeInt32, err: errors.New("1"), errAfter: llvrwReadInt32},
errors.New("1"),
},
{
"Int32/dst/error",
&TestValueReaderWriter{bsontype: TypeInt32, err: errors.New("2"), errAfter: llvrwWriteInt32},
&TestValueReaderWriter{bsontype: TypeInt32, readval: int32(12345)},
errors.New("2"),
},
{
"Timestamp/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeTimestamp, err: errors.New("1"), errAfter: llvrwReadTimestamp},
errors.New("1"),
},
{
"Timestamp/dst/error",
&TestValueReaderWriter{bsontype: TypeTimestamp, err: errors.New("2"), errAfter: llvrwWriteTimestamp},
&TestValueReaderWriter{
bsontype: TypeTimestamp,
readval: bsoncore.Value{
Type: bsoncore.TypeTimestamp,
Data: bsoncore.AppendTimestamp(nil, 12345, 67890),
},
},
errors.New("2"),
},
{
"Int64/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeInt64, err: errors.New("1"), errAfter: llvrwReadInt64},
errors.New("1"),
},
{
"Int64/dst/error",
&TestValueReaderWriter{bsontype: TypeInt64, err: errors.New("2"), errAfter: llvrwWriteInt64},
&TestValueReaderWriter{bsontype: TypeInt64, readval: int64(1234567890)},
errors.New("2"),
},
{
"Decimal128/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeDecimal128, err: errors.New("1"), errAfter: llvrwReadDecimal128},
errors.New("1"),
},
{
"Decimal128/dst/error",
&TestValueReaderWriter{bsontype: TypeDecimal128, err: errors.New("2"), errAfter: llvrwWriteDecimal128},
&TestValueReaderWriter{bsontype: TypeDecimal128, readval: NewDecimal128(12345, 67890)},
errors.New("2"),
},
{
"MinKey/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeMinKey, err: errors.New("1"), errAfter: llvrwReadMinKey},
errors.New("1"),
},
{
"MinKey/dst/error",
&TestValueReaderWriter{bsontype: TypeMinKey, err: errors.New("2"), errAfter: llvrwWriteMinKey},
&TestValueReaderWriter{bsontype: TypeMinKey},
errors.New("2"),
},
{
"MaxKey/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: TypeMaxKey, err: errors.New("1"), errAfter: llvrwReadMaxKey},
errors.New("1"),
},
{
"MaxKey/dst/error",
&TestValueReaderWriter{bsontype: TypeMaxKey, err: errors.New("2"), errAfter: llvrwWriteMaxKey},
&TestValueReaderWriter{bsontype: TypeMaxKey},
errors.New("2"),
},
{
"Unknown BSON type error",
&TestValueReaderWriter{},
&TestValueReaderWriter{},
fmt.Errorf("cannot copy unknown BSON type %s", Type(0)),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.dst.t, tc.src.t = t, t
err := copyValue(tc.dst, tc.src)
if !assert.CompareErrors(err, tc.err) {
t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err)
}
})
}
})
t.Run("CopyValueFromBytes", func(t *testing.T) {
t.Run("BytesWriter", func(t *testing.T) {
vw := newValueWriterFromSlice(make([]byte, 0))
_, err := vw.WriteDocument()
noerr(t, err)
_, err = vw.WriteDocumentElement("foo")
noerr(t, err)
err = copyValueFromBytes(vw, TypeString, bsoncore.AppendString(nil, "bar"))
noerr(t, err)
err = vw.WriteDocumentEnd()
noerr(t, err)
var idx int32
want, err := bsoncore.AppendDocumentEnd(
bsoncore.AppendStringElement(
bsoncore.AppendDocumentStartInline(nil, &idx),
"foo", "bar",
),
idx,
)
noerr(t, err)
got := vw.buf
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
}
})
t.Run("Non BytesWriter", func(t *testing.T) {
llvrw := &TestValueReaderWriter{t: t}
err := copyValueFromBytes(llvrw, TypeString, bsoncore.AppendString(nil, "bar"))
noerr(t, err)
got, want := llvrw.invoked, llvrwWriteString
if got != want {
t.Errorf("Incorrect method invoked on llvrw. got %v; want %v", got, want)
}
})
})
t.Run("CopyValueToBytes", func(t *testing.T) {
t.Run("BytesReader", func(t *testing.T) {
var idx int32
b, err := bsoncore.AppendDocumentEnd(
bsoncore.AppendStringElement(
bsoncore.AppendDocumentStartInline(nil, &idx),
"hello", "world",
),
idx,
)
noerr(t, err)
vr := newDocumentReader(bytes.NewReader(b))
_, err = vr.ReadDocument()
noerr(t, err)
_, _, err = vr.ReadElement()
noerr(t, err)
btype, got, err := copyValueToBytes(vr)
noerr(t, err)
want := bsoncore.AppendString(nil, "world")
if btype != TypeString {
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
}
if !bytes.Equal(got, want) {
t.Errorf("Bytes do not match. got %v; want %v", got, want)
}
})
t.Run("Non BytesReader", func(t *testing.T) {
llvrw := &TestValueReaderWriter{t: t, bsontype: TypeString, readval: "Hello, world!"}
btype, got, err := copyValueToBytes(llvrw)
noerr(t, err)
want := bsoncore.AppendString(nil, "Hello, world!")
if btype != TypeString {
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
}
if !bytes.Equal(got, want) {
t.Errorf("Bytes do not match. got %v; want %v", got, want)
}
})
})
t.Run("AppendValueBytes", func(t *testing.T) {
t.Run("BytesReader", func(t *testing.T) {
var idx int32
b, err := bsoncore.AppendDocumentEnd(
bsoncore.AppendStringElement(
bsoncore.AppendDocumentStartInline(nil, &idx),
"hello", "world",
),
idx,
)
noerr(t, err)
vr := newDocumentReader(bytes.NewReader(b))
_, err = vr.ReadDocument()
noerr(t, err)
_, _, err = vr.ReadElement()
noerr(t, err)
btype, got, err := copyValueToBytes(vr)
noerr(t, err)
want := bsoncore.AppendString(nil, "world")
if btype != TypeString {
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
}
if !bytes.Equal(got, want) {
t.Errorf("Bytes do not match. got %v; want %v", got, want)
}
})
t.Run("Non BytesReader", func(t *testing.T) {
llvrw := &TestValueReaderWriter{t: t, bsontype: TypeString, readval: "Hello, world!"}
btype, got, err := copyValueToBytes(llvrw)
noerr(t, err)
want := bsoncore.AppendString(nil, "Hello, world!")
if btype != TypeString {
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
}
if !bytes.Equal(got, want) {
t.Errorf("Bytes do not match. got %v; want %v", got, want)
}
})
t.Run("CopyValue error", func(t *testing.T) {
want := errors.New("CopyValue error")
llvrw := &TestValueReaderWriter{t: t, bsontype: TypeString, err: want, errAfter: llvrwReadString}
_, _, got := copyValueToBytes(llvrw)
if !assert.CompareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
})
}

339
decimal.go Normal file
View File

@ -0,0 +1,339 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package bson
import (
"encoding/json"
"errors"
"fmt"
"math/big"
"regexp"
"strconv"
"strings"
"gitea.psichedelico.com/go/bson/internal/decimal128"
)
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
const (
MaxDecimal128Exp = 6111
MinDecimal128Exp = -6176
)
// These errors are returned when an invalid value is parsed as a big.Int.
var (
ErrParseNaN = errors.New("cannot parse NaN as a *big.Int")
ErrParseInf = errors.New("cannot parse Infinity as a *big.Int")
ErrParseNegInf = errors.New("cannot parse -Infinity as a *big.Int")
)
// Decimal128 holds decimal128 BSON values.
type Decimal128 struct {
h, l uint64
}
// NewDecimal128 creates a Decimal128 using the provide high and low uint64s.
func NewDecimal128(h, l uint64) Decimal128 {
return Decimal128{h: h, l: l}
}
// GetBytes returns the underlying bytes of the BSON decimal value as two uint64 values. The first
// contains the most first 8 bytes of the value and the second contains the latter.
func (d Decimal128) GetBytes() (uint64, uint64) {
return d.h, d.l
}
// String returns a string representation of the decimal value.
func (d Decimal128) String() string {
return decimal128.String(d.h, d.l)
}
// BigInt returns significand as big.Int and exponent, bi * 10 ^ exp.
func (d Decimal128) BigInt() (*big.Int, int, error) {
high, low := d.GetBytes()
posSign := high>>63&1 == 0 // positive sign
switch high >> 58 & (1<<5 - 1) {
case 0x1F:
return nil, 0, ErrParseNaN
case 0x1E:
if posSign {
return nil, 0, ErrParseInf
}
return nil, 0, ErrParseNegInf
}
var exp int
if high>>61&3 == 3 {
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
// Implicit 0b100 prefix in significand.
exp = int(high >> 47 & (1<<14 - 1))
// Spec says all of these values are out of range.
high, low = 0, 0
} else {
// Bits: 1*sign 14*exponent 113*significand
exp = int(high >> 49 & (1<<14 - 1))
high &= (1<<49 - 1)
}
exp += MinDecimal128Exp
// Would be handled by the logic below, but that's trivial and common.
if high == 0 && low == 0 && exp == 0 {
return new(big.Int), 0, nil
}
bi := big.NewInt(0)
const host32bit = ^uint(0)>>32 == 0
if host32bit {
bi.SetBits([]big.Word{big.Word(low), big.Word(low >> 32), big.Word(high), big.Word(high >> 32)})
} else {
bi.SetBits([]big.Word{big.Word(low), big.Word(high)})
}
if !posSign {
return bi.Neg(bi), exp, nil
}
return bi, exp, nil
}
// IsNaN returns whether d is NaN.
func (d Decimal128) IsNaN() bool {
return d.h>>58&(1<<5-1) == 0x1F
}
// IsInf returns:
//
// +1 d == Infinity
// 0 other case
// -1 d == -Infinity
func (d Decimal128) IsInf() int {
if d.h>>58&(1<<5-1) != 0x1E {
return 0
}
if d.h>>63&1 == 0 {
return 1
}
return -1
}
// IsZero returns true if d is the empty Decimal128.
func (d Decimal128) IsZero() bool {
return d.h == 0 && d.l == 0
}
// MarshalJSON returns Decimal128 as a string.
func (d Decimal128) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON creates a Decimal128 from a JSON string, an extended JSON $numberDecimal value, or the string
// "null". If b is a JSON string or extended JSON value, d will have the value of that string, and if b is "null", d will
// be unchanged.
func (d *Decimal128) UnmarshalJSON(b []byte) error {
// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer Decimal128 field
// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
// enter the UnmarshalJSON hook.
if string(b) == "null" {
return nil
}
var res interface{}
err := json.Unmarshal(b, &res)
if err != nil {
return err
}
str, ok := res.(string)
// Extended JSON
if !ok {
m, ok := res.(map[string]interface{})
if !ok {
return errors.New("not an extended JSON Decimal128: expected document")
}
d128, ok := m["$numberDecimal"]
if !ok {
return errors.New("not an extended JSON Decimal128: expected key $numberDecimal")
}
str, ok = d128.(string)
if !ok {
return errors.New("not an extended JSON Decimal128: expected decimal to be string")
}
}
*d, err = ParseDecimal128(str)
return err
}
var dNaN = Decimal128{0x1F << 58, 0}
var dPosInf = Decimal128{0x1E << 58, 0}
var dNegInf = Decimal128{0x3E << 58, 0}
func dErr(s string) (Decimal128, error) {
return dNaN, fmt.Errorf("cannot parse %q as a decimal128", s)
}
// match scientific notation number, example -10.15e-18
var normalNumber = regexp.MustCompile(`^(?P<int>[-+]?\d*)?(?:\.(?P<dec>\d*))?(?:[Ee](?P<exp>[-+]?\d+))?$`)
// ParseDecimal128 takes the given string and attempts to parse it into a valid
// Decimal128 value.
func ParseDecimal128(s string) (Decimal128, error) {
if s == "" {
return dErr(s)
}
matches := normalNumber.FindStringSubmatch(s)
if len(matches) == 0 {
orig := s
neg := s[0] == '-'
if neg || s[0] == '+' {
s = s[1:]
}
if s == "NaN" || s == "nan" || strings.EqualFold(s, "nan") {
return dNaN, nil
}
if s == "Inf" || s == "inf" || strings.EqualFold(s, "inf") || strings.EqualFold(s, "infinity") {
if neg {
return dNegInf, nil
}
return dPosInf, nil
}
return dErr(orig)
}
intPart := matches[1]
decPart := matches[2]
expPart := matches[3]
var err error
exp := 0
if expPart != "" {
exp, err = strconv.Atoi(expPart)
if err != nil {
return dErr(s)
}
}
if decPart != "" {
exp -= len(decPart)
}
if len(strings.Trim(intPart+decPart, "-0")) > 35 {
return dErr(s)
}
// Parse the significand (i.e. the non-exponent part) as a big.Int.
bi, ok := new(big.Int).SetString(intPart+decPart, 10)
if !ok {
return dErr(s)
}
d, ok := ParseDecimal128FromBigInt(bi, exp)
if !ok {
return dErr(s)
}
if bi.Sign() == 0 && s[0] == '-' {
d.h |= 1 << 63
}
return d, nil
}
var (
ten = big.NewInt(10)
zero = new(big.Int)
maxS, _ = new(big.Int).SetString("9999999999999999999999999999999999", 10)
)
// ParseDecimal128FromBigInt attempts to parse the given significand and exponent into a valid Decimal128 value.
func ParseDecimal128FromBigInt(bi *big.Int, exp int) (Decimal128, bool) {
// copy
bi = new(big.Int).Set(bi)
q := new(big.Int)
r := new(big.Int)
// If the significand is zero, the logical value will always be zero, independent of the
// exponent. However, the loops for handling out-of-range exponent values below may be extremely
// slow for zero values because the significand never changes. Limit the exponent value to the
// supported range here to prevent entering the loops below.
if bi.Cmp(zero) == 0 {
if exp > MaxDecimal128Exp {
exp = MaxDecimal128Exp
}
if exp < MinDecimal128Exp {
exp = MinDecimal128Exp
}
}
for bigIntCmpAbs(bi, maxS) == 1 {
bi, _ = q.QuoRem(bi, ten, r)
if r.Cmp(zero) != 0 {
return Decimal128{}, false
}
exp++
if exp > MaxDecimal128Exp {
return Decimal128{}, false
}
}
for exp < MinDecimal128Exp {
// Subnormal.
bi, _ = q.QuoRem(bi, ten, r)
if r.Cmp(zero) != 0 {
return Decimal128{}, false
}
exp++
}
for exp > MaxDecimal128Exp {
// Clamped.
bi.Mul(bi, ten)
if bigIntCmpAbs(bi, maxS) == 1 {
return Decimal128{}, false
}
exp--
}
b := bi.Bytes()
var h, l uint64
for i := 0; i < len(b); i++ {
if i < len(b)-8 {
h = h<<8 | uint64(b[i])
continue
}
l = l<<8 | uint64(b[i])
}
h |= uint64(exp-MinDecimal128Exp) & uint64(1<<14-1) << 49
if bi.Sign() == -1 {
h |= 1 << 63
}
return Decimal128{h: h, l: l}, true
}
// bigIntCmpAbs computes big.Int.Cmp(absoluteValue(x), absoluteValue(y)).
func bigIntCmpAbs(x, y *big.Int) int {
xAbs := bigIntAbsValue(x)
yAbs := bigIntAbsValue(y)
return xAbs.Cmp(yAbs)
}
// bigIntAbsValue returns a big.Int containing the absolute value of b.
// If b is already a non-negative number, it is returned without any changes or copies.
func bigIntAbsValue(b *big.Int) *big.Int {
if b.Sign() >= 0 {
return b // already positive
}
return new(big.Int).Abs(b)
}

236
decimal_test.go Normal file
View File

@ -0,0 +1,236 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/json"
"fmt"
"math/big"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
"gitea.psichedelico.com/go/bson/internal/require"
)
type bigIntTestCase struct {
s string
h uint64
l uint64
bi *big.Int
exp int
remark string
}
func parseBigInt(s string) *big.Int {
bi, _ := new(big.Int).SetString(s, 10)
return bi
}
var (
one = big.NewInt(1)
biMaxS = new(big.Int).SetBytes([]byte{0x1, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
biNMaxS = new(big.Int).Neg(biMaxS)
biOverflow = new(big.Int).Add(biMaxS, one)
biNOverflow = new(big.Int).Neg(biOverflow)
bi12345 = parseBigInt("12345")
biN12345 = parseBigInt("-12345")
bi9_14 = parseBigInt("90123456789012")
biN9_14 = parseBigInt("-90123456789012")
bi9_34 = parseBigInt("9999999999999999999999999999999999")
biN9_34 = parseBigInt("-9999999999999999999999999999999999")
)
var bigIntTestCases = []bigIntTestCase{
{s: "12345", h: 0x3040000000000000, l: 12345, bi: bi12345},
{s: "-12345", h: 0xB040000000000000, l: 12345, bi: biN12345},
{s: "90123456.789012", h: 0x3034000000000000, l: 90123456789012, bi: bi9_14, exp: -6},
{s: "-90123456.789012", h: 0xB034000000000000, l: 90123456789012, bi: biN9_14, exp: -6},
{s: "9.0123456789012E+22", h: 0x3052000000000000, l: 90123456789012, bi: bi9_14, exp: 9},
{s: "-9.0123456789012E+22", h: 0xB052000000000000, l: 90123456789012, bi: biN9_14, exp: 9},
{s: "9.0123456789012E-8", h: 0x3016000000000000, l: 90123456789012, bi: bi9_14, exp: -21},
{s: "-9.0123456789012E-8", h: 0xB016000000000000, l: 90123456789012, bi: biN9_14, exp: -21},
{s: "9999999999999999999999999999999999", h: 3477321013416265664, l: 4003012203950112767, bi: bi9_34},
{s: "-9999999999999999999999999999999999", h: 12700693050271041472, l: 4003012203950112767, bi: biN9_34},
{s: "0.9999999999999999999999999999999999", h: 3458180714999941056, l: 4003012203950112767, bi: bi9_34, exp: -34},
{s: "-0.9999999999999999999999999999999999", h: 12681552751854716864, l: 4003012203950112767, bi: biN9_34, exp: -34},
{s: "99999999999999999.99999999999999999", h: 3467750864208103360, l: 4003012203950112767, bi: bi9_34, exp: -17},
{s: "-99999999999999999.99999999999999999", h: 12691122901062879168, l: 4003012203950112767, bi: biN9_34, exp: -17},
{s: "9.999999999999999999999999999999999E+35", h: 3478446913323108288, l: 4003012203950112767, bi: bi9_34, exp: 2},
{s: "-9.999999999999999999999999999999999E+35", h: 12701818950177884096, l: 4003012203950112767, bi: biN9_34, exp: 2},
{s: "9.999999999999999999999999999999999E+40", h: 3481261663090214848, l: 4003012203950112767, bi: bi9_34, exp: 7},
{s: "-9.999999999999999999999999999999999E+40", h: 12704633699944990656, l: 4003012203950112767, bi: biN9_34, exp: 7},
{s: "99999999999999999999999999999.99999", h: 3474506263649159104, l: 4003012203950112767, bi: bi9_34, exp: -5},
{s: "-99999999999999999999999999999.99999", h: 12697878300503934912, l: 4003012203950112767, bi: biN9_34, exp: -5},
{s: "1.038459371706965525706099265844019E-6143", remark: "subnormal", h: 0x333333333333, l: 0x3333333333333333, bi: parseBigInt("10384593717069655257060992658440190"), exp: MinDecimal128Exp - 1},
{s: "-1.038459371706965525706099265844019E-6143", remark: "subnormal", h: 0x8000333333333333, l: 0x3333333333333333, bi: parseBigInt("-10384593717069655257060992658440190"), exp: MinDecimal128Exp - 1},
{s: "rounding overflow 1", remark: "overflow", bi: parseBigInt("103845937170696552570609926584401910"), exp: MaxDecimal128Exp},
{s: "rounding overflow 2", remark: "overflow", bi: parseBigInt("103845937170696552570609926584401910"), exp: MaxDecimal128Exp},
{s: "subnormal overflow 1", remark: "overflow", bi: biMaxS, exp: MinDecimal128Exp - 1},
{s: "subnormal overflow 2", remark: "overflow", bi: biNMaxS, exp: MinDecimal128Exp - 1},
{s: "clamped overflow 1", remark: "overflow", bi: biMaxS, exp: MaxDecimal128Exp + 1},
{s: "clamped overflow 2", remark: "overflow", bi: biNMaxS, exp: MaxDecimal128Exp + 1},
{s: "biMaxS+1 overflow", remark: "overflow", bi: biOverflow, exp: MaxDecimal128Exp},
{s: "biNMaxS-1 overflow", remark: "overflow", bi: biNOverflow, exp: MaxDecimal128Exp},
{s: "NaN", h: 0x7c00000000000000, l: 0, remark: "NaN"},
{s: "Infinity", h: 0x7800000000000000, l: 0, remark: "Infinity"},
{s: "-Infinity", h: 0xf800000000000000, l: 0, remark: "-Infinity"},
}
func TestDecimal128_BigInt(t *testing.T) {
for _, c := range bigIntTestCases {
t.Run(c.s, func(t *testing.T) {
switch c.remark {
case "NaN", "Infinity", "-Infinity":
d128 := NewDecimal128(c.h, c.l)
_, _, err := d128.BigInt()
require.Error(t, err, "case %s", c.s)
case "":
d128 := NewDecimal128(c.h, c.l)
bi, e, err := d128.BigInt()
require.NoError(t, err, "case %s", c.s)
require.Equal(t, 0, c.bi.Cmp(bi), "case %s e:%s a:%s", c.s, c.bi.String(), bi.String())
require.Equal(t, c.exp, e, "case %s", c.s, d128.String())
}
})
}
}
func TestParseDecimal128FromBigInt(t *testing.T) {
for _, c := range bigIntTestCases {
switch c.remark {
case "overflow":
d128, ok := ParseDecimal128FromBigInt(c.bi, c.exp)
require.Equal(t, false, ok, "case %s %s", c.s, d128.String(), c.remark)
case "", "rounding", "subnormal", "clamped":
d128, ok := ParseDecimal128FromBigInt(c.bi, c.exp)
require.Equal(t, true, ok, "case %s", c.s)
require.Equal(t, c.s, d128.String(), "case %s", c.s)
require.Equal(t, c.h, d128.h, "case %s", c.s, d128.l)
require.Equal(t, c.l, d128.l, "case %s", c.s, d128.h)
}
}
}
func TestParseDecimal128(t *testing.T) {
cases := make([]bigIntTestCase, 0, len(bigIntTestCases))
cases = append(cases, bigIntTestCases...)
cases = append(cases,
bigIntTestCase{s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800},
bigIntTestCase{s: "12345E+21", h: 0x306a000000000000, l: 12345},
bigIntTestCase{s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"},
bigIntTestCase{s: ".125e1", h: 0x303c000000000000, l: 125},
bigIntTestCase{s: ".125", h: 0x303a000000000000, l: 125},
// Test that parsing negative zero returns negative zero with a zero exponent.
bigIntTestCase{s: "-0", h: 0xb040000000000000, l: 0},
// Test that parsing negative zero with an in-range exponent returns negative zero and
// preserves the specified exponent value.
bigIntTestCase{s: "-0E999", h: 0xb80e000000000000, l: 0},
// Test that parsing zero with an out-of-range positive exponent returns zero with the
// maximum positive exponent (i.e. 0e+6111).
bigIntTestCase{s: "0E2000000000000", h: 0x5ffe000000000000, l: 0},
// Test that parsing zero with an out-of-range negative exponent returns zero with the
// minimum negative exponent (i.e. 0e-6176).
bigIntTestCase{s: "-0E2000000000000", h: 0xdffe000000000000, l: 0},
bigIntTestCase{s: "", remark: "parse fail"})
for _, c := range cases {
t.Run(c.s, func(t *testing.T) {
switch c.remark {
case "overflow", "parse fail":
_, err := ParseDecimal128(c.s)
assert.Error(t, err, "ParseDecimal128(%q) should return an error", c.s)
default:
got, err := ParseDecimal128(c.s)
require.NoError(t, err, "ParseDecimal128(%q) error", c.s)
want := Decimal128{h: c.h, l: c.l}
// Decimal128 doesn't implement an equality function, so compare the expected
// low/high uint64 values directly. Also print the string representation of each
// number to make debugging failures easier.
assert.Equal(t, want, got, "ParseDecimal128(%q) = %s, want %s", c.s, got, want)
}
})
}
}
func TestDecimal128_JSON(t *testing.T) {
t.Run("roundTrip", func(t *testing.T) {
decimal := NewDecimal128(0x3040000000000000, 12345)
bytes, err := json.Marshal(decimal)
assert.Nil(t, err, "json.Marshal error: %v", err)
got := NewDecimal128(0, 0)
err = json.Unmarshal(bytes, &got)
assert.Nil(t, err, "json.Unmarshal error: %v", err)
assert.Equal(t, decimal.h, got.h, "expected h: %v got: %v", decimal.h, got.h)
assert.Equal(t, decimal.l, got.l, "expected l: %v got: %v", decimal.l, got.l)
})
t.Run("unmarshal extendedJSON", func(t *testing.T) {
want := NewDecimal128(0x3040000000000000, 12345)
extJSON := fmt.Sprintf(`{"$numberDecimal": %q}`, want.String())
got := NewDecimal128(0, 0)
err := json.Unmarshal([]byte(extJSON), &got)
assert.Nil(t, err, "json.Unmarshal error: %v", err)
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
})
t.Run("unmarshal null", func(t *testing.T) {
want := NewDecimal128(0, 0)
extJSON := `null`
got := NewDecimal128(0, 0)
err := json.Unmarshal([]byte(extJSON), &got)
assert.Nil(t, err, "json.Unmarshal error: %v", err)
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
})
t.Run("unmarshal", func(t *testing.T) {
cases := make([]bigIntTestCase, 0, len(bigIntTestCases))
cases = append(cases, bigIntTestCases...)
cases = append(cases,
bigIntTestCase{s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800},
bigIntTestCase{s: "12345E+21", h: 0x306a000000000000, l: 12345},
bigIntTestCase{s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"},
bigIntTestCase{s: ".125e1", h: 0x303c000000000000, l: 125},
bigIntTestCase{s: ".125", h: 0x303a000000000000, l: 125})
for _, c := range cases {
t.Run(c.s, func(t *testing.T) {
input := fmt.Sprintf(`{"foo": %q}`, c.s)
var got map[string]Decimal128
err := json.Unmarshal([]byte(input), &got)
switch c.remark {
case "overflow", "parse fail":
assert.NotNil(t, err, "expected Unmarshal error, got nil")
default:
assert.Nil(t, err, "Unmarshal error: %v", err)
gotDecimal := got["foo"]
assert.Equal(t, c.h, gotDecimal.h, "expected h: %v got: %v", c.h, gotDecimal.l)
assert.Equal(t, c.l, gotDecimal.l, "expected l: %v got: %v", c.l, gotDecimal.h)
}
})
}
})
}

136
decoder.go Normal file
View File

@ -0,0 +1,136 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"reflect"
"sync"
)
// ErrDecodeToNil is the error returned when trying to decode to a nil value
var ErrDecodeToNil = errors.New("cannot Decode to nil value")
// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Decoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var decPool = sync.Pool{
New: func() interface{} {
return new(Decoder)
},
}
// A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as
// the source of BSON data.
type Decoder struct {
dc DecodeContext
vr ValueReader
}
// NewDecoder returns a new decoder that reads from vr.
func NewDecoder(vr ValueReader) *Decoder {
return &Decoder{
dc: DecodeContext{Registry: defaultRegistry},
vr: vr,
}
}
// Decode reads the next BSON document from the stream and decodes it into the
// value pointed to by val.
//
// See [Unmarshal] for details about BSON unmarshaling behavior.
func (d *Decoder) Decode(val interface{}) error {
if unmarshaler, ok := val.(Unmarshaler); ok {
// TODO(skriptble): Reuse a []byte here and use the AppendDocumentBytes method.
buf, err := copyDocumentToBytes(d.vr)
if err != nil {
return err
}
return unmarshaler.UnmarshalBSON(buf)
}
rval := reflect.ValueOf(val)
switch rval.Kind() {
case reflect.Ptr:
if rval.IsNil() {
return ErrDecodeToNil
}
rval = rval.Elem()
case reflect.Map:
if rval.IsNil() {
return ErrDecodeToNil
}
default:
return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval)
}
decoder, err := d.dc.LookupDecoder(rval.Type())
if err != nil {
return err
}
return decoder.DecodeValue(d.dc, d.vr, rval)
}
// Reset will reset the state of the decoder, using the same *DecodeContext used in
// the original construction but using vr for reading.
func (d *Decoder) Reset(vr ValueReader) {
d.vr = vr
}
// SetRegistry replaces the current registry of the decoder with r.
func (d *Decoder) SetRegistry(r *Registry) {
d.dc.Registry = r
}
// DefaultDocumentM causes the Decoder to always unmarshal documents into the bson.M type. This
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
func (d *Decoder) DefaultDocumentM() {
d.dc.defaultDocumentType = reflect.TypeOf(M{})
}
// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values
// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct
// field. The truncation logic does not apply to BSON "decimal128" values.
func (d *Decoder) AllowTruncatingDoubles() {
d.dc.truncate = true
}
// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or
// "Old" BSON binary subtype as a Go byte slice instead of a bson.Binary.
func (d *Decoder) BinaryAsSlice() {
d.dc.binaryAsSlice = true
}
// ObjectIDAsHexString causes the Decoder to decode object IDs to their hex representation.
func (d *Decoder) ObjectIDAsHexString() {
d.dc.objectIDAsHexString = true
}
// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson"
// struct tag is not specified.
func (d *Decoder) UseJSONStructTags() {
d.dc.useJSONStructTags = true
}
// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead
// of the UTC timezone.
func (d *Decoder) UseLocalTimeZone() {
d.dc.useLocalTimeZone = true
}
// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value
// passed to Decode before unmarshaling BSON documents into them.
func (d *Decoder) ZeroMaps() {
d.dc.zeroMaps = true
}
// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination
// value passed to Decode before unmarshaling BSON documents into them.
func (d *Decoder) ZeroStructs() {
d.dc.zeroStructs = true
}

208
decoder_example_test.go Normal file
View File

@ -0,0 +1,208 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson_test
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"gitea.psichedelico.com/go/bson"
)
func ExampleDecoder() {
// Marshal a BSON document that contains the name, SKU, and price (in cents)
// of a product.
doc := bson.D{
{Key: "name", Value: "Cereal Rounds"},
{Key: "sku", Value: "AB12345"},
{Key: "price_cents", Value: 399},
}
data, err := bson.Marshal(doc)
if err != nil {
panic(err)
}
// Create a Decoder that reads the marshaled BSON document and use it to
// unmarshal the document into a Product struct.
decoder := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(data)))
type Product struct {
Name string `bson:"name"`
SKU string `bson:"sku"`
Price int64 `bson:"price_cents"`
}
var res Product
err = decoder.Decode(&res)
if err != nil {
panic(err)
}
fmt.Printf("%+v\n", res)
// Output: {Name:Cereal Rounds SKU:AB12345 Price:399}
}
func ExampleDecoder_DefaultDocumentM() {
// Marshal a BSON document that contains a city name and a nested document
// with various city properties.
doc := bson.D{
{Key: "name", Value: "New York"},
{Key: "properties", Value: bson.D{
{Key: "state", Value: "NY"},
{Key: "population", Value: 8_804_190},
{Key: "elevation", Value: 10},
}},
}
data, err := bson.Marshal(doc)
if err != nil {
panic(err)
}
// Create a Decoder that reads the marshaled BSON document and use it to unmarshal the document
// into a City struct.
decoder := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(data)))
type City struct {
Name string `bson:"name"`
Properties interface{} `bson:"properties"`
}
// Configure the Decoder to default to decoding BSON documents as the M
// type if the decode destination has no type information. The Properties
// field in the City struct will be decoded as a "M" (i.e. map) instead
// of the default "D".
decoder.DefaultDocumentM()
var res City
err = decoder.Decode(&res)
if err != nil {
panic(err)
}
data, err = json.Marshal(res)
if err != nil {
panic(err)
}
fmt.Printf("%+v\n", string(data))
// Output: {"Name":"New York","Properties":{"elevation":10,"population":8804190,"state":"NY"}}
}
func ExampleDecoder_UseJSONStructTags() {
// Marshal a BSON document that contains the name, SKU, and price (in cents)
// of a product.
doc := bson.D{
{Key: "name", Value: "Cereal Rounds"},
{Key: "sku", Value: "AB12345"},
{Key: "price_cents", Value: 399},
}
data, err := bson.Marshal(doc)
if err != nil {
panic(err)
}
// Create a Decoder that reads the marshaled BSON document and use it to
// unmarshal the document into a Product struct.
decoder := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(data)))
type Product struct {
Name string `json:"name"`
SKU string `json:"sku"`
Price int64 `json:"price_cents"`
}
// Configure the Decoder to use "json" struct tags when decoding if "bson"
// struct tags are not present.
decoder.UseJSONStructTags()
var res Product
err = decoder.Decode(&res)
if err != nil {
panic(err)
}
fmt.Printf("%+v\n", res)
// Output: {Name:Cereal Rounds SKU:AB12345 Price:399}
}
func ExampleDecoder_extendedJSON() {
// Define an Extended JSON document that contains the name, SKU, and price
// (in cents) of a product.
data := []byte(`{"name":"Cereal Rounds","sku":"AB12345","price_cents":{"$numberLong":"399"}}`)
// Create a Decoder that reads the Extended JSON document and use it to
// unmarshal the document into a Product struct.
vr, err := bson.NewExtJSONValueReader(bytes.NewReader(data), true)
if err != nil {
panic(err)
}
decoder := bson.NewDecoder(vr)
type Product struct {
Name string `bson:"name"`
SKU string `bson:"sku"`
Price int64 `bson:"price_cents"`
}
var res Product
err = decoder.Decode(&res)
if err != nil {
panic(err)
}
fmt.Printf("%+v\n", res)
// Output: {Name:Cereal Rounds SKU:AB12345 Price:399}
}
func ExampleDecoder_multipleExtendedJSONDocuments() {
// Define a newline-separated sequence of Extended JSON documents that
// contain X,Y coordinates.
data := []byte(`
{"x":{"$numberInt":"0"},"y":{"$numberInt":"0"}}
{"x":{"$numberInt":"1"},"y":{"$numberInt":"1"}}
{"x":{"$numberInt":"2"},"y":{"$numberInt":"2"}}
{"x":{"$numberInt":"3"},"y":{"$numberInt":"3"}}
{"x":{"$numberInt":"4"},"y":{"$numberInt":"4"}}
`)
// Create a Decoder that reads the Extended JSON documents and use it to
// unmarshal the documents Coordinate structs.
vr, err := bson.NewExtJSONValueReader(bytes.NewReader(data), true)
if err != nil {
panic(err)
}
decoder := bson.NewDecoder(vr)
type Coordinate struct {
X int
Y int
}
// Read and unmarshal each Extended JSON document from the sequence. If
// Decode returns error io.EOF, that means the Decoder has reached the end
// of the input, so break the loop.
for {
var res Coordinate
err = decoder.Decode(&res)
if errors.Is(err, io.EOF) {
break
}
if err != nil {
panic(err)
}
fmt.Printf("%+v\n", res)
}
// Output:
// {X:0 Y:0}
// {X:1 Y:1}
// {X:2 Y:2}
// {X:3 Y:3}
// {X:4 Y:4}
}

699
decoder_test.go Normal file
View File

@ -0,0 +1,699 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"reflect"
"testing"
"time"
"gitea.psichedelico.com/go/bson/internal/assert"
"gitea.psichedelico.com/go/bson/internal/require"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
func TestDecodeValue(t *testing.T) {
t.Parallel()
for _, tc := range unmarshalingTestCases() {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := reflect.New(tc.sType).Elem()
vr := NewDocumentReader(bytes.NewReader(tc.data))
reg := defaultRegistry
decoder, err := reg.LookupDecoder(reflect.TypeOf(got))
noerr(t, err)
err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got)
noerr(t, err)
assert.Equal(t, tc.want, got.Addr().Interface(), "Results do not match.")
})
}
}
func TestDecodingInterfaces(t *testing.T) {
t.Parallel()
type testCase struct {
name string
stub func() ([]byte, interface{}, func(*testing.T))
}
testCases := []testCase{
{
name: "struct with interface containing a concrete value",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type testStruct struct {
Value interface{}
}
var value string
data := docToBytes(struct {
Value string
}{
Value: "foo",
})
receiver := testStruct{&value}
check := func(t *testing.T) {
t.Helper()
assert.Equal(t, "foo", value)
}
return data, &receiver, check
},
},
{
name: "struct with interface containing a struct",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type demo struct {
Data string
}
type testStruct struct {
Value interface{}
}
var value demo
data := docToBytes(struct {
Value demo
}{
Value: demo{"foo"},
})
receiver := testStruct{&value}
check := func(t *testing.T) {
t.Helper()
assert.Equal(t, "foo", value.Data)
}
return data, &receiver, check
},
},
{
name: "struct with interface containing a slice",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type testStruct struct {
Values interface{}
}
var values []string
data := docToBytes(struct {
Values []string
}{
Values: []string{"foo", "bar"},
})
receiver := testStruct{&values}
check := func(t *testing.T) {
t.Helper()
assert.Equal(t, []string{"foo", "bar"}, values)
}
return data, &receiver, check
},
},
{
name: "struct with interface containing an array",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type testStruct struct {
Values interface{}
}
var values [2]string
data := docToBytes(struct {
Values []string
}{
Values: []string{"foo", "bar"},
})
receiver := testStruct{&values}
check := func(t *testing.T) {
t.Helper()
assert.Equal(t, [2]string{"foo", "bar"}, values)
}
return data, &receiver, check
},
},
{
name: "struct with interface array containing concrete values",
stub: func() ([]byte, interface{}, func(*testing.T)) {
type testStruct struct {
Values [3]interface{}
}
var str string
var i, j int
data := docToBytes(struct {
Values []interface{}
}{
Values: []interface{}{"foo", 42, nil},
})
receiver := testStruct{[3]interface{}{&str, &i, &j}}
check := func(t *testing.T) {
t.Helper()
assert.Equal(t, "foo", str)
assert.Equal(t, 42, i)
assert.Equal(t, 0, j)
assert.Equal(t, testStruct{[3]interface{}{&str, &i, nil}}, receiver)
}
return data, &receiver, check
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
data, receiver, check := tc.stub()
got := reflect.ValueOf(receiver).Elem()
vr := NewDocumentReader(bytes.NewReader(data))
reg := defaultRegistry
decoder, err := reg.LookupDecoder(got.Type())
noerr(t, err)
err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got)
noerr(t, err)
check(t)
})
}
}
func TestDecoder(t *testing.T) {
t.Parallel()
t.Run("Decode", func(t *testing.T) {
t.Parallel()
t.Run("basic", func(t *testing.T) {
t.Parallel()
for _, tc := range unmarshalingTestCases() {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := reflect.New(tc.sType).Interface()
vr := NewDocumentReader(bytes.NewReader(tc.data))
dec := NewDecoder(vr)
err := dec.Decode(got)
noerr(t, err)
assert.Equal(t, tc.want, got, "Results do not match.")
})
}
})
t.Run("stream", func(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
vr := NewDocumentReader(&buf)
dec := NewDecoder(vr)
for _, tc := range unmarshalingTestCases() {
tc := tc
t.Run(tc.name, func(t *testing.T) {
buf.Write(tc.data)
got := reflect.New(tc.sType).Interface()
err := dec.Decode(got)
noerr(t, err)
assert.Equal(t, tc.want, got, "Results do not match.")
})
}
})
t.Run("lookup error", func(t *testing.T) {
t.Parallel()
type certainlydoesntexistelsewhereihope func(string, string) string
// Avoid unused code lint error.
_ = certainlydoesntexistelsewhereihope(func(string, string) string { return "" })
cdeih := func(string, string) string { return "certainlydoesntexistelsewhereihope" }
dec := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
want := errNoDecoder{Type: reflect.TypeOf(cdeih)}
got := dec.Decode(&cdeih)
assert.Equal(t, want, got, "Received unexpected error.")
})
t.Run("Unmarshaler", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
err error
vr ValueReader
invoked bool
}{
{
"error",
errors.New("Unmarshaler error"),
&valueReaderWriter{BSONType: TypeEmbeddedDocument, Err: ErrEOD, ErrAfter: readElement},
true,
},
{
"copy error",
errors.New("copy error"),
&valueReaderWriter{Err: errors.New("copy error"), ErrAfter: readDocument},
false,
},
{
"success",
nil,
&valueReaderWriter{BSONType: TypeEmbeddedDocument, Err: ErrEOD, ErrAfter: readElement},
true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
unmarshaler := &testUnmarshaler{Err: tc.err}
dec := NewDecoder(tc.vr)
got := dec.Decode(unmarshaler)
want := tc.err
if !assert.CompareErrors(got, want) {
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
}
if unmarshaler.Invoked != tc.invoked {
if tc.invoked {
t.Error("Expected to have UnmarshalBSON invoked, but it wasn't.")
} else {
t.Error("Expected UnmarshalBSON to not be invoked, but it was.")
}
}
})
}
t.Run("Unmarshaler/success ValueReader", func(t *testing.T) {
t.Parallel()
want := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))
unmarshaler := &testUnmarshaler{}
vr := NewDocumentReader(bytes.NewReader(want))
dec := NewDecoder(vr)
err := dec.Decode(unmarshaler)
noerr(t, err)
got := unmarshaler.Val
if !bytes.Equal(got, want) {
t.Errorf("Did not unmarshal properly. got %v; want %v", got, want)
}
})
})
})
t.Run("NewDecoder", func(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
got := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
if got == nil {
t.Errorf("Was expecting a non-nil Decoder, but got <nil>")
}
})
})
t.Run("NewDecoderWithContext", func(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
got := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
if got == nil {
t.Errorf("Was expecting a non-nil Decoder, but got <nil>")
}
})
})
t.Run("Decode doesn't zero struct", func(t *testing.T) {
t.Parallel()
type foo struct {
Item string
Qty int
Bonus int
}
var got foo
got.Item = "apple"
got.Bonus = 2
data := docToBytes(D{{"item", "canvas"}, {"qty", 4}})
vr := NewDocumentReader(bytes.NewReader(data))
dec := NewDecoder(vr)
err := dec.Decode(&got)
noerr(t, err)
want := foo{Item: "canvas", Qty: 4, Bonus: 2}
assert.Equal(t, want, got, "Results do not match.")
})
t.Run("Reset", func(t *testing.T) {
t.Parallel()
vr1, vr2 := NewDocumentReader(bytes.NewReader([]byte{})), NewDocumentReader(bytes.NewReader([]byte{}))
dec := NewDecoder(vr1)
if dec.vr != vr1 {
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1)
}
dec.Reset(vr2)
if dec.vr != vr2 {
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2)
}
})
t.Run("SetRegistry", func(t *testing.T) {
t.Parallel()
r1, r2 := defaultRegistry, NewRegistry()
dc1 := DecodeContext{Registry: r1}
dc2 := DecodeContext{Registry: r2}
dec := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
if !reflect.DeepEqual(dec.dc, dc1) {
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1)
}
dec.SetRegistry(r2)
if !reflect.DeepEqual(dec.dc, dc2) {
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2)
}
})
t.Run("DecodeToNil", func(t *testing.T) {
t.Parallel()
data := docToBytes(D{{"item", "canvas"}, {"qty", 4}})
vr := NewDocumentReader(bytes.NewReader(data))
dec := NewDecoder(vr)
var got *D
err := dec.Decode(got)
if !errors.Is(err, ErrDecodeToNil) {
t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err)
}
})
}
type testUnmarshaler struct {
Invoked bool
Val []byte
Err error
}
func (tu *testUnmarshaler) UnmarshalBSON(d []byte) error {
tu.Invoked = true
tu.Val = d
return tu.Err
}
func TestDecoderConfiguration(t *testing.T) {
type truncateDoublesTest struct {
MyInt int
MyInt8 int8
MyInt16 int16
MyInt32 int32
MyInt64 int64
MyUint uint
MyUint8 uint8
MyUint16 uint16
MyUint32 uint32
MyUint64 uint64
}
type objectIDTest struct {
ID string
}
type jsonStructTest struct {
StructFieldName string `json:"jsonFieldName"`
}
type localTimeZoneTest struct {
MyTime time.Time
}
type zeroMapsTest struct {
MyMap map[string]string
}
type zeroStructsTest struct {
MyString string
MyInt int
}
testCases := []struct {
description string
configure func(*Decoder)
input []byte
decodeInto func() interface{}
want interface{}
}{
// Test that AllowTruncatingDoubles causes the Decoder to unmarshal BSON doubles with
// fractional parts into Go integer types by truncating the fractional part.
{
description: "AllowTruncatingDoubles",
configure: func(dec *Decoder) {
dec.AllowTruncatingDoubles()
},
input: bsoncore.NewDocumentBuilder().
AppendDouble("myInt", 1.999).
AppendDouble("myInt8", 1.999).
AppendDouble("myInt16", 1.999).
AppendDouble("myInt32", 1.999).
AppendDouble("myInt64", 1.999).
AppendDouble("myUint", 1.999).
AppendDouble("myUint8", 1.999).
AppendDouble("myUint16", 1.999).
AppendDouble("myUint32", 1.999).
AppendDouble("myUint64", 1.999).
Build(),
decodeInto: func() interface{} { return &truncateDoublesTest{} },
want: &truncateDoublesTest{
MyInt: 1,
MyInt8: 1,
MyInt16: 1,
MyInt32: 1,
MyInt64: 1,
MyUint: 1,
MyUint8: 1,
MyUint16: 1,
MyUint32: 1,
MyUint64: 1,
},
},
// Test that BinaryAsSlice causes the Decoder to unmarshal BSON binary fields into Go byte
// slices when there is no type information (e.g when unmarshaling into a bson.D).
{
description: "BinaryAsSlice",
configure: func(dec *Decoder) {
dec.BinaryAsSlice()
},
input: bsoncore.NewDocumentBuilder().
AppendBinary("myBinary", TypeBinaryGeneric, []byte{}).
Build(),
decodeInto: func() interface{} { return &D{} },
want: &D{{Key: "myBinary", Value: []byte{}}},
},
// Test that the default decoder always decodes BSON documents into bson.D values,
// independent of the top-level Go value type.
{
description: "DocumentD nested by default",
configure: func(_ *Decoder) {},
input: bsoncore.NewDocumentBuilder().
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
AppendString("myString", "test value").
Build()).
Build(),
decodeInto: func() interface{} { return M{} },
want: M{
"myDocument": D{{Key: "myString", Value: "test value"}},
},
},
// Test that DefaultDocumentM always decodes BSON documents into bson.M values,
// independent of the top-level Go value type.
{
description: "DefaultDocumentM nested",
configure: func(dec *Decoder) {
dec.DefaultDocumentM()
},
input: bsoncore.NewDocumentBuilder().
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
AppendString("myString", "test value").
Build()).
Build(),
decodeInto: func() interface{} { return &D{} },
want: &D{
{Key: "myDocument", Value: M{"myString": "test value"}},
},
},
// Test that ObjectIDAsHexString causes the Decoder to decode object ID to hex.
{
description: "ObjectIDAsHexString",
configure: func(dec *Decoder) {
dec.ObjectIDAsHexString()
},
input: bsoncore.NewDocumentBuilder().
AppendObjectID("id", func() ObjectID {
id, _ := ObjectIDFromHex("5ef7fdd91c19e3222b41b839")
return id
}()).
Build(),
decodeInto: func() interface{} { return &objectIDTest{} },
want: &objectIDTest{ID: "5ef7fdd91c19e3222b41b839"},
},
// Test that UseJSONStructTags causes the Decoder to fall back to "json" struct tags if
// "bson" struct tags are not available.
{
description: "UseJSONStructTags",
configure: func(dec *Decoder) {
dec.UseJSONStructTags()
},
input: bsoncore.NewDocumentBuilder().
AppendString("jsonFieldName", "test value").
Build(),
decodeInto: func() interface{} { return &jsonStructTest{} },
want: &jsonStructTest{StructFieldName: "test value"},
},
// Test that UseLocalTimeZone causes the Decoder to use the local time zone for decoded
// time.Time values instead of UTC.
{
description: "UseLocalTimeZone",
configure: func(dec *Decoder) {
dec.UseLocalTimeZone()
},
input: bsoncore.NewDocumentBuilder().
AppendDateTime("myTime", 1684349179939).
Build(),
decodeInto: func() interface{} { return &localTimeZoneTest{} },
want: &localTimeZoneTest{MyTime: time.UnixMilli(1684349179939)},
},
// Test that ZeroMaps causes the Decoder to empty any Go map values before decoding BSON
// documents into them.
{
description: "ZeroMaps",
configure: func(dec *Decoder) {
dec.ZeroMaps()
},
input: bsoncore.NewDocumentBuilder().
AppendDocument("myMap", bsoncore.NewDocumentBuilder().
AppendString("myString", "test value").
Build()).
Build(),
decodeInto: func() interface{} {
return &zeroMapsTest{MyMap: map[string]string{"myExtraValue": "extra value"}}
},
want: &zeroMapsTest{MyMap: map[string]string{"myString": "test value"}},
},
// Test that ZeroStructs causes the Decoder to empty any Go struct values before decoding
// BSON documents into them.
{
description: "ZeroStructs",
configure: func(dec *Decoder) {
dec.ZeroStructs()
},
input: bsoncore.NewDocumentBuilder().
AppendString("myString", "test value").
Build(),
decodeInto: func() interface{} {
return &zeroStructsTest{MyInt: 1}
},
want: &zeroStructsTest{MyString: "test value"},
},
}
for _, tc := range testCases {
tc := tc // Capture range variable.
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
dec := NewDecoder(NewDocumentReader(bytes.NewReader(tc.input)))
tc.configure(dec)
got := tc.decodeInto()
err := dec.Decode(got)
require.NoError(t, err, "Decode error")
assert.Equal(t, tc.want, got, "expected and actual decode results do not match")
})
}
t.Run("Decoding an object ID to string", func(t *testing.T) {
t.Parallel()
type objectIDTest struct {
ID string
}
doc := bsoncore.NewDocumentBuilder().
AppendObjectID("id", func() ObjectID {
id, _ := ObjectIDFromHex("5ef7fdd91c19e3222b41b839")
return id
}()).
Build()
dec := NewDecoder(NewDocumentReader(bytes.NewReader(doc)))
var got objectIDTest
err := dec.Decode(&got)
const want = "error decoding key id: decoding an object ID into a string is not supported by default (set Decoder.ObjectIDAsHexString to enable decoding as a hexadecimal string)"
assert.EqualError(t, err, want)
})
t.Run("DefaultDocumentM top-level", func(t *testing.T) {
t.Parallel()
input := bsoncore.NewDocumentBuilder().
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
AppendString("myString", "test value").
Build()).
Build()
dec := NewDecoder(NewDocumentReader(bytes.NewReader(input)))
dec.DefaultDocumentM()
var got interface{}
err := dec.Decode(&got)
require.NoError(t, err, "Decode error")
want := M{
"myDocument": M{
"myString": "test value",
},
}
assert.Equal(t, want, got, "expected and actual decode results do not match")
})
t.Run("Default decodes DocumentD for top-level", func(t *testing.T) {
t.Parallel()
input := bsoncore.NewDocumentBuilder().
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
AppendString("myString", "test value").
Build()).
Build()
dec := NewDecoder(NewDocumentReader(bytes.NewReader(input)))
var got interface{}
err := dec.Decode(&got)
require.NoError(t, err, "Decode error")
want := D{
{Key: "myDocument", Value: D{
{Key: "myString", Value: "test value"},
}},
}
assert.Equal(t, want, got, "expected and actual decode results do not match")
})
}

1497
default_value_decoders.go Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

517
default_value_encoders.go Normal file
View File

@ -0,0 +1,517 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/json"
"errors"
"math"
"net/url"
"reflect"
"sync"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
var bvwPool = sync.Pool{
New: func() interface{} {
return new(valueWriter)
},
}
var errInvalidValue = errors.New("cannot encode invalid element")
var sliceWriterPool = sync.Pool{
New: func() interface{} {
sw := make(sliceWriter, 0)
return &sw
},
}
func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error {
vw, err := dw.WriteDocumentElement(e.Key)
if err != nil {
return err
}
if e.Value == nil {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value))
if err != nil {
return err
}
return nil
}
// registerDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with
// the provided RegistryBuilder.
func registerDefaultEncoders(reg *Registry) {
mapEncoder := &mapCodec{}
uintCodec := &uintCodec{}
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{})
reg.RegisterTypeEncoder(tTime, &timeCodec{})
reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{})
reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{})
reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue))
reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue))
reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue))
reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue))
reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue))
reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue))
reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue))
reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue))
reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue))
reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue))
reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue))
reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue))
reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue))
reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue))
reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue))
reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue))
reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue))
reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue))
reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue))
reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int16, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int32, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Int64, ValueEncoderFunc(intEncodeValue))
reg.RegisterKindEncoder(reflect.Uint, uintCodec)
reg.RegisterKindEncoder(reflect.Uint8, uintCodec)
reg.RegisterKindEncoder(reflect.Uint16, uintCodec)
reg.RegisterKindEncoder(reflect.Uint32, uintCodec)
reg.RegisterKindEncoder(reflect.Uint64, uintCodec)
reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue))
reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue))
reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue))
reg.RegisterKindEncoder(reflect.Map, mapEncoder)
reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{})
reg.RegisterKindEncoder(reflect.String, &stringCodec{})
reg.RegisterKindEncoder(reflect.Struct, newStructCodec(mapEncoder))
reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{})
reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue))
reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue))
}
// booleanEncodeValue is the ValueEncoderFunc for bool types.
func booleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Bool {
return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val}
}
return vw.WriteBoolean(val.Bool())
}
func fitsIn32Bits(i int64) bool {
return math.MinInt32 <= i && i <= math.MaxInt32
}
// intEncodeValue is the ValueEncoderFunc for int types.
func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32:
return vw.WriteInt32(int32(val.Int()))
case reflect.Int:
i64 := val.Int()
if fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
case reflect.Int64:
i64 := val.Int()
if ec.minSize && fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
}
return ValueEncoderError{
Name: "IntEncodeValue",
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
Received: val,
}
}
// floatEncodeValue is the ValueEncoderFunc for float types.
func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Float32, reflect.Float64:
return vw.WriteDouble(val.Float())
}
return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val}
}
// objectIDEncodeValue is the ValueEncoderFunc for ObjectID.
func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tOID {
return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val}
}
return vw.WriteObjectID(val.Interface().(ObjectID))
}
// decimal128EncodeValue is the ValueEncoderFunc for Decimal128.
func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDecimal {
return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val}
}
return vw.WriteDecimal128(val.Interface().(Decimal128))
}
// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number.
func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJSONNumber {
return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val}
}
jsnum := val.Interface().(json.Number)
// Attempt int first, then float64
if i64, err := jsnum.Int64(); err == nil {
return intEncodeValue(ec, vw, reflect.ValueOf(i64))
}
f64, err := jsnum.Float64()
if err != nil {
return err
}
return floatEncodeValue(ec, vw, reflect.ValueOf(f64))
}
// urlEncodeValue is the ValueEncoderFunc for url.URL.
func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tURL {
return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val}
}
u := val.Interface().(url.URL)
return vw.WriteString(u.String())
}
// arrayEncodeValue is the ValueEncoderFunc for array types.
func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Array {
return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val}
}
// If we have a []E we want to treat it as a document instead of as an array.
if val.Type().Elem() == tE {
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for idx := 0; idx < val.Len(); idx++ {
e := val.Index(idx).Interface().(E)
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// If we have a []byte we want to treat it as a binary instead of as an array.
if val.Type().Elem() == tByte {
var byteSlice []byte
for idx := 0; idx < val.Len(); idx++ {
byteSlice = append(byteSlice, val.Index(idx).Interface().(byte))
}
return vw.WriteBinary(byteSlice)
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if errors.Is(lookupErr, errInvalidValue) {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) {
if origEncoder != nil || (currVal.Kind() != reflect.Interface) {
return origEncoder, currVal, nil
}
currVal = currVal.Elem()
if !currVal.IsValid() {
return nil, currVal, errInvalidValue
}
currEncoder, err := ec.LookupEncoder(currVal.Type())
return currEncoder, currVal, err
}
// valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations.
func valueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement ValueMarshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
case val.Type().Implements(tValueMarshaler):
// If ValueMarshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tValueMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tValueMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
}
m, ok := val.Interface().(ValueMarshaler)
if !ok {
return vw.WriteNull()
}
t, data, err := m.MarshalBSONValue()
if err != nil {
return err
}
return copyValueFromBytes(vw, Type(t), data)
}
// marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations.
func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement Marshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
case val.Type().Implements(tMarshaler):
// If Marshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
}
m, ok := val.Interface().(Marshaler)
if !ok {
return vw.WriteNull()
}
data, err := m.MarshalBSON()
if err != nil {
return err
}
return copyValueFromBytes(vw, TypeEmbeddedDocument, data)
}
// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type.
func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJavaScript {
return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val}
}
return vw.WriteJavascript(val.String())
}
// symbolEncodeValue is the ValueEncoderFunc for the Symbol type.
func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tSymbol {
return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val}
}
return vw.WriteSymbol(val.String())
}
// binaryEncodeValue is the ValueEncoderFunc for Binary.
func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tBinary {
return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val}
}
b := val.Interface().(Binary)
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}
// vectorEncodeValue is the ValueEncoderFunc for Vector.
func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
t := val.Type()
if !val.IsValid() || t != tVector {
return ValueEncoderError{Name: "VectorEncodeValue",
Types: []reflect.Type{tVector},
Received: val,
}
}
v := val.Interface().(Vector)
b := v.Binary()
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}
// undefinedEncodeValue is the ValueEncoderFunc for Undefined.
func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tUndefined {
return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val}
}
return vw.WriteUndefined()
}
// dateTimeEncodeValue is the ValueEncoderFunc for DateTime.
func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDateTime {
return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val}
}
return vw.WriteDateTime(val.Int())
}
// nullEncodeValue is the ValueEncoderFunc for Null.
func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tNull {
return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val}
}
return vw.WriteNull()
}
// regexEncodeValue is the ValueEncoderFunc for Regex.
func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRegex {
return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val}
}
regex := val.Interface().(Regex)
return vw.WriteRegex(regex.Pattern, regex.Options)
}
// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer.
func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDBPointer {
return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val}
}
dbp := val.Interface().(DBPointer)
return vw.WriteDBPointer(dbp.DB, dbp.Pointer)
}
// timestampEncodeValue is the ValueEncoderFunc for Timestamp.
func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTimestamp {
return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val}
}
ts := val.Interface().(Timestamp)
return vw.WriteTimestamp(ts.T, ts.I)
}
// minKeyEncodeValue is the ValueEncoderFunc for MinKey.
func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMinKey {
return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val}
}
return vw.WriteMinKey()
}
// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey.
func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMaxKey {
return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val}
}
return vw.WriteMaxKey()
}
// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document.
func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreDocument {
return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val}
}
cdoc := val.Interface().(bsoncore.Document)
return copyDocumentFromBytes(vw, cdoc)
}
// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope.
func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCodeWithScope {
return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val}
}
cws := val.Interface().(CodeWithScope)
dw, err := vw.WriteCodeWithScope(string(cws.Code))
if err != nil {
return err
}
sw := sliceWriterPool.Get().(*sliceWriter)
defer sliceWriterPool.Put(sw)
*sw = (*sw)[:0]
scopeVW := bvwPool.Get().(*valueWriter)
scopeVW.reset(scopeVW.buf[:0])
scopeVW.w = sw
defer bvwPool.Put(scopeVW)
encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope))
if err != nil {
return err
}
err = copyBytesToDocumentWriter(dw, *sw)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type
func isImplementationNil(val reflect.Value, inter reflect.Type) bool {
vt := val.Type()
for vt.Kind() == reflect.Ptr {
vt = vt.Elem()
}
return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil()
}

File diff suppressed because it is too large Load Diff

155
doc.go Normal file
View File

@ -0,0 +1,155 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bson is a library for reading, writing, and manipulating BSON. BSON is a binary serialization format used to
// store documents and make remote procedure calls in MongoDB. The BSON specification is located at https://bsonspec.org.
// The BSON library handles marshaling and unmarshaling of values through a configurable codec system. For a description
// of the codec system and examples of registering custom codecs, see the bsoncodec package. For additional information
// and usage examples, check out the [Work with BSON] page in the Go Driver docs site.
//
// # Raw BSON
//
// The Raw family of types is used to validate and retrieve elements from a slice of bytes. This
// type is most useful when you want do lookups on BSON bytes without unmarshaling it into another
// type.
//
// Example:
//
// var raw bson.Raw = ... // bytes from somewhere
// err := raw.Validate()
// if err != nil { return err }
// val := raw.Lookup("foo")
// i32, ok := val.Int32OK()
// // do something with i32...
//
// # Native Go Types
//
// The D and M types defined in this package can be used to build representations of BSON using native Go types. D is a
// slice and M is a map. For more information about the use cases for these types, see the documentation on the type
// definitions.
//
// Note that a D should not be constructed with duplicate key names, as that can cause undefined server behavior.
//
// Example:
//
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
//
// When decoding BSON to a D or M, the following type mappings apply when unmarshaling:
//
// 1. BSON int32 unmarshals to an int32.
// 2. BSON int64 unmarshals to an int64.
// 3. BSON double unmarshals to a float64.
// 4. BSON string unmarshals to a string.
// 5. BSON boolean unmarshals to a bool.
// 6. BSON embedded document unmarshals to the parent type (i.e. D for a D, M for an M).
// 7. BSON array unmarshals to a bson.A.
// 8. BSON ObjectId unmarshals to a bson.ObjectID.
// 9. BSON datetime unmarshals to a bson.DateTime.
// 10. BSON binary unmarshals to a bson.Binary.
// 11. BSON regular expression unmarshals to a bson.Regex.
// 12. BSON JavaScript unmarshals to a bson.JavaScript.
// 13. BSON code with scope unmarshals to a bson.CodeWithScope.
// 14. BSON timestamp unmarshals to an bson.Timestamp.
// 15. BSON 128-bit decimal unmarshals to an bson.Decimal128.
// 16. BSON min key unmarshals to an bson.MinKey.
// 17. BSON max key unmarshals to an bson.MaxKey.
// 18. BSON undefined unmarshals to a bson.Undefined.
// 19. BSON null unmarshals to nil.
// 20. BSON DBPointer unmarshals to a bson.DBPointer.
// 21. BSON symbol unmarshals to a bson.Symbol.
//
// The above mappings also apply when marshaling a D or M to BSON. Some other useful marshaling mappings are:
//
// 1. time.Time marshals to a BSON datetime.
// 2. int8, int16, and int32 marshal to a BSON int32.
// 3. int marshals to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, inclusive, and a BSON int64
// otherwise.
// 4. int64 marshals to BSON int64 (unless [Encoder.IntMinSize] is set).
// 5. uint8 and uint16 marshal to a BSON int32.
// 6. uint, uint32, and uint64 marshal to a BSON int64 (unless [Encoder.IntMinSize] is set).
// 7. BSON null and undefined values will unmarshal into the zero value of a field (e.g. unmarshaling a BSON null or
// undefined value into a string will yield the empty string.).
//
// # Structs
//
// Structs can be marshaled/unmarshaled to/from BSON or Extended JSON. When transforming structs to/from BSON or Extended
// JSON, the following rules apply:
//
// 1. Only exported fields in structs will be marshaled or unmarshaled.
//
// 2. When marshaling a struct, each field will be lowercased to generate the key for the corresponding BSON element.
// For example, a struct field named "Foo" will generate key "foo". This can be overridden via a struct tag (e.g.
// `bson:"fooField"` to generate key "fooField" instead).
//
// 3. An embedded struct field is marshaled as a subdocument. The key will be the lowercased name of the field's type.
//
// 4. A pointer field is marshaled as the underlying type if the pointer is non-nil. If the pointer is nil, it is
// marshaled as a BSON null value.
//
// 5. When unmarshaling, a field of type interface{} will follow the D/M type mappings listed above. BSON documents
// unmarshaled into an interface{} field will be unmarshaled as a D.
//
// The encoding of each struct field can be customized by the "bson" struct tag.
//
// This tag behavior is configurable, and different struct tag behavior can be configured by initializing a new
// bsoncodec.StructCodec with the desired tag parser and registering that StructCodec onto the Registry. By default, JSON
// tags are not honored, but that can be enabled by creating a StructCodec with JSONFallbackStructTagParser, like below:
//
// Example:
//
// structcodec, _ := bsoncodec.NewStructCodec(bsoncodec.JSONFallbackStructTagParser)
//
// The bson tag gives the name of the field, possibly followed by a comma-separated list of options.
// The name may be empty in order to specify options without overriding the default field name. The following options can
// be used to configure behavior:
//
// 1. omitempty: If the "omitempty" struct tag is specified on a field, the field will not be marshaled if it is set to
// an "empty" value. Numbers, booleans, and strings are considered empty if their value is equal to the zero value for
// the type (i.e. 0 for numbers, false for booleans, and "" for strings). Slices, maps, and arrays are considered
// empty if they are of length zero. Interfaces and pointers are considered empty if their value is nil. By default,
// structs are only considered empty if the struct type implements [bsoncodec.Zeroer] and the IsZero
// method returns true. Struct types that do not implement [bsoncodec.Zeroer] are never considered empty and will be
// marshaled as embedded documents. NOTE: It is recommended that this tag be used for all slice and map fields.
//
// 2. minsize: If the minsize struct tag is specified on a field of type int64, uint, uint32, or uint64 and the value of
// the field can fit in a signed int32, the field will be serialized as a BSON int32 rather than a BSON int64. For
// other types, this tag is ignored.
//
// 3. truncate: If the truncate struct tag is specified on a field with a non-float numeric type, BSON doubles
// unmarshaled into that field will be truncated at the decimal point. For example, if 3.14 is unmarshaled into a
// field of type int, it will be unmarshaled as 3. If this tag is not specified, the decoder will throw an error if
// the value cannot be decoded without losing precision. For float64 or non-numeric types, this tag is ignored.
//
// 4. inline: If the inline struct tag is specified for a struct or map field, the field will be "flattened" when
// marshaling and "un-flattened" when unmarshaling. This means that all of the fields in that struct/map will be
// pulled up one level and will become top-level fields rather than being fields in a nested document. For example,
// if a map field named "Map" with value map[string]interface{}{"foo": "bar"} is inlined, the resulting document will
// be {"foo": "bar"} instead of {"map": {"foo": "bar"}}. There can only be one inlined map field in a struct. If
// there are duplicated fields in the resulting document when an inlined struct is marshaled, the inlined field will
// be overwritten. If there are duplicated fields in the resulting document when an inlined map is marshaled, an
// error will be returned. This tag can be used with fields that are pointers to structs. If an inlined pointer field
// is nil, it will not be marshaled. For fields that are not maps or structs, this tag is ignored.
//
// # Marshaling and Unmarshaling
//
// Manually marshaling and unmarshaling can be done with the Marshal and Unmarshal family of functions.
//
// bsoncodec code provides a system for encoding values to BSON representations and decoding
// values from BSON representations. This package considers both binary BSON and ExtendedJSON as
// BSON representations. The types in this package enable a flexible system for handling this
// encoding and decoding.
//
// The codec system is composed of two parts:
//
// 1) [ValueEncoder] and [ValueDecoder] that handle encoding and decoding Go values to and from BSON
// representations.
//
// 2) A [Registry] that holds these ValueEncoders and ValueDecoders and provides methods for
// retrieving them.
//
// [Work with BSON]: https://www.mongodb.com/docs/drivers/go/current/fundamentals/bson/
package bson

127
empty_interface_codec.go Normal file
View File

@ -0,0 +1,127 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
)
// emptyInterfaceCodec is the Codec used for interface{} values.
type emptyInterfaceCodec struct {
// decodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the
// "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary.
decodeBinaryAsSlice bool
}
// Assert that emptyInterfaceCodec satisfies the typeDecoder interface, which allows it
// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a
// collection.
var _ typeDecoder = &emptyInterfaceCodec{}
// EncodeValue is the ValueEncoderFunc for interface{}.
func (eic *emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tEmpty {
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(val.Elem().Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, val.Elem())
}
func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) {
isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument
if isDocument {
if dc.defaultDocumentType != nil {
// If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return
// that type.
return dc.defaultDocumentType, nil
}
}
rtype, err := dc.LookupTypeMapEntry(valueType)
if err == nil {
return rtype, nil
}
if isDocument {
// For documents, fallback to looking up a type map entry for Type(0) or TypeEmbeddedDocument,
// depending on the original valueType.
var lookupType Type
switch valueType {
case Type(0):
lookupType = TypeEmbeddedDocument
case TypeEmbeddedDocument:
lookupType = Type(0)
}
rtype, err = dc.LookupTypeMapEntry(lookupType)
if err == nil {
return rtype, nil
}
// fallback to bson.D
return tD, nil
}
return nil, err
}
func (eic *emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tEmpty {
return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)}
}
rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type())
if err != nil {
switch vr.Type() {
case TypeNull:
return reflect.Zero(t), vr.ReadNull()
default:
return emptyValue, err
}
}
decoder, err := dc.LookupDecoder(rtype)
if err != nil {
return emptyValue, err
}
elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, rtype)
if err != nil {
return emptyValue, err
}
if (eic.decodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary {
binElem := elem.Interface().(Binary)
if binElem.Subtype == TypeBinaryGeneric || binElem.Subtype == TypeBinaryBinaryOld {
elem = reflect.ValueOf(binElem.Data)
}
}
return elem, nil
}
// DecodeValue is the ValueDecoderFunc for interface{}.
func (eic *emptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tEmpty {
return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
elem, err := eic.decodeType(dc, vr, val.Type())
if err != nil {
return err
}
val.Set(elem)
return nil
}

123
encoder.go Normal file
View File

@ -0,0 +1,123 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"reflect"
"sync"
)
// This pool is used to keep the allocations of Encoders down. This is only used for the Marshal*
// methods and is not consumable from outside of this package. The Encoders retrieved from this pool
// must have both Reset and SetRegistry called on them.
var encPool = sync.Pool{
New: func() interface{} {
return new(Encoder)
},
}
// An Encoder writes a serialization format to an output stream. It writes to a ValueWriter
// as the destination of BSON data.
type Encoder struct {
ec EncodeContext
vw ValueWriter
}
// NewEncoder returns a new encoder that writes to vw.
func NewEncoder(vw ValueWriter) *Encoder {
return &Encoder{
ec: EncodeContext{Registry: defaultRegistry},
vw: vw,
}
}
// Encode writes the BSON encoding of val to the stream.
//
// See [Marshal] for details about BSON marshaling behavior.
func (e *Encoder) Encode(val interface{}) error {
if marshaler, ok := val.(Marshaler); ok {
// TODO(skriptble): Should we have a MarshalAppender interface so that we can have []byte reuse?
buf, err := marshaler.MarshalBSON()
if err != nil {
return err
}
return copyDocumentFromBytes(e.vw, buf)
}
encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val))
if err != nil {
return err
}
return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val))
}
// Reset will reset the state of the Encoder, using the same *EncodeContext used in
// the original construction but using vw.
func (e *Encoder) Reset(vw ValueWriter) {
e.vw = vw
}
// SetRegistry replaces the current registry of the Encoder with r.
func (e *Encoder) SetRegistry(r *Registry) {
e.ec.Registry = r
}
// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in
// the marshaled BSON when the "inline" struct tag option is set.
func (e *Encoder) ErrorOnInlineDuplicates() {
e.ec.errorOnInlineDuplicates = true
}
// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint,
// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can
// represent the integer value.
func (e *Encoder) IntMinSize() {
e.ec.minSize = true
}
// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name
// strings using fmt.Sprint instead of the default string conversion logic.
func (e *Encoder) StringifyMapKeysWithFmt() {
e.ec.stringifyMapKeysWithFmt = true
}
// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON
// null.
func (e *Encoder) NilMapAsEmpty() {
e.ec.nilMapAsEmpty = true
}
// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON
// null.
func (e *Encoder) NilSliceAsEmpty() {
e.ec.nilSliceAsEmpty = true
}
// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values
// instead of BSON null.
func (e *Encoder) NilByteSliceAsEmpty() {
e.ec.nilByteSliceAsEmpty = true
}
// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported
// TODO struct fields once the logic is updated to also inspect private struct fields.
// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{})
// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set.
//
// Note that the Encoder only examines exported struct fields when determining if a struct is the
// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty.
func (e *Encoder) OmitZeroStruct() {
e.ec.omitZeroStruct = true
}
// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson"
// struct tag is not specified.
func (e *Encoder) UseJSONStructTags() {
e.ec.useJSONStructTags = true
}

240
encoder_example_test.go Normal file
View File

@ -0,0 +1,240 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson_test
import (
"bytes"
"errors"
"fmt"
"io"
"gitea.psichedelico.com/go/bson"
)
func ExampleEncoder() {
// Create an Encoder that writes BSON values to a bytes.Buffer.
buf := new(bytes.Buffer)
vw := bson.NewDocumentWriter(buf)
encoder := bson.NewEncoder(vw)
type Product struct {
Name string `bson:"name"`
SKU string `bson:"sku"`
Price int64 `bson:"price_cents"`
}
// Use the Encoder to marshal a BSON document that contains the name, SKU,
// and price (in cents) of a product.
product := Product{
Name: "Cereal Rounds",
SKU: "AB12345",
Price: 399,
}
err := encoder.Encode(product)
if err != nil {
panic(err)
}
// Print the BSON document as Extended JSON by converting it to bson.Raw.
fmt.Println(bson.Raw(buf.Bytes()).String())
// Output: {"name": "Cereal Rounds","sku": "AB12345","price_cents": {"$numberLong":"399"}}
}
type CityState struct {
City string
State string
}
func (k CityState) String() string {
return fmt.Sprintf("%s, %s", k.City, k.State)
}
func ExampleEncoder_StringifyMapKeysWithFmt() {
// Create an Encoder that writes BSON values to a bytes.Buffer.
buf := new(bytes.Buffer)
vw := bson.NewDocumentWriter(buf)
encoder := bson.NewEncoder(vw)
// Configure the Encoder to convert Go map keys to BSON document field names
// using fmt.Sprintf instead of the default string conversion logic.
encoder.StringifyMapKeysWithFmt()
// Use the Encoder to marshal a BSON document that contains is a map of
// city and state to a list of zip codes in that city.
zipCodes := map[CityState][]int{
{City: "New York", State: "NY"}: {10001, 10301, 10451},
}
err := encoder.Encode(zipCodes)
if err != nil {
panic(err)
}
// Print the BSON document as Extended JSON by converting it to bson.Raw.
fmt.Println(bson.Raw(buf.Bytes()).String())
// Output: {"New York, NY": [{"$numberInt":"10001"},{"$numberInt":"10301"},{"$numberInt":"10451"}]}
}
func ExampleEncoder_UseJSONStructTags() {
// Create an Encoder that writes BSON values to a bytes.Buffer.
buf := new(bytes.Buffer)
vw := bson.NewDocumentWriter(buf)
encoder := bson.NewEncoder(vw)
type Product struct {
Name string `json:"name"`
SKU string `json:"sku"`
Price int64 `json:"price_cents"`
}
// Configure the Encoder to use "json" struct tags when decoding if "bson"
// struct tags are not present.
encoder.UseJSONStructTags()
// Use the Encoder to marshal a BSON document that contains the name, SKU,
// and price (in cents) of a product.
product := Product{
Name: "Cereal Rounds",
SKU: "AB12345",
Price: 399,
}
err := encoder.Encode(product)
if err != nil {
panic(err)
}
// Print the BSON document as Extended JSON by converting it to bson.Raw.
fmt.Println(bson.Raw(buf.Bytes()).String())
// Output: {"name": "Cereal Rounds","sku": "AB12345","price_cents": {"$numberLong":"399"}}
}
func ExampleEncoder_multipleBSONDocuments() {
// Create an Encoder that writes BSON values to a bytes.Buffer.
buf := new(bytes.Buffer)
vw := bson.NewDocumentWriter(buf)
encoder := bson.NewEncoder(vw)
type Coordinate struct {
X int
Y int
}
// Use the encoder to marshal 5 Coordinate values as a sequence of BSON
// documents.
for i := 0; i < 5; i++ {
err := encoder.Encode(Coordinate{
X: i,
Y: i + 1,
})
if err != nil {
panic(err)
}
}
// Read each marshaled BSON document from the buffer and print them as
// Extended JSON by converting them to bson.Raw.
for {
doc, err := bson.ReadDocument(buf)
if errors.Is(err, io.EOF) {
return
}
if err != nil {
panic(err)
}
fmt.Println(doc.String())
}
// Output:
// {"x": {"$numberInt":"0"},"y": {"$numberInt":"1"}}
// {"x": {"$numberInt":"1"},"y": {"$numberInt":"2"}}
// {"x": {"$numberInt":"2"},"y": {"$numberInt":"3"}}
// {"x": {"$numberInt":"3"},"y": {"$numberInt":"4"}}
// {"x": {"$numberInt":"4"},"y": {"$numberInt":"5"}}
}
func ExampleEncoder_extendedJSON() {
// Create an Encoder that writes canonical Extended JSON values to a
// bytes.Buffer.
buf := new(bytes.Buffer)
vw := bson.NewExtJSONValueWriter(buf, true, false)
encoder := bson.NewEncoder(vw)
type Product struct {
Name string `bson:"name"`
SKU string `bson:"sku"`
Price int64 `bson:"price_cents"`
}
// Use the Encoder to marshal a BSON document that contains the name, SKU,
// and price (in cents) of a product.
product := Product{
Name: "Cereal Rounds",
SKU: "AB12345",
Price: 399,
}
err := encoder.Encode(product)
if err != nil {
panic(err)
}
fmt.Println(buf.String())
// Output: {"name":"Cereal Rounds","sku":"AB12345","price_cents":{"$numberLong":"399"}}
}
func ExampleEncoder_multipleExtendedJSONDocuments() {
// Create an Encoder that writes canonical Extended JSON values to a
// bytes.Buffer.
buf := new(bytes.Buffer)
vw := bson.NewExtJSONValueWriter(buf, true, false)
encoder := bson.NewEncoder(vw)
type Coordinate struct {
X int
Y int
}
// Use the encoder to marshal 5 Coordinate values as a sequence of Extended
// JSON documents.
for i := 0; i < 5; i++ {
err := encoder.Encode(Coordinate{
X: i,
Y: i + 1,
})
if err != nil {
panic(err)
}
}
fmt.Println(buf.String())
// Output:
// {"x":{"$numberInt":"0"},"y":{"$numberInt":"1"}}
// {"x":{"$numberInt":"1"},"y":{"$numberInt":"2"}}
// {"x":{"$numberInt":"2"},"y":{"$numberInt":"3"}}
// {"x":{"$numberInt":"3"},"y":{"$numberInt":"4"}}
// {"x":{"$numberInt":"4"},"y":{"$numberInt":"5"}}
}
func ExampleEncoder_IntMinSize() {
// Create an encoder that will marshal integers as the minimum BSON int size
// (either 32 or 64 bits) that can represent the integer value.
type foo struct {
Bar uint32
}
buf := new(bytes.Buffer)
vw := bson.NewDocumentWriter(buf)
enc := bson.NewEncoder(vw)
enc.IntMinSize()
err := enc.Encode(foo{2})
if err != nil {
panic(err)
}
fmt.Println(bson.Raw(buf.Bytes()).String())
// Output:
// {"bar": {"$numberInt":"2"}}
}

303
encoder_test.go Normal file
View File

@ -0,0 +1,303 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"errors"
"reflect"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
"gitea.psichedelico.com/go/bson/internal/require"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
func TestBasicEncode(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
got := make(sliceWriter, 0, 1024)
vw := NewDocumentWriter(&got)
reg := defaultRegistry
encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val))
noerr(t, err)
err = encoder.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.val))
noerr(t, err)
if !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
}
func TestEncoderEncode(t *testing.T) {
for _, tc := range marshalingTestCases {
t.Run(tc.name, func(t *testing.T) {
got := make(sliceWriter, 0, 1024)
vw := NewDocumentWriter(&got)
enc := NewEncoder(vw)
err := enc.Encode(tc.val)
noerr(t, err)
if !bytes.Equal(got, tc.want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
}
})
}
t.Run("Marshaler", func(t *testing.T) {
testCases := []struct {
name string
buf []byte
err error
wanterr error
vw ValueWriter
}{
{
"error",
nil,
errors.New("Marshaler error"),
errors.New("Marshaler error"),
&valueReaderWriter{},
},
{
"copy error",
[]byte{0x05, 0x00, 0x00, 0x00, 0x00},
nil,
errors.New("copy error"),
&valueReaderWriter{Err: errors.New("copy error"), ErrAfter: writeDocument},
},
{
"success",
[]byte{0x07, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00},
nil,
nil,
nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
marshaler := testMarshaler{buf: tc.buf, err: tc.err}
var vw ValueWriter
b := make(sliceWriter, 0, 100)
compareVW := false
if tc.vw != nil {
vw = tc.vw
} else {
compareVW = true
vw = NewDocumentWriter(&b)
}
enc := NewEncoder(vw)
got := enc.Encode(marshaler)
want := tc.wanterr
if !assert.CompareErrors(got, want) {
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
}
if compareVW {
buf := b
if !bytes.Equal(buf, tc.buf) {
t.Errorf("Copied bytes do not match. got %v; want %v", buf, tc.buf)
}
}
})
}
})
}
type testMarshaler struct {
buf []byte
err error
}
func (tm testMarshaler) MarshalBSON() ([]byte, error) { return tm.buf, tm.err }
func docToBytes(d interface{}) []byte {
b, err := Marshal(d)
if err != nil {
panic(err)
}
return b
}
type stringerTest struct{}
func (stringerTest) String() string {
return "test key"
}
func TestEncoderConfiguration(t *testing.T) {
type inlineDuplicateInner struct {
Duplicate string
}
type inlineDuplicateOuter struct {
Inline inlineDuplicateInner `bson:",inline"`
Duplicate string
}
type zeroStruct struct {
MyString string
}
testCases := []struct {
description string
configure func(*Encoder)
input interface{}
want []byte
wantErr error
}{
// Test that ErrorOnInlineDuplicates causes the Encoder to return an error if there are any
// duplicate fields in the marshaled document caused by using the "inline" struct tag.
{
description: "ErrorOnInlineDuplicates",
configure: func(enc *Encoder) {
enc.ErrorOnInlineDuplicates()
},
input: inlineDuplicateOuter{
Inline: inlineDuplicateInner{Duplicate: "inner"},
Duplicate: "outer",
},
wantErr: errors.New("struct bson.inlineDuplicateOuter has duplicated key duplicate"),
},
// Test that IntMinSize encodes Go int and int64 values as BSON int32 if the value is small
// enough.
{
description: "IntMinSize",
configure: func(enc *Encoder) {
enc.IntMinSize()
},
input: D{
{Key: "myInt", Value: int(1)},
{Key: "myInt64", Value: int64(1)},
{Key: "myUint", Value: uint(1)},
{Key: "myUint32", Value: uint32(1)},
{Key: "myUint64", Value: uint64(1)},
},
want: bsoncore.NewDocumentBuilder().
AppendInt32("myInt", 1).
AppendInt32("myInt64", 1).
AppendInt32("myUint", 1).
AppendInt32("myUint32", 1).
AppendInt32("myUint64", 1).
Build(),
},
// Test that StringifyMapKeysWithFmt uses fmt.Sprint to convert map keys to BSON field names.
{
description: "StringifyMapKeysWithFmt",
configure: func(enc *Encoder) {
enc.StringifyMapKeysWithFmt()
},
input: map[stringerTest]string{
{}: "test value",
},
want: bsoncore.NewDocumentBuilder().
AppendString("test key", "test value").
Build(),
},
// Test that NilMapAsEmpty encodes nil Go maps as empty BSON documents.
{
description: "NilMapAsEmpty",
configure: func(enc *Encoder) {
enc.NilMapAsEmpty()
},
input: D{{Key: "myMap", Value: map[string]string(nil)}},
want: bsoncore.NewDocumentBuilder().
AppendDocument("myMap", bsoncore.NewDocumentBuilder().Build()).
Build(),
},
// Test that NilSliceAsEmpty encodes nil Go slices as empty BSON arrays.
{
description: "NilSliceAsEmpty",
configure: func(enc *Encoder) {
enc.NilSliceAsEmpty()
},
input: D{{Key: "mySlice", Value: []string(nil)}},
want: bsoncore.NewDocumentBuilder().
AppendArray("mySlice", bsoncore.NewArrayBuilder().Build()).
Build(),
},
// Test that NilByteSliceAsEmpty encodes nil Go byte slices as empty BSON binary elements.
{
description: "NilByteSliceAsEmpty",
configure: func(enc *Encoder) {
enc.NilByteSliceAsEmpty()
},
input: D{{Key: "myBytes", Value: []byte(nil)}},
want: bsoncore.NewDocumentBuilder().
AppendBinary("myBytes", TypeBinaryGeneric, []byte{}).
Build(),
},
// Test that OmitZeroStruct omits empty structs from the marshaled document if the
// "omitempty" struct tag is used.
{
description: "OmitZeroStruct",
configure: func(enc *Encoder) {
enc.OmitZeroStruct()
},
input: struct {
Zero zeroStruct `bson:",omitempty"`
}{},
want: bsoncore.NewDocumentBuilder().Build(),
},
// Test that UseJSONStructTags causes the Encoder to fall back to "json" struct tags if
// "bson" struct tags are not available.
{
description: "UseJSONStructTags",
configure: func(enc *Encoder) {
enc.UseJSONStructTags()
},
input: struct {
StructFieldName string `json:"jsonFieldName"`
}{
StructFieldName: "test value",
},
want: bsoncore.NewDocumentBuilder().
AppendString("jsonFieldName", "test value").
Build(),
},
}
for _, tc := range testCases {
tc := tc // Capture range variable.
t.Run(tc.description, func(t *testing.T) {
t.Parallel()
got := new(bytes.Buffer)
vw := NewDocumentWriter(got)
enc := NewEncoder(vw)
tc.configure(enc)
err := enc.Encode(tc.input)
if tc.wantErr != nil {
assert.Equal(t, tc.wantErr, err, "expected and actual errors do not match")
return
}
require.NoError(t, err, "Encode error")
assert.Equal(t, tc.want, got.Bytes(), "expected and actual encoded BSON do not match")
// After we compare the raw bytes, also decode the expected and actual BSON as a bson.D
// and compare them. The goal is to make assertion failures easier to debug because
// binary diffs are very difficult to understand.
var wantDoc D
err = Unmarshal(tc.want, &wantDoc)
require.NoError(t, err, "Unmarshal error")
var gotDoc D
err = Unmarshal(got.Bytes(), &gotDoc)
require.NoError(t, err, "Unmarshal error")
assert.Equal(t, wantDoc, gotDoc, "expected and actual decoded documents do not match")
})
}
}

143
example_test.go Normal file
View File

@ -0,0 +1,143 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson_test
import (
"fmt"
"time"
"gitea.psichedelico.com/go/bson"
)
// This example uses Raw to skip parsing a nested document in a BSON message.
func ExampleRaw_unmarshal() {
b, err := bson.Marshal(bson.M{
"Word": "beach",
"Synonyms": bson.A{"coast", "shore", "waterfront"},
})
if err != nil {
panic(err)
}
var res struct {
Word string
Synonyms bson.Raw // Don't parse the whole list, we just want to count the elements.
}
err = bson.Unmarshal(b, &res)
if err != nil {
panic(err)
}
elems, err := res.Synonyms.Elements()
if err != nil {
panic(err)
}
fmt.Printf("%s, synonyms count: %d\n", res.Word, len(elems))
// Output: beach, synonyms count: 3
}
// This example uses Raw to add a precomputed BSON document during marshal.
func ExampleRaw_marshal() {
precomputed, err := bson.Marshal(bson.M{"Precomputed": true})
if err != nil {
panic(err)
}
msg := struct {
Message string
Metadata bson.Raw
}{
Message: "Hello World!",
Metadata: precomputed,
}
b, err := bson.Marshal(msg)
if err != nil {
panic(err)
}
// Print the Extended JSON by converting BSON to bson.Raw.
fmt.Println(bson.Raw(b).String())
// Output: {"message": "Hello World!","metadata": {"Precomputed": true}}
}
// This example uses RawValue to delay parsing a value in a BSON message.
func ExampleRawValue_unmarshal() {
b1, err := bson.Marshal(bson.M{
"Format": "UNIX",
"Timestamp": 1675282389,
})
if err != nil {
panic(err)
}
b2, err := bson.Marshal(bson.M{
"Format": "RFC3339",
"Timestamp": time.Unix(1675282389, 0).Format(time.RFC3339),
})
if err != nil {
panic(err)
}
for _, b := range [][]byte{b1, b2} {
var res struct {
Format string
Timestamp bson.RawValue // Delay parsing until we know the timestamp format.
}
err = bson.Unmarshal(b, &res)
if err != nil {
panic(err)
}
var t time.Time
switch res.Format {
case "UNIX":
t = time.Unix(res.Timestamp.AsInt64(), 0)
case "RFC3339":
t, err = time.Parse(time.RFC3339, res.Timestamp.StringValue())
if err != nil {
panic(err)
}
}
fmt.Println(res.Format, t.Unix())
}
// Output:
// UNIX 1675282389
// RFC3339 1675282389
}
// This example uses RawValue to add a precomputed BSON string value during marshal.
func ExampleRawValue_marshal() {
t, val, err := bson.MarshalValue("Precomputed message!")
if err != nil {
panic(err)
}
precomputed := bson.RawValue{
Type: t,
Value: val,
}
msg := struct {
Message bson.RawValue
Time time.Time
}{
Message: precomputed,
Time: time.Unix(1675282389, 0),
}
b, err := bson.Marshal(msg)
if err != nil {
panic(err)
}
// Print the Extended JSON by converting BSON to bson.Raw.
fmt.Println(bson.Raw(b).String())
// Output: {"message": "Precomputed message!","time": {"$date":{"$numberLong":"1675282389000"}}}
}

804
extjson_parser.go Normal file
View File

@ -0,0 +1,804 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
)
const maxNestingDepth = 200
// ErrInvalidJSON indicates the JSON input is invalid
var ErrInvalidJSON = errors.New("invalid JSON input")
type jsonParseState byte
const (
jpsStartState jsonParseState = iota
jpsSawBeginObject
jpsSawEndObject
jpsSawBeginArray
jpsSawEndArray
jpsSawColon
jpsSawComma
jpsSawKey
jpsSawValue
jpsDoneState
jpsInvalidState
)
type jsonParseMode byte
const (
jpmInvalidMode jsonParseMode = iota
jpmObjectMode
jpmArrayMode
)
type extJSONValue struct {
t Type
v interface{}
}
type extJSONObject struct {
keys []string
values []*extJSONValue
}
type extJSONParser struct {
js *jsonScanner
s jsonParseState
m []jsonParseMode
k string
v *extJSONValue
err error
canonicalOnly bool
depth int
maxDepth int
emptyObject bool
relaxedUUID bool
}
// newExtJSONParser returns a new extended JSON parser, ready to to begin
// parsing from the first character of the argued json input. It will not
// perform any read-ahead and will therefore not report any errors about
// malformed JSON at this point.
func newExtJSONParser(r io.Reader, canonicalOnly bool) *extJSONParser {
return &extJSONParser{
js: &jsonScanner{r: r},
s: jpsStartState,
m: []jsonParseMode{},
canonicalOnly: canonicalOnly,
maxDepth: maxNestingDepth,
}
}
// peekType examines the next value and returns its BSON Type
func (ejp *extJSONParser) peekType() (Type, error) {
var t Type
var err error
initialState := ejp.s
ejp.advanceState()
switch ejp.s {
case jpsSawValue:
t = ejp.v.t
case jpsSawBeginArray:
t = TypeArray
case jpsInvalidState:
err = ejp.err
case jpsSawComma:
// in array mode, seeing a comma means we need to progress again to actually observe a type
if ejp.peekMode() == jpmArrayMode {
return ejp.peekType()
}
case jpsSawEndArray:
// this would only be a valid state if we were in array mode, so return end-of-array error
err = ErrEOA
case jpsSawBeginObject:
// peek key to determine type
ejp.advanceState()
switch ejp.s {
case jpsSawEndObject: // empty embedded document
t = TypeEmbeddedDocument
ejp.emptyObject = true
case jpsInvalidState:
err = ejp.err
case jpsSawKey:
if initialState == jpsStartState {
return TypeEmbeddedDocument, nil
}
t = wrapperKeyBSONType(ejp.k)
// if $uuid is encountered, parse as binary subtype 4
if ejp.k == "$uuid" {
ejp.relaxedUUID = true
t = TypeBinary
}
switch t {
case TypeJavaScript:
// just saw $code, need to check for $scope at same level
_, err = ejp.readValue(TypeJavaScript)
if err != nil {
break
}
switch ejp.s {
case jpsSawEndObject: // type is TypeJavaScript
case jpsSawComma:
ejp.advanceState()
if ejp.s == jpsSawKey && ejp.k == "$scope" {
t = TypeCodeWithScope
} else {
err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
}
case jpsInvalidState:
err = ejp.err
default:
err = ErrInvalidJSON
}
case TypeCodeWithScope:
err = errors.New("invalid extended JSON: code with $scope must contain $code before $scope")
}
}
}
return t, err
}
// readKey parses the next key and its type and returns them
func (ejp *extJSONParser) readKey() (string, Type, error) {
if ejp.emptyObject {
ejp.emptyObject = false
return "", 0, ErrEOD
}
// advance to key (or return with error)
switch ejp.s {
case jpsStartState:
ejp.advanceState()
if ejp.s == jpsSawBeginObject {
ejp.advanceState()
}
case jpsSawBeginObject:
ejp.advanceState()
case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject, jpsSawComma:
ejp.advanceState()
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsDoneState:
return "", 0, io.EOF
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, ErrInvalidJSON
}
case jpsSawKey: // do nothing (key was peeked before)
default:
return "", 0, invalidRequestError("key")
}
// read key
var key string
switch ejp.s {
case jpsSawKey:
key = ejp.k
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, invalidRequestError("key")
}
// check for colon
ejp.advanceState()
if err := ensureColon(ejp.s, key); err != nil {
return "", 0, err
}
// peek at the value to determine type
t, err := ejp.peekType()
if err != nil {
return "", 0, err
}
return key, t, nil
}
// readValue returns the value corresponding to the Type returned by peekType
func (ejp *extJSONParser) readValue(t Type) (*extJSONValue, error) {
if ejp.s == jpsInvalidState {
return nil, ejp.err
}
var v *extJSONValue
switch t {
case TypeNull, TypeBoolean, TypeString:
if ejp.s != jpsSawValue {
return nil, invalidRequestError(t.String())
}
v = ejp.v
case TypeInt32, TypeInt64, TypeDouble:
// relaxed version allows these to be literal number values
if ejp.s == jpsSawValue {
v = ejp.v
break
}
fallthrough
case TypeDecimal128, TypeSymbol, TypeObjectID, TypeMinKey, TypeMaxKey, TypeUndefined:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("} after value", t)
}
default:
return nil, invalidRequestError(t.String())
}
case TypeBinary, TypeRegex, TypeTimestamp, TypeDBPointer:
if ejp.s != jpsSawKey {
return nil, invalidRequestError(t.String())
}
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
if t == TypeBinary && ejp.s == jpsSawValue {
// convert relaxed $uuid format
if ejp.relaxedUUID {
defer func() { ejp.relaxedUUID = false }()
uuid, err := ejp.v.parseSymbol()
if err != nil {
return nil, err
}
// RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing
// in the 8th, 13th, 18th, and 23rd characters.
//
// See https://tools.ietf.org/html/rfc4122#section-3
valid := len(uuid) == 36 &&
string(uuid[8]) == "-" &&
string(uuid[13]) == "-" &&
string(uuid[18]) == "-" &&
string(uuid[23]) == "-"
if !valid {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
}
// remove hyphens
uuidNoHyphens := strings.ReplaceAll(uuid, "-", "")
if len(uuidNoHyphens) != 32 {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
}
// convert hex to bytes
bytes, err := hex.DecodeString(uuidNoHyphens)
if err != nil {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %w", err)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("$uuid and value and then }", TypeBinary)
}
base64 := &extJSONValue{
t: TypeString,
v: base64.StdEncoding.EncodeToString(bytes),
}
subType := &extJSONValue{
t: TypeString,
v: "04",
}
v = &extJSONValue{
t: TypeEmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// convert legacy $binary format
base64 := ejp.v
ejp.advanceState()
if ejp.s != jpsSawComma {
return nil, invalidJSONErrorForType(",", TypeBinary)
}
ejp.advanceState()
key, t, err := ejp.readKey()
if err != nil {
return nil, err
}
if key != "$type" {
return nil, invalidJSONErrorForType("$type", TypeBinary)
}
subType, err := ejp.readValue(t)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", TypeBinary)
}
v = &extJSONValue{
t: TypeEmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// read KV pairs
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONErrorForType("{", t)
}
keys, vals, err := ejp.readObject(2, true)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
}
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case TypeDateTime:
switch ejp.s {
case jpsSawValue:
v = ejp.v
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject:
keys, vals, err := ejp.readObject(1, true)
if err != nil {
return nil, err
}
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case jpsSawValue:
if ejp.canonicalOnly {
return nil, invalidJSONError("{")
}
v = ejp.v
default:
if ejp.canonicalOnly {
return nil, invalidJSONErrorForType("object", t)
}
return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as described in RFC-3339", t)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("value and then }", t)
}
default:
return nil, invalidRequestError(t.String())
}
case TypeJavaScript:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object or comma and just return
ejp.advanceState()
case jpsSawEndObject:
v = ejp.v
default:
return nil, invalidRequestError(t.String())
}
case TypeCodeWithScope:
if ejp.s == jpsSawKey && ejp.k == "$scope" {
v = ejp.v // this is the $code string from earlier
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONError("$scope to be embedded document")
}
} else {
return nil, invalidRequestError(t.String())
}
case TypeEmbeddedDocument, TypeArray:
return nil, invalidRequestError(t.String())
}
return v, nil
}
// readObject is a utility method for reading full objects of known (or expected) size
// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
keys := make([]string, numKeys)
vals := make([]*extJSONValue, numKeys)
if !started {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, nil, invalidJSONError("{")
}
}
for i := 0; i < numKeys; i++ {
key, t, err := ejp.readKey()
if err != nil {
return nil, nil, err
}
switch ejp.s {
case jpsSawKey:
v, err := ejp.readValue(t)
if err != nil {
return nil, nil, err
}
keys[i] = key
vals[i] = v
case jpsSawValue:
keys[i] = key
vals[i] = ejp.v
default:
return nil, nil, invalidJSONError("value")
}
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, nil, invalidJSONError("}")
}
return keys, vals, nil
}
// advanceState reads the next JSON token from the scanner and transitions
// from the current state based on that token's type
func (ejp *extJSONParser) advanceState() {
if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
return
}
jt, err := ejp.js.nextToken()
if err != nil {
ejp.err = err
ejp.s = jpsInvalidState
return
}
valid := ejp.validateToken(jt.t)
if !valid {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
return
}
switch jt.t {
case jttBeginObject:
ejp.s = jpsSawBeginObject
ejp.pushMode(jpmObjectMode)
ejp.depth++
if ejp.depth > ejp.maxDepth {
ejp.err = nestingDepthError(jt.p, ejp.depth)
ejp.s = jpsInvalidState
}
case jttEndObject:
ejp.s = jpsSawEndObject
ejp.depth--
if ejp.popMode() != jpmObjectMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttBeginArray:
ejp.s = jpsSawBeginArray
ejp.pushMode(jpmArrayMode)
case jttEndArray:
ejp.s = jpsSawEndArray
if ejp.popMode() != jpmArrayMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttColon:
ejp.s = jpsSawColon
case jttComma:
ejp.s = jpsSawComma
case jttEOF:
ejp.s = jpsDoneState
if len(ejp.m) != 0 {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttString:
switch ejp.s {
case jpsSawComma:
if ejp.peekMode() == jpmArrayMode {
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
return
}
fallthrough
case jpsSawBeginObject:
ejp.s = jpsSawKey
ejp.k = jt.v.(string)
return
}
fallthrough
default:
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
}
}
var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
jpsStartState: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
jttEOF: true,
},
jpsSawBeginObject: {
jttEndObject: true,
jttString: true,
},
jpsSawEndObject: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawBeginArray: {
jttBeginObject: true,
jttBeginArray: true,
jttEndArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawEndArray: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawColon: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawComma: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawKey: {
jttColon: true,
},
jpsSawValue: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsDoneState: {},
jpsInvalidState: {},
}
func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
switch ejp.s {
case jpsSawEndObject:
// if we are at depth zero and the next token is a '{',
// we can consider it valid only if we are not in array mode.
if jtt == jttBeginObject && ejp.depth == 0 {
return ejp.peekMode() != jpmArrayMode
}
case jpsSawComma:
switch ejp.peekMode() {
// the only valid next token after a comma inside a document is a string (a key)
case jpmObjectMode:
return jtt == jttString
case jpmInvalidMode:
return false
}
}
_, ok := jpsValidTransitionTokens[ejp.s][jtt]
return ok
}
// ensureExtValueType returns true if the current value has the expected
// value type for single-key extended JSON types. For example,
// {"$numberInt": v} v must be TypeString
func (ejp *extJSONParser) ensureExtValueType(t Type) bool {
switch t {
case TypeMinKey, TypeMaxKey:
return ejp.v.t == TypeInt32
case TypeUndefined:
return ejp.v.t == TypeBoolean
case TypeInt32, TypeInt64, TypeDouble, TypeDecimal128, TypeSymbol, TypeObjectID:
return ejp.v.t == TypeString
default:
return false
}
}
func (ejp *extJSONParser) pushMode(m jsonParseMode) {
ejp.m = append(ejp.m, m)
}
func (ejp *extJSONParser) popMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
m := ejp.m[l-1]
ejp.m = ejp.m[:l-1]
return m
}
func (ejp *extJSONParser) peekMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
return ejp.m[l-1]
}
func extendJSONToken(jt *jsonToken) *extJSONValue {
var t Type
switch jt.t {
case jttInt32:
t = TypeInt32
case jttInt64:
t = TypeInt64
case jttDouble:
t = TypeDouble
case jttString:
t = TypeString
case jttBool:
t = TypeBoolean
case jttNull:
t = TypeNull
default:
return nil
}
return &extJSONValue{t: t, v: jt.v}
}
func ensureColon(s jsonParseState, key string) error {
if s != jpsSawColon {
return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
}
return nil
}
func invalidRequestError(s string) error {
return fmt.Errorf("invalid request to read %s", s)
}
func invalidJSONError(expected string) error {
return fmt.Errorf("invalid JSON input; expected %s", expected)
}
func invalidJSONErrorForType(expected string, t Type) error {
return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
}
func unexpectedTokenError(jt *jsonToken) error {
switch jt.t {
case jttInt32, jttInt64, jttDouble:
return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
case jttString:
return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
case jttBool:
return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
case jttNull:
return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
case jttEOF:
return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
default:
return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
}
}
func nestingDepthError(p, depth int) error {
return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
}

804
extjson_parser_test.go Normal file
View File

@ -0,0 +1,804 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"io"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
var (
keyDiff = specificDiff("key")
typDiff = specificDiff("type")
valDiff = specificDiff("value")
expectErrEOF = expectSpecificError(io.EOF)
expectErrEOD = expectSpecificError(ErrEOD)
expectErrEOA = expectSpecificError(ErrEOA)
)
type expectedErrorFunc func(t *testing.T, err error, desc string)
type peekTypeTestCase struct {
desc string
input string
typs []Type
errFs []expectedErrorFunc
}
type readKeyValueTestCase struct {
desc string
input string
keys []string
typs []Type
vals []*extJSONValue
keyEFs []expectedErrorFunc
valEFs []expectedErrorFunc
}
func expectNoError(t *testing.T, err error, desc string) {
if err != nil {
t.Helper()
t.Errorf("%s: Unepexted error: %v", desc, err)
t.FailNow()
}
}
func expectError(t *testing.T, err error, desc string) {
if err == nil {
t.Helper()
t.Errorf("%s: Expected error", desc)
t.FailNow()
}
}
func expectSpecificError(expected error) expectedErrorFunc {
return func(t *testing.T, err error, desc string) {
if !errors.Is(err, expected) {
t.Helper()
t.Errorf("%s: Expected %v but got: %v", desc, expected, err)
t.FailNow()
}
}
}
func specificDiff(name string) func(t *testing.T, expected, actual interface{}, desc string) {
return func(t *testing.T, expected, actual interface{}, desc string) {
if diff := cmp.Diff(expected, actual); diff != "" {
t.Helper()
t.Errorf("%s: Incorrect JSON %s (-want, +got): %s\n", desc, name, diff)
t.FailNow()
}
}
}
func expectErrorNOOP(_ *testing.T, _ error, _ string) {
}
func readKeyDiff(t *testing.T, eKey, aKey string, eTyp, aTyp Type, err error, errF expectedErrorFunc, desc string) {
keyDiff(t, eKey, aKey, desc)
typDiff(t, eTyp, aTyp, desc)
errF(t, err, desc)
}
func readValueDiff(t *testing.T, eVal, aVal *extJSONValue, err error, errF expectedErrorFunc, desc string) {
if aVal != nil {
typDiff(t, eVal.t, aVal.t, desc)
valDiff(t, eVal.v, aVal.v, desc)
} else {
valDiff(t, eVal, aVal, desc)
}
errF(t, err, desc)
}
func TestExtJSONParserPeekType(t *testing.T) {
makeValidPeekTypeTestCase := func(input string, typ Type, desc string) peekTypeTestCase {
return peekTypeTestCase{
desc: desc, input: input,
typs: []Type{typ},
errFs: []expectedErrorFunc{expectNoError},
}
}
makeInvalidTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase {
return peekTypeTestCase{
desc: desc, input: input,
typs: []Type{Type(0)},
errFs: []expectedErrorFunc{lastEF},
}
}
makeInvalidPeekTypeTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase {
return peekTypeTestCase{
desc: desc, input: input,
typs: []Type{TypeArray, TypeString, Type(0)},
errFs: []expectedErrorFunc{expectNoError, expectNoError, lastEF},
}
}
cases := []peekTypeTestCase{
makeValidPeekTypeTestCase(`null`, TypeNull, "Null"),
makeValidPeekTypeTestCase(`"string"`, TypeString, "String"),
makeValidPeekTypeTestCase(`true`, TypeBoolean, "Boolean--true"),
makeValidPeekTypeTestCase(`false`, TypeBoolean, "Boolean--false"),
makeValidPeekTypeTestCase(`{"$minKey": 1}`, TypeMinKey, "MinKey"),
makeValidPeekTypeTestCase(`{"$maxKey": 1}`, TypeMaxKey, "MaxKey"),
makeValidPeekTypeTestCase(`{"$numberInt": "42"}`, TypeInt32, "Int32"),
makeValidPeekTypeTestCase(`{"$numberLong": "42"}`, TypeInt64, "Int64"),
makeValidPeekTypeTestCase(`{"$symbol": "symbol"}`, TypeSymbol, "Symbol"),
makeValidPeekTypeTestCase(`{"$numberDouble": "42.42"}`, TypeDouble, "Double"),
makeValidPeekTypeTestCase(`{"$undefined": true}`, TypeUndefined, "Undefined"),
makeValidPeekTypeTestCase(`{"$numberDouble": "NaN"}`, TypeDouble, "Double--NaN"),
makeValidPeekTypeTestCase(`{"$numberDecimal": "1234"}`, TypeDecimal128, "Decimal"),
makeValidPeekTypeTestCase(`{"foo": "bar"}`, TypeEmbeddedDocument, "Toplevel document"),
makeValidPeekTypeTestCase(`{"$date": {"$numberLong": "0"}}`, TypeDateTime, "Datetime"),
makeValidPeekTypeTestCase(`{"$code": "function() {}"}`, TypeJavaScript, "Code no scope"),
makeValidPeekTypeTestCase(`[{"$numberInt": "1"},{"$numberInt": "2"}]`, TypeArray, "Array"),
makeValidPeekTypeTestCase(`{"$timestamp": {"t": 42, "i": 1}}`, TypeTimestamp, "Timestamp"),
makeValidPeekTypeTestCase(`{"$oid": "57e193d7a9cc81b4027498b5"}`, TypeObjectID, "Object ID"),
makeValidPeekTypeTestCase(`{"$binary": {"base64": "AQIDBAU=", "subType": "80"}}`, TypeBinary, "Binary"),
makeValidPeekTypeTestCase(`{"$code": "function() {}", "$scope": {}}`, TypeCodeWithScope, "Code With Scope"),
makeValidPeekTypeTestCase(`{"$binary": {"base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03"}}`, TypeBinary, "Binary"),
makeValidPeekTypeTestCase(`{"$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03"}`, TypeBinary, "Binary"),
makeValidPeekTypeTestCase(`{"$regularExpression": {"pattern": "foo*", "options": "ix"}}`, TypeRegex, "Regular expression"),
makeValidPeekTypeTestCase(`{"$dbPointer": {"$ref": "db.collection", "$id": {"$oid": "57e193d7a9cc81b4027498b1"}}}`, TypeDBPointer, "DBPointer"),
makeValidPeekTypeTestCase(`{"$ref": "collection", "$id": {"$oid": "57fd71e96e32ab4225b723fb"}, "$db": "database"}`, TypeEmbeddedDocument, "DBRef"),
makeInvalidPeekTypeTestCase("invalid array--missing ]", `["a"`, expectError),
makeInvalidPeekTypeTestCase("invalid array--colon in array", `["a":`, expectError),
makeInvalidPeekTypeTestCase("invalid array--extra comma", `["a",,`, expectError),
makeInvalidPeekTypeTestCase("invalid array--trailing comma", `["a",]`, expectError),
makeInvalidPeekTypeTestCase("peekType after end of array", `["a"]`, expectErrEOA),
{
desc: "invalid array--leading comma",
input: `[,`,
typs: []Type{TypeArray, Type(0)},
errFs: []expectedErrorFunc{expectNoError, expectError},
},
makeInvalidTestCase("lone $scope", `{"$scope": {}}`, expectError),
makeInvalidTestCase("empty code with unknown extra key", `{"$code":"", "0":""}`, expectError),
makeInvalidTestCase("non-empty code with unknown extra key", `{"$code":"foobar", "0":""}`, expectError),
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
ejp := newExtJSONParser(strings.NewReader(tc.input), true)
// Manually set the parser's starting state to jpsSawColon so peekType will read ahead to find the extjson
// type of the value. If not set, the parser will be in jpsStartState and advance to jpsSawKey, which will
// cause it to return without peeking the extjson type.
ejp.s = jpsSawColon
for i, eTyp := range tc.typs {
errF := tc.errFs[i]
typ, err := ejp.peekType()
errF(t, err, tc.desc)
if err != nil {
// Don't inspect the type if there was an error
return
}
typDiff(t, eTyp, typ, tc.desc)
}
})
}
}
func TestExtJSONParserReadKeyReadValue(t *testing.T) {
// several test cases will use the same keys, types, and values, and only differ on input structure
keys := []string{"_id", "Symbol", "String", "Int32", "Int64", "Int", "MinKey"}
types := []Type{TypeObjectID, TypeSymbol, TypeString, TypeInt32, TypeInt64, TypeInt32, TypeMinKey}
values := []*extJSONValue{
{t: TypeString, v: "57e193d7a9cc81b4027498b5"},
{t: TypeString, v: "symbol"},
{t: TypeString, v: "string"},
{t: TypeString, v: "42"},
{t: TypeString, v: "42"},
{t: TypeInt32, v: int32(42)},
{t: TypeInt32, v: int32(1)},
}
errFuncs := make([]expectedErrorFunc, 7)
for i := 0; i < 7; i++ {
errFuncs[i] = expectNoError
}
firstKeyError := func(desc, input string) readKeyValueTestCase {
return readKeyValueTestCase{
desc: desc,
input: input,
keys: []string{""},
typs: []Type{Type(0)},
vals: []*extJSONValue{nil},
keyEFs: []expectedErrorFunc{expectError},
valEFs: []expectedErrorFunc{expectErrorNOOP},
}
}
secondKeyError := func(desc, input, firstKey string, firstType Type, firstValue *extJSONValue) readKeyValueTestCase {
return readKeyValueTestCase{
desc: desc,
input: input,
keys: []string{firstKey, ""},
typs: []Type{firstType, Type(0)},
vals: []*extJSONValue{firstValue, nil},
keyEFs: []expectedErrorFunc{expectNoError, expectError},
valEFs: []expectedErrorFunc{expectNoError, expectErrorNOOP},
}
}
cases := []readKeyValueTestCase{
{
desc: "normal spacing",
input: `{
"_id": { "$oid": "57e193d7a9cc81b4027498b5" },
"Symbol": { "$symbol": "symbol" },
"String": "string",
"Int32": { "$numberInt": "42" },
"Int64": { "$numberLong": "42" },
"Int": 42,
"MinKey": { "$minKey": 1 }
}`,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "new line before comma",
input: `{ "_id": { "$oid": "57e193d7a9cc81b4027498b5" }
, "Symbol": { "$symbol": "symbol" }
, "String": "string"
, "Int32": { "$numberInt": "42" }
, "Int64": { "$numberLong": "42" }
, "Int": 42
, "MinKey": { "$minKey": 1 }
}`,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "tabs around colons",
input: `{
"_id": { "$oid" : "57e193d7a9cc81b4027498b5" },
"Symbol": { "$symbol" : "symbol" },
"String": "string",
"Int32": { "$numberInt" : "42" },
"Int64": { "$numberLong": "42" },
"Int": 42,
"MinKey": { "$minKey": 1 }
}`,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "no whitespace",
input: `{"_id":{"$oid":"57e193d7a9cc81b4027498b5"},"Symbol":{"$symbol":"symbol"},"String":"string","Int32":{"$numberInt":"42"},"Int64":{"$numberLong":"42"},"Int":42,"MinKey":{"$minKey":1}}`,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "mixed whitespace",
input: ` {
"_id" : { "$oid": "57e193d7a9cc81b4027498b5" },
"Symbol" : { "$symbol": "symbol" } ,
"String" : "string",
"Int32" : { "$numberInt": "42" } ,
"Int64" : {"$numberLong" : "42"},
"Int" : 42,
"MinKey" : { "$minKey": 1 } } `,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "nested object",
input: `{"k1": 1, "k2": { "k3": { "k4": 4 } }, "k5": 5}`,
keys: []string{"k1", "k2", "k3", "k4", "", "", "k5", ""},
typs: []Type{TypeInt32, TypeEmbeddedDocument, TypeEmbeddedDocument, TypeInt32, Type(0), Type(0), TypeInt32, Type(0)},
vals: []*extJSONValue{
{t: TypeInt32, v: int32(1)}, nil, nil, {t: TypeInt32, v: int32(4)}, nil, nil, {t: TypeInt32, v: int32(5)}, nil,
},
keyEFs: []expectedErrorFunc{
expectNoError, expectNoError, expectNoError, expectNoError, expectErrEOD,
expectErrEOD, expectNoError, expectErrEOD,
},
valEFs: []expectedErrorFunc{
expectNoError, expectError, expectError, expectNoError, expectErrorNOOP,
expectErrorNOOP, expectNoError, expectErrorNOOP,
},
},
{
desc: "invalid input: invalid values for extended type",
input: `{"a": {"$numberInt": "1", "x"`,
keys: []string{"a"},
typs: []Type{TypeInt32},
vals: []*extJSONValue{nil},
keyEFs: []expectedErrorFunc{expectNoError},
valEFs: []expectedErrorFunc{expectError},
},
firstKeyError("invalid input: missing key--EOF", "{"),
firstKeyError("invalid input: missing key--colon first", "{:"),
firstKeyError("invalid input: missing value", `{"a":`),
firstKeyError("invalid input: missing colon", `{"a" 1`),
firstKeyError("invalid input: extra colon", `{"a"::`),
secondKeyError("invalid input: missing }", `{"a": 1`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
secondKeyError("invalid input: missing comma", `{"a": 1 "b"`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
secondKeyError("invalid input: extra comma", `{"a": 1,, "b"`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
secondKeyError("invalid input: trailing comma in object", `{"a": 1,}`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
{
desc: "invalid input: lone scope after a complete value",
input: `{"a": "", "b": {"$scope: ""}}`,
keys: []string{"a"},
typs: []Type{TypeString},
vals: []*extJSONValue{{TypeString, ""}},
keyEFs: []expectedErrorFunc{expectNoError, expectNoError},
valEFs: []expectedErrorFunc{expectNoError, expectError},
},
{
desc: "invalid input: lone scope nested",
input: `{"a":{"b":{"$scope":{`,
keys: []string{},
typs: []Type{},
vals: []*extJSONValue{nil},
keyEFs: []expectedErrorFunc{expectNoError},
valEFs: []expectedErrorFunc{expectError},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
ejp := newExtJSONParser(strings.NewReader(tc.input), true)
for i, eKey := range tc.keys {
eTyp := tc.typs[i]
eVal := tc.vals[i]
keyErrF := tc.keyEFs[i]
valErrF := tc.valEFs[i]
k, typ, err := ejp.readKey()
readKeyDiff(t, eKey, k, eTyp, typ, err, keyErrF, tc.desc)
v, err := ejp.readValue(typ)
readValueDiff(t, eVal, v, err, valErrF, tc.desc)
}
})
}
}
type ejpExpectationTest func(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{})
type ejpTestCase struct {
f ejpExpectationTest
p *extJSONParser
k string
t Type
v interface{}
}
// expectSingleValue is used for simple JSON types (strings, numbers, literals) and for extended JSON types that
// have single key-value pairs (i.e. { "$minKey": 1 }, { "$numberLong": "42.42" })
func expectSingleValue(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{}) {
eVal := expectedValue.(*extJSONValue)
k, typ, err := p.readKey()
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
v, err := p.readValue(typ)
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
}
// expectMultipleValues is used for values that are subdocuments of known size and with known keys (such as extended
// JSON types { "$timestamp": {"t": 1, "i": 1} } and { "$regularExpression": {"pattern": "", options: ""} })
func expectMultipleValues(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{}) {
k, typ, err := p.readKey()
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
v, err := p.readValue(typ)
expectNoError(t, err, "")
typDiff(t, TypeEmbeddedDocument, v.t, expectedKey)
actObj := v.v.(*extJSONObject)
expObj := expectedValue.(*extJSONObject)
for i, actKey := range actObj.keys {
expKey := expObj.keys[i]
actVal := actObj.values[i]
expVal := expObj.values[i]
keyDiff(t, expKey, actKey, expectedKey)
typDiff(t, expVal.t, actVal.t, expectedKey)
valDiff(t, expVal.v, actVal.v, expectedKey)
}
}
type ejpKeyTypValTriple struct {
key string
typ Type
val *extJSONValue
}
type ejpSubDocumentTestValue struct {
code string // code is only used for TypeCodeWithScope (and is ignored for TypeEmbeddedDocument
ktvs []ejpKeyTypValTriple // list of (key, type, value) triples; this is "scope" for TypeCodeWithScope
}
// expectSubDocument is used for embedded documents and code with scope types; it reads all the keys and values
// in the embedded document (or scope for codeWithScope) and compares them to the expectedValue's list of (key, type,
// value) triples
func expectSubDocument(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{}) {
subdoc := expectedValue.(ejpSubDocumentTestValue)
k, typ, err := p.readKey()
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
if expectedType == TypeCodeWithScope {
v, err := p.readValue(typ)
readValueDiff(t, &extJSONValue{t: TypeString, v: subdoc.code}, v, err, expectNoError, expectedKey)
}
for _, ktv := range subdoc.ktvs {
eKey := ktv.key
eTyp := ktv.typ
eVal := ktv.val
k, typ, err = p.readKey()
readKeyDiff(t, eKey, k, eTyp, typ, err, expectNoError, expectedKey)
v, err := p.readValue(typ)
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
}
if expectedType == TypeCodeWithScope {
// expect scope doc to close
k, typ, err = p.readKey()
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOD, expectedKey)
}
// expect subdoc to close
k, typ, err = p.readKey()
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOD, expectedKey)
}
// expectArray takes the expectedKey, ignores the expectedType, and uses the expectedValue
// as a slice of (type Type, value *extJSONValue) pairs
func expectArray(t *testing.T, p *extJSONParser, expectedKey string, _ Type, expectedValue interface{}) {
ktvs := expectedValue.([]ejpKeyTypValTriple)
k, typ, err := p.readKey()
readKeyDiff(t, expectedKey, k, TypeArray, typ, err, expectNoError, expectedKey)
for _, ktv := range ktvs {
eTyp := ktv.typ
eVal := ktv.val
typ, err = p.peekType()
typDiff(t, eTyp, typ, expectedKey)
expectNoError(t, err, expectedKey)
v, err := p.readValue(typ)
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
}
// expect array to end
typ, err = p.peekType()
typDiff(t, Type(0), typ, expectedKey)
expectErrEOA(t, err, expectedKey)
}
func TestExtJSONParserAllTypes(t *testing.T) {
in := ` { "_id" : { "$oid": "57e193d7a9cc81b4027498b5"}
, "Symbol" : { "$symbol": "symbol"}
, "String" : "string"
, "Int32" : { "$numberInt": "42"}
, "Int64" : { "$numberLong": "42"}
, "Double" : { "$numberDouble": "42.42"}
, "SpecialFloat" : { "$numberDouble": "NaN" }
, "Decimal" : { "$numberDecimal": "1234" }
, "Binary" : { "$binary": { "base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03" } }
, "BinaryLegacy" : { "$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03" }
, "BinaryUserDefined" : { "$binary": { "base64": "AQIDBAU=", "subType": "80" } }
, "Code" : { "$code": "function() {}" }
, "CodeWithEmptyScope" : { "$code": "function() {}", "$scope": {} }
, "CodeWithScope" : { "$code": "function() {}", "$scope": { "x": 1 } }
, "EmptySubdocument" : {}
, "Subdocument" : { "foo": "bar", "baz": { "$numberInt": "42" } }
, "Array" : [{"$numberInt": "1"}, {"$numberLong": "2"}, {"$numberDouble": "3"}, 4, "string", 5.0]
, "Timestamp" : { "$timestamp": { "t": 42, "i": 1 } }
, "RegularExpression" : { "$regularExpression": { "pattern": "foo*", "options": "ix" } }
, "DatetimeEpoch" : { "$date": { "$numberLong": "0" } }
, "DatetimePositive" : { "$date": { "$numberLong": "9223372036854775807" } }
, "DatetimeNegative" : { "$date": { "$numberLong": "-9223372036854775808" } }
, "True" : true
, "False" : false
, "DBPointer" : { "$dbPointer": { "$ref": "db.collection", "$id": { "$oid": "57e193d7a9cc81b4027498b1" } } }
, "DBRef" : { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" }, "$db": "database" }
, "DBRefNoDB" : { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" } }
, "MinKey" : { "$minKey": 1 }
, "MaxKey" : { "$maxKey": 1 }
, "Null" : null
, "Undefined" : { "$undefined": true }
}`
ejp := newExtJSONParser(strings.NewReader(in), true)
cases := []ejpTestCase{
{
f: expectSingleValue, p: ejp,
k: "_id", t: TypeObjectID, v: &extJSONValue{t: TypeString, v: "57e193d7a9cc81b4027498b5"},
},
{
f: expectSingleValue, p: ejp,
k: "Symbol", t: TypeSymbol, v: &extJSONValue{t: TypeString, v: "symbol"},
},
{
f: expectSingleValue, p: ejp,
k: "String", t: TypeString, v: &extJSONValue{t: TypeString, v: "string"},
},
{
f: expectSingleValue, p: ejp,
k: "Int32", t: TypeInt32, v: &extJSONValue{t: TypeString, v: "42"},
},
{
f: expectSingleValue, p: ejp,
k: "Int64", t: TypeInt64, v: &extJSONValue{t: TypeString, v: "42"},
},
{
f: expectSingleValue, p: ejp,
k: "Double", t: TypeDouble, v: &extJSONValue{t: TypeString, v: "42.42"},
},
{
f: expectSingleValue, p: ejp,
k: "SpecialFloat", t: TypeDouble, v: &extJSONValue{t: TypeString, v: "NaN"},
},
{
f: expectSingleValue, p: ejp,
k: "Decimal", t: TypeDecimal128, v: &extJSONValue{t: TypeString, v: "1234"},
},
{
f: expectMultipleValues, p: ejp,
k: "Binary", t: TypeBinary,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{
{t: TypeString, v: "o0w498Or7cijeBSpkquNtg=="},
{t: TypeString, v: "03"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "BinaryLegacy", t: TypeBinary,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{
{t: TypeString, v: "o0w498Or7cijeBSpkquNtg=="},
{t: TypeString, v: "03"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "BinaryUserDefined", t: TypeBinary,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{
{t: TypeString, v: "AQIDBAU="},
{t: TypeString, v: "80"},
},
},
},
{
f: expectSingleValue, p: ejp,
k: "Code", t: TypeJavaScript, v: &extJSONValue{t: TypeString, v: "function() {}"},
},
{
f: expectSubDocument, p: ejp,
k: "CodeWithEmptyScope", t: TypeCodeWithScope,
v: ejpSubDocumentTestValue{
code: "function() {}",
ktvs: []ejpKeyTypValTriple{},
},
},
{
f: expectSubDocument, p: ejp,
k: "CodeWithScope", t: TypeCodeWithScope,
v: ejpSubDocumentTestValue{
code: "function() {}",
ktvs: []ejpKeyTypValTriple{
{"x", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}},
},
},
},
{
f: expectSubDocument, p: ejp,
k: "EmptySubdocument", t: TypeEmbeddedDocument,
v: ejpSubDocumentTestValue{
ktvs: []ejpKeyTypValTriple{},
},
},
{
f: expectSubDocument, p: ejp,
k: "Subdocument", t: TypeEmbeddedDocument,
v: ejpSubDocumentTestValue{
ktvs: []ejpKeyTypValTriple{
{"foo", TypeString, &extJSONValue{t: TypeString, v: "bar"}},
{"baz", TypeInt32, &extJSONValue{t: TypeString, v: "42"}},
},
},
},
{
f: expectArray, p: ejp,
k: "Array", t: TypeArray,
v: []ejpKeyTypValTriple{
{typ: TypeInt32, val: &extJSONValue{t: TypeString, v: "1"}},
{typ: TypeInt64, val: &extJSONValue{t: TypeString, v: "2"}},
{typ: TypeDouble, val: &extJSONValue{t: TypeString, v: "3"}},
{typ: TypeInt32, val: &extJSONValue{t: TypeInt32, v: int32(4)}},
{typ: TypeString, val: &extJSONValue{t: TypeString, v: "string"}},
{typ: TypeDouble, val: &extJSONValue{t: TypeDouble, v: 5.0}},
},
},
{
f: expectMultipleValues, p: ejp,
k: "Timestamp", t: TypeTimestamp,
v: &extJSONObject{
keys: []string{"t", "i"},
values: []*extJSONValue{
{t: TypeInt32, v: int32(42)},
{t: TypeInt32, v: int32(1)},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "RegularExpression", t: TypeRegex,
v: &extJSONObject{
keys: []string{"pattern", "options"},
values: []*extJSONValue{
{t: TypeString, v: "foo*"},
{t: TypeString, v: "ix"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "DatetimeEpoch", t: TypeDateTime,
v: &extJSONObject{
keys: []string{"$numberLong"},
values: []*extJSONValue{
{t: TypeString, v: "0"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "DatetimePositive", t: TypeDateTime,
v: &extJSONObject{
keys: []string{"$numberLong"},
values: []*extJSONValue{
{t: TypeString, v: "9223372036854775807"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "DatetimeNegative", t: TypeDateTime,
v: &extJSONObject{
keys: []string{"$numberLong"},
values: []*extJSONValue{
{t: TypeString, v: "-9223372036854775808"},
},
},
},
{
f: expectSingleValue, p: ejp,
k: "True", t: TypeBoolean, v: &extJSONValue{t: TypeBoolean, v: true},
},
{
f: expectSingleValue, p: ejp,
k: "False", t: TypeBoolean, v: &extJSONValue{t: TypeBoolean, v: false},
},
{
f: expectMultipleValues, p: ejp,
k: "DBPointer", t: TypeDBPointer,
v: &extJSONObject{
keys: []string{"$ref", "$id"},
values: []*extJSONValue{
{t: TypeString, v: "db.collection"},
{t: TypeString, v: "57e193d7a9cc81b4027498b1"},
},
},
},
{
f: expectSubDocument, p: ejp,
k: "DBRef", t: TypeEmbeddedDocument,
v: ejpSubDocumentTestValue{
ktvs: []ejpKeyTypValTriple{
{"$ref", TypeString, &extJSONValue{t: TypeString, v: "collection"}},
{"$id", TypeObjectID, &extJSONValue{t: TypeString, v: "57fd71e96e32ab4225b723fb"}},
{"$db", TypeString, &extJSONValue{t: TypeString, v: "database"}},
},
},
},
{
f: expectSubDocument, p: ejp,
k: "DBRefNoDB", t: TypeEmbeddedDocument,
v: ejpSubDocumentTestValue{
ktvs: []ejpKeyTypValTriple{
{"$ref", TypeString, &extJSONValue{t: TypeString, v: "collection"}},
{"$id", TypeObjectID, &extJSONValue{t: TypeString, v: "57fd71e96e32ab4225b723fb"}},
},
},
},
{
f: expectSingleValue, p: ejp,
k: "MinKey", t: TypeMinKey, v: &extJSONValue{t: TypeInt32, v: int32(1)},
},
{
f: expectSingleValue, p: ejp,
k: "MaxKey", t: TypeMaxKey, v: &extJSONValue{t: TypeInt32, v: int32(1)},
},
{
f: expectSingleValue, p: ejp,
k: "Null", t: TypeNull, v: &extJSONValue{t: TypeNull, v: nil},
},
{
f: expectSingleValue, p: ejp,
k: "Undefined", t: TypeUndefined, v: &extJSONValue{t: TypeBoolean, v: true},
},
}
// run the test cases
for _, tc := range cases {
tc.f(t, tc.p, tc.k, tc.t, tc.v)
}
// expect end of whole document: read final }
k, typ, err := ejp.readKey()
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOD, "")
// expect end of whole document: read EOF
k, typ, err = ejp.readKey()
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOF, "")
if diff := cmp.Diff(jpsDoneState, ejp.s); diff != "" {
t.Errorf("expected parser to be in done state but instead is in %v\n", ejp.s)
t.FailNow()
}
}
func TestExtJSONValue(t *testing.T) {
t.Run("Large Date", func(t *testing.T) {
val := &extJSONValue{
t: TypeString,
v: "3001-01-01T00:00:00Z",
}
intVal, err := val.parseDateTime()
if err != nil {
t.Fatalf("error parsing date time: %v", err)
}
if intVal <= 0 {
t.Fatalf("expected value above 0, got %v", intVal)
}
})
t.Run("fallback time format", func(t *testing.T) {
val := &extJSONValue{
t: TypeString,
v: "2019-06-04T14:54:31.416+0000",
}
_, err := val.parseDateTime()
if err != nil {
t.Fatalf("error parsing date time: %v", err)
}
})
}

46
extjson_prose_test.go Normal file
View File

@ -0,0 +1,46 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
)
func TestExtJSON(t *testing.T) {
timestampNegativeInt32Err := fmt.Errorf("$timestamp i number should be uint32: -1")
timestampNegativeInt64Err := fmt.Errorf("$timestamp i number should be uint32: -2147483649")
timestampLargeValueErr := fmt.Errorf("$timestamp i number should be uint32: 4294967296")
testCases := []struct {
name string
input string
canonical bool
err error
}{
{"timestamp - negative int32 value", `{"":{"$timestamp":{"t":0,"i":-1}}}`, false, timestampNegativeInt32Err},
{"timestamp - negative int64 value", `{"":{"$timestamp":{"t":0,"i":-2147483649}}}`, false, timestampNegativeInt64Err},
{"timestamp - value overflows uint32", `{"":{"$timestamp":{"t":0,"i":4294967296}}}`, false, timestampLargeValueErr},
{"top level key is not treated as special", `{"$code": "foo"}`, false, nil},
{"escaped single quote errors", `{"f\'oo": "bar"}`, false, ErrInvalidJSON},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var res Raw
err := UnmarshalExtJSON([]byte(tc.input), tc.canonical, &res)
if tc.err == nil {
assert.Nil(t, err, "UnmarshalExtJSON error: %v", err)
return
}
assert.NotNil(t, err, "expected error %v, got nil", tc.err)
assert.Equal(t, tc.err.Error(), err.Error(), "expected error %v, got %v", tc.err, err)
})
}
}

606
extjson_reader.go Normal file
View File

@ -0,0 +1,606 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"io"
)
type ejvrState struct {
mode mode
vType Type
depth int
}
// extJSONValueReader is for reading extended JSON.
type extJSONValueReader struct {
p *extJSONParser
stack []ejvrState
frame int
}
// NewExtJSONValueReader returns a ValueReader that reads Extended JSON values
// from r. If canonicalOnly is true, reading values from the ValueReader returns
// an error if the Extended JSON was not marshaled in canonical mode.
func NewExtJSONValueReader(r io.Reader, canonicalOnly bool) (ValueReader, error) {
return newExtJSONValueReader(r, canonicalOnly)
}
func newExtJSONValueReader(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
ejvr := new(extJSONValueReader)
return ejvr.reset(r, canonicalOnly)
}
func (ejvr *extJSONValueReader) reset(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
p := newExtJSONParser(r, canonicalOnly)
typ, err := p.peekType()
if err != nil {
return nil, ErrInvalidJSON
}
var m mode
switch typ {
case TypeEmbeddedDocument:
m = mTopLevel
case TypeArray:
m = mArray
default:
m = mValue
}
stack := make([]ejvrState, 1, 5)
stack[0] = ejvrState{
mode: m,
vType: typ,
}
return &extJSONValueReader{
p: p,
stack: stack,
}, nil
}
func (ejvr *extJSONValueReader) advanceFrame() {
if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack
length := len(ejvr.stack)
if length+1 >= cap(ejvr.stack) {
// double it
buf := make([]ejvrState, 2*cap(ejvr.stack)+1)
copy(buf, ejvr.stack)
ejvr.stack = buf
}
ejvr.stack = ejvr.stack[:length+1]
}
ejvr.frame++
// Clean the stack
ejvr.stack[ejvr.frame].mode = 0
ejvr.stack[ejvr.frame].vType = 0
ejvr.stack[ejvr.frame].depth = 0
}
func (ejvr *extJSONValueReader) pushDocument() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mDocument
ejvr.stack[ejvr.frame].depth = ejvr.p.depth
}
func (ejvr *extJSONValueReader) pushCodeWithScope() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mCodeWithScope
}
func (ejvr *extJSONValueReader) pushArray() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mArray
}
func (ejvr *extJSONValueReader) push(m mode, t Type) {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = m
ejvr.stack[ejvr.frame].vType = t
}
func (ejvr *extJSONValueReader) pop() {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
ejvr.frame--
case mDocument, mArray, mCodeWithScope:
ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
}
func (ejvr *extJSONValueReader) skipObject() {
// read entire object until depth returns to 0 (last ending } or ] seen)
depth := 1
for depth > 0 {
ejvr.p.advanceState()
// If object is empty, raise depth and continue. When emptyObject is true, the
// parser has already read both the opening and closing brackets of an empty
// object ("{}"), so the next valid token will be part of the parent document,
// not part of the nested document.
//
// If there is a comma, there are remaining fields, emptyObject must be set back
// to false, and comma must be skipped with advanceState().
if ejvr.p.emptyObject {
if ejvr.p.s == jpsSawComma {
ejvr.p.emptyObject = false
ejvr.p.advanceState()
}
depth--
continue
}
switch ejvr.p.s {
case jpsSawBeginObject, jpsSawBeginArray:
depth++
case jpsSawEndObject, jpsSawEndArray:
depth--
}
}
}
func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvr.stack[ejvr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if ejvr.frame != 0 {
te.parent = ejvr.stack[ejvr.frame-1].mode
}
return te
}
func (ejvr *extJSONValueReader) typeError(t Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t)
}
func (ejvr *extJSONValueReader) ensureElementValue(t Type, destination mode, callerName string, addModes ...mode) error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != t {
return ejvr.typeError(t)
}
default:
modes := []mode{mElement, mValue}
if addModes != nil {
modes = append(modes, addModes...)
}
return ejvr.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvr *extJSONValueReader) Type() Type {
return ejvr.stack[ejvr.frame].vType
}
func (ejvr *extJSONValueReader) Skip() error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
default:
return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
defer ejvr.pop()
t := ejvr.stack[ejvr.frame].vType
switch t {
case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope:
// read entire array, doc or CodeWithScope
ejvr.skipObject()
default:
_, err := ejvr.p.readValue(t)
if err != nil {
return err
}
}
return nil
}
func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel: // allow reading array from top level
case mArray:
return ejvr, nil
default:
if err := ejvr.ensureElementValue(TypeArray, mArray, "ReadArray", mTopLevel, mArray); err != nil {
return nil, err
}
}
ejvr.pushArray()
return ejvr, nil
}
func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := ejvr.ensureElementValue(TypeBinary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
v, err := ejvr.p.readValue(TypeBinary)
if err != nil {
return nil, 0, err
}
b, btype, err = v.parseBinary()
ejvr.pop()
return b, btype, err
}
func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) {
if err := ejvr.ensureElementValue(TypeBoolean, 0, "ReadBoolean"); err != nil {
return false, err
}
v, err := ejvr.p.readValue(TypeBoolean)
if err != nil {
return false, err
}
if v.t != TypeBoolean {
return false, fmt.Errorf("expected type bool, but got type %s", v.t)
}
ejvr.pop()
return v.v.(bool), nil
}
func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel:
return ejvr, nil
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != TypeEmbeddedDocument {
return nil, ejvr.typeError(TypeEmbeddedDocument)
}
ejvr.pushDocument()
return ejvr, nil
default:
return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
}
func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err = ejvr.ensureElementValue(TypeCodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
v, err := ejvr.p.readValue(TypeCodeWithScope)
if err != nil {
return "", nil, err
}
code, err = v.parseJavascript()
ejvr.pushCodeWithScope()
return code, ejvr, err
}
func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid ObjectID, err error) {
if err = ejvr.ensureElementValue(TypeDBPointer, 0, "ReadDBPointer"); err != nil {
return "", NilObjectID, err
}
v, err := ejvr.p.readValue(TypeDBPointer)
if err != nil {
return "", NilObjectID, err
}
ns, oid, err = v.parseDBPointer()
ejvr.pop()
return ns, oid, err
}
func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) {
if err := ejvr.ensureElementValue(TypeDateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeDateTime)
if err != nil {
return 0, err
}
d, err := v.parseDateTime()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDecimal128() (Decimal128, error) {
if err := ejvr.ensureElementValue(TypeDecimal128, 0, "ReadDecimal128"); err != nil {
return Decimal128{}, err
}
v, err := ejvr.p.readValue(TypeDecimal128)
if err != nil {
return Decimal128{}, err
}
d, err := v.parseDecimal128()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDouble() (float64, error) {
if err := ejvr.ensureElementValue(TypeDouble, 0, "ReadDouble"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeDouble)
if err != nil {
return 0, err
}
d, err := v.parseDouble()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadInt32() (int32, error) {
if err := ejvr.ensureElementValue(TypeInt32, 0, "ReadInt32"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeInt32)
if err != nil {
return 0, err
}
i, err := v.parseInt32()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadInt64() (int64, error) {
if err := ejvr.ensureElementValue(TypeInt64, 0, "ReadInt64"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(TypeInt64)
if err != nil {
return 0, err
}
i, err := v.parseInt64()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) {
if err = ejvr.ensureElementValue(TypeJavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(TypeJavaScript)
if err != nil {
return "", err
}
code, err = v.parseJavascript()
ejvr.pop()
return code, err
}
func (ejvr *extJSONValueReader) ReadMaxKey() error {
if err := ejvr.ensureElementValue(TypeMaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeMaxKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("max")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadMinKey() error {
if err := ejvr.ensureElementValue(TypeMinKey, 0, "ReadMinKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeMinKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("min")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadNull() error {
if err := ejvr.ensureElementValue(TypeNull, 0, "ReadNull"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeNull)
if err != nil {
return err
}
if v.t != TypeNull {
return fmt.Errorf("expected type null but got type %s", v.t)
}
ejvr.pop()
return nil
}
func (ejvr *extJSONValueReader) ReadObjectID() (ObjectID, error) {
if err := ejvr.ensureElementValue(TypeObjectID, 0, "ReadObjectID"); err != nil {
return ObjectID{}, err
}
v, err := ejvr.p.readValue(TypeObjectID)
if err != nil {
return ObjectID{}, err
}
oid, err := v.parseObjectID()
ejvr.pop()
return oid, err
}
func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) {
if err = ejvr.ensureElementValue(TypeRegex, 0, "ReadRegex"); err != nil {
return "", "", err
}
v, err := ejvr.p.readValue(TypeRegex)
if err != nil {
return "", "", err
}
pattern, options, err = v.parseRegex()
ejvr.pop()
return pattern, options, err
}
func (ejvr *extJSONValueReader) ReadString() (string, error) {
if err := ejvr.ensureElementValue(TypeString, 0, "ReadString"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(TypeString)
if err != nil {
return "", err
}
if v.t != TypeString {
return "", fmt.Errorf("expected type string but got type %s", v.t)
}
ejvr.pop()
return v.v.(string), nil
}
func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) {
if err = ejvr.ensureElementValue(TypeSymbol, 0, "ReadSymbol"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(TypeSymbol)
if err != nil {
return "", err
}
symbol, err = v.parseSymbol()
ejvr.pop()
return symbol, err
}
func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err = ejvr.ensureElementValue(TypeTimestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
v, err := ejvr.p.readValue(TypeTimestamp)
if err != nil {
return 0, 0, err
}
t, i, err = v.parseTimestamp()
ejvr.pop()
return t, i, err
}
func (ejvr *extJSONValueReader) ReadUndefined() error {
if err := ejvr.ensureElementValue(TypeUndefined, 0, "ReadUndefined"); err != nil {
return err
}
v, err := ejvr.p.readValue(TypeUndefined)
if err != nil {
return err
}
err = v.parseUndefined()
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
name, t, err := ejvr.p.readKey()
if err != nil {
if errors.Is(err, ErrEOD) {
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
_, err := ejvr.p.peekType()
if err != nil {
return "", nil, err
}
}
ejvr.pop()
}
return "", nil, err
}
ejvr.push(mElement, t)
return name, ejvr, nil
}
func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mArray:
default:
return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := ejvr.p.peekType()
if err != nil {
if errors.Is(err, ErrEOA) {
ejvr.pop()
}
return nil, err
}
ejvr.push(mValue, t)
return ejvr, nil
}

168
extjson_reader_test.go Normal file
View File

@ -0,0 +1,168 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"errors"
"fmt"
"io"
"strings"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
"github.com/google/go-cmp/cmp"
)
func TestExtJSONReader(t *testing.T) {
t.Run("ReadDocument", func(t *testing.T) {
t.Run("EmbeddedDocument", func(t *testing.T) {
ejvr := &extJSONValueReader{
stack: []ejvrState{
{mode: mTopLevel},
{mode: mElement, vType: TypeBoolean},
},
frame: 1,
}
ejvr.stack[1].mode = mArray
wanterr := ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
_, err := ejvr.ReadDocument()
if err == nil || err.Error() != wanterr.Error() {
t.Errorf("Incorrect returned error. got %v; want %v", err, wanterr)
}
})
})
t.Run("invalid transition", func(t *testing.T) {
t.Run("Skip", func(t *testing.T) {
ejvr := &extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}}
wanterr := (&extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}}).invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
goterr := ejvr.Skip()
if !cmp.Equal(goterr, wanterr, cmp.Comparer(assert.CompareErrors)) {
t.Errorf("Expected correct invalid transition error. got %v; want %v", goterr, wanterr)
}
})
})
}
func TestReadMultipleTopLevelDocuments(t *testing.T) {
testCases := []struct {
name string
input string
expected [][]byte
}{
{
"single top-level document",
"{\"foo\":1}",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
},
},
{
"single top-level document with leading and trailing whitespace",
"\n\n {\"foo\":1} \n",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
},
},
{
"two top-level documents",
"{\"foo\":1}{\"foo\":2}",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
},
},
{
"two top-level documents with leading and trailing whitespace and whitespace separation ",
"\n\n {\"foo\":1}\n{\"foo\":2}\n ",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
},
},
{
"top-level array with single document",
"[{\"foo\":1}]",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
},
},
{
"top-level array with 2 documents",
"[{\"foo\":1},{\"foo\":2}]",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := strings.NewReader(tc.input)
vr, err := NewExtJSONValueReader(r, false)
if err != nil {
t.Fatalf("expected no error, but got %v", err)
}
actual, err := readAllDocuments(vr)
if err != nil {
t.Fatalf("expected no error, but got %v", err)
}
if diff := cmp.Diff(tc.expected, actual); diff != "" {
t.Fatalf("expected does not match actual: %v", diff)
}
})
}
}
func readAllDocuments(vr ValueReader) ([][]byte, error) {
var actual [][]byte
switch vr.Type() {
case TypeEmbeddedDocument:
for {
result, err := copyDocumentToBytes(vr)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return nil, err
}
actual = append(actual, result)
}
case TypeArray:
ar, err := vr.ReadArray()
if err != nil {
return nil, err
}
for {
evr, err := ar.ReadValue()
if err != nil {
if errors.Is(err, ErrEOA) {
break
}
return nil, err
}
result, err := copyDocumentToBytes(evr)
if err != nil {
return nil, err
}
actual = append(actual, result)
}
default:
return nil, fmt.Errorf("expected an array or a document, but got %s", vr.Type())
}
return actual, nil
}

223
extjson_tables.go Normal file
View File

@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/golang/go by The Go Authors
// See THIRD-PARTY-NOTICES for original license terms.
package bson
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

489
extjson_wrappers.go Normal file
View File

@ -0,0 +1,489 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/base64"
"errors"
"fmt"
"math"
"strconv"
"time"
)
func wrapperKeyBSONType(key string) Type {
switch key {
case "$numberInt":
return TypeInt32
case "$numberLong":
return TypeInt64
case "$oid":
return TypeObjectID
case "$symbol":
return TypeSymbol
case "$numberDouble":
return TypeDouble
case "$numberDecimal":
return TypeDecimal128
case "$binary":
return TypeBinary
case "$code":
return TypeJavaScript
case "$scope":
return TypeCodeWithScope
case "$timestamp":
return TypeTimestamp
case "$regularExpression":
return TypeRegex
case "$dbPointer":
return TypeDBPointer
case "$date":
return TypeDateTime
case "$minKey":
return TypeMinKey
case "$maxKey":
return TypeMaxKey
case "$undefined":
return TypeUndefined
}
return TypeEmbeddedDocument
}
func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
if ejv.t != TypeEmbeddedDocument {
return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t)
}
binObj := ejv.v.(*extJSONObject)
bFound := false
stFound := false
for i, key := range binObj.keys {
val := binObj.values[i]
switch key {
case "base64":
if bFound {
return nil, 0, errors.New("duplicate base64 key in $binary")
}
if val.t != TypeString {
return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t)
}
base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string))
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string))
}
b = base64Bytes
bFound = true
case "subType":
if stFound {
return nil, 0, errors.New("duplicate subType key in $binary")
}
if val.t != TypeString {
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
}
i, err := strconv.ParseUint(val.v.(string), 16, 8)
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err)
}
subType = byte(i)
stFound = true
default:
return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key)
}
}
if !bFound {
return nil, 0, errors.New("missing base64 field in $binary object")
}
if !stFound {
return nil, 0, errors.New("missing subType field in $binary object")
}
return b, subType, nil
}
func (ejv *extJSONValue) parseDBPointer() (ns string, oid ObjectID, err error) {
if ejv.t != TypeEmbeddedDocument {
return "", NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t)
}
dbpObj := ejv.v.(*extJSONObject)
oidFound := false
nsFound := false
for i, key := range dbpObj.keys {
val := dbpObj.values[i]
switch key {
case "$ref":
if nsFound {
return "", NilObjectID, errors.New("duplicate $ref key in $dbPointer")
}
if val.t != TypeString {
return "", NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t)
}
ns = val.v.(string)
nsFound = true
case "$id":
if oidFound {
return "", NilObjectID, errors.New("duplicate $id key in $dbPointer")
}
if val.t != TypeString {
return "", NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t)
}
oid, err = ObjectIDFromHex(val.v.(string))
if err != nil {
return "", NilObjectID, err
}
oidFound = true
default:
return "", NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key)
}
}
if !nsFound {
return "", oid, errors.New("missing $ref field in $dbPointer object")
}
if !oidFound {
return "", oid, errors.New("missing $id field in $dbPointer object")
}
return ns, oid, nil
}
const (
rfc3339Milli = "2006-01-02T15:04:05.999Z07:00"
)
var (
timeFormats = []string{rfc3339Milli, "2006-01-02T15:04:05.999Z0700"}
)
func (ejv *extJSONValue) parseDateTime() (int64, error) {
switch ejv.t {
case TypeInt32:
return int64(ejv.v.(int32)), nil
case TypeInt64:
return ejv.v.(int64), nil
case TypeString:
return parseDatetimeString(ejv.v.(string))
case TypeEmbeddedDocument:
return parseDatetimeObject(ejv.v.(*extJSONObject))
default:
return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t)
}
}
func parseDatetimeString(data string) (int64, error) {
var t time.Time
var err error
// try acceptable time formats until one matches
for _, format := range timeFormats {
t, err = time.Parse(format, data)
if err == nil {
break
}
}
if err != nil {
return 0, fmt.Errorf("invalid $date value string: %s", data)
}
return int64(NewDateTimeFromTime(t)), nil
}
func parseDatetimeObject(data *extJSONObject) (d int64, err error) {
dFound := false
for i, key := range data.keys {
val := data.values[i]
switch key {
case "$numberLong":
if dFound {
return 0, errors.New("duplicate $numberLong key in $date")
}
if val.t != TypeString {
return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t)
}
d, err = val.parseInt64()
if err != nil {
return 0, err
}
dFound = true
default:
return 0, fmt.Errorf("invalid key in $date object: %s", key)
}
}
if !dFound {
return 0, errors.New("missing $numberLong field in $date object")
}
return d, nil
}
func (ejv *extJSONValue) parseDecimal128() (Decimal128, error) {
if ejv.t != TypeString {
return Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t)
}
d, err := ParseDecimal128(ejv.v.(string))
if err != nil {
return Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string))
}
return d, nil
}
func (ejv *extJSONValue) parseDouble() (float64, error) {
if ejv.t == TypeDouble {
return ejv.v.(float64), nil
}
if ejv.t != TypeString {
return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t)
}
switch ejv.v.(string) {
case "Infinity":
return math.Inf(1), nil
case "-Infinity":
return math.Inf(-1), nil
case "NaN":
return math.NaN(), nil
}
f, err := strconv.ParseFloat(ejv.v.(string), 64)
if err != nil {
return 0, err
}
return f, nil
}
func (ejv *extJSONValue) parseInt32() (int32, error) {
if ejv.t == TypeInt32 {
return ejv.v.(int32), nil
}
if ejv.t != TypeString {
return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
if i < math.MinInt32 || i > math.MaxInt32 {
return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i)
}
return int32(i), nil
}
func (ejv *extJSONValue) parseInt64() (int64, error) {
if ejv.t == TypeInt64 {
return ejv.v.(int64), nil
}
if ejv.t != TypeString {
return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
return i, nil
}
func (ejv *extJSONValue) parseJavascript() (code string, err error) {
if ejv.t != TypeString {
return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseMinMaxKey(minmax string) error {
if ejv.t != TypeInt32 {
return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t)
}
if ejv.v.(int32) != 1 {
return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32))
}
return nil
}
func (ejv *extJSONValue) parseObjectID() (ObjectID, error) {
if ejv.t != TypeString {
return NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t)
}
return ObjectIDFromHex(ejv.v.(string))
}
func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) {
if ejv.t != TypeEmbeddedDocument {
return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t)
}
regexObj := ejv.v.(*extJSONObject)
patFound := false
optFound := false
for i, key := range regexObj.keys {
val := regexObj.values[i]
switch key {
case "pattern":
if patFound {
return "", "", errors.New("duplicate pattern key in $regularExpression")
}
if val.t != TypeString {
return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t)
}
pattern = val.v.(string)
patFound = true
case "options":
if optFound {
return "", "", errors.New("duplicate options key in $regularExpression")
}
if val.t != TypeString {
return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t)
}
options = val.v.(string)
optFound = true
default:
return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key)
}
}
if !patFound {
return "", "", errors.New("missing pattern field in $regularExpression object")
}
if !optFound {
return "", "", errors.New("missing options field in $regularExpression object")
}
return pattern, options, nil
}
func (ejv *extJSONValue) parseSymbol() (string, error) {
if ejv.t != TypeString {
return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) {
if ejv.t != TypeEmbeddedDocument {
return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t)
}
handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) {
if flag {
return 0, fmt.Errorf("duplicate %s key in $timestamp", key)
}
switch val.t {
case TypeInt32:
value := val.v.(int32)
if value < 0 {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
}
return uint32(value), nil
case TypeInt64:
value := val.v.(int64)
if value < 0 || value > int64(math.MaxUint32) {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
}
return uint32(value), nil
default:
return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t)
}
}
tsObj := ejv.v.(*extJSONObject)
tFound := false
iFound := false
for j, key := range tsObj.keys {
val := tsObj.values[j]
switch key {
case "t":
if t, err = handleKey(key, val, tFound); err != nil {
return 0, 0, err
}
tFound = true
case "i":
if i, err = handleKey(key, val, iFound); err != nil {
return 0, 0, err
}
iFound = true
default:
return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key)
}
}
if !tFound {
return 0, 0, errors.New("missing t field in $timestamp object")
}
if !iFound {
return 0, 0, errors.New("missing i field in $timestamp object")
}
return t, i, nil
}
func (ejv *extJSONValue) parseUndefined() error {
if ejv.t != TypeBoolean {
return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t)
}
if !ejv.v.(bool) {
return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool))
}
return nil
}

690
extjson_writer.go Normal file
View File

@ -0,0 +1,690 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"math"
"sort"
"strconv"
"strings"
"time"
"unicode/utf8"
)
type ejvwState struct {
mode mode
}
type extJSONValueWriter struct {
w io.Writer
buf []byte
stack []ejvwState
frame int64
canonical bool
escapeHTML bool
newlines bool
}
// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) ValueWriter {
// Enable newlines for all Extended JSON value writers created by NewExtJSONValueWriter. We
// expect these value writers to be used with an Encoder, which should add newlines after
// encoded Extended JSON documents.
return newExtJSONWriter(w, canonical, escapeHTML, true)
}
func newExtJSONWriter(w io.Writer, canonical, escapeHTML, newlines bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
w: w,
buf: []byte{},
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
newlines: newlines,
}
}
func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
buf: buf,
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
}
}
func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
if ejvw.stack == nil {
ejvw.stack = make([]ejvwState, 1, 5)
}
ejvw.stack = ejvw.stack[:1]
ejvw.stack[0] = ejvwState{mode: mTopLevel}
ejvw.canonical = canonical
ejvw.escapeHTML = escapeHTML
ejvw.frame = 0
ejvw.buf = buf
ejvw.w = nil
}
func (ejvw *extJSONValueWriter) advanceFrame() {
if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
length := len(ejvw.stack)
if length+1 >= cap(ejvw.stack) {
// double it
buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
copy(buf, ejvw.stack)
ejvw.stack = buf
}
ejvw.stack = ejvw.stack[:length+1]
}
ejvw.frame++
}
func (ejvw *extJSONValueWriter) push(m mode) {
ejvw.advanceFrame()
ejvw.stack[ejvw.frame].mode = m
}
func (ejvw *extJSONValueWriter) pop() {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
ejvw.frame--
case mDocument, mArray, mCodeWithScope:
ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvw.stack[ejvw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if ejvw.frame != 0 {
te.parent = ejvw.stack[ejvw.frame-1].mode
}
return te
}
func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return ejvw.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
var s string
if quotes {
s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
} else {
s = fmt.Sprintf(`{"$%s":%s}`, key, value)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '[')
ejvw.push(mArray)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
return ejvw.WriteBinaryWithSubtype(b, 0x00)
}
func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$binary":{"base64":"`)
buf.WriteString(base64.StdEncoding.EncodeToString(b))
buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
var buf bytes.Buffer
buf.WriteString(`{"$code":`)
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
buf.WriteString(`,"$scope":{`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.push(mCodeWithScope)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$dbPointer":{"$ref":"`)
buf.WriteString(ns)
buf.WriteString(`","$id":{"$oid":"`)
buf.WriteString(oid.Hex())
buf.WriteString(`"}}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
return err
}
t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
ejvw.writeExtendedSingleValue("date", s, false)
} else {
ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDecimal128(d Decimal128) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
if ejvw.stack[ejvw.frame].mode == mTopLevel {
ejvw.buf = append(ejvw.buf, '{')
return ejvw, nil
}
if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '{')
ejvw.push(mDocument)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
return err
}
s := formatDouble(f)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberDouble", s, true)
} else {
switch s {
case "Infinity":
fallthrough
case "-Infinity":
fallthrough
case "NaN":
s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
return err
}
s := strconv.FormatInt(int64(i), 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberInt", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
return err
}
s := strconv.FormatInt(i, 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberLong", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("code", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMaxKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("maxKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMinKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("minKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteNull() error {
if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte("null")...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteObjectID(oid ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
return err
}
options = sortStringAlphebeticAscending(options)
var buf bytes.Buffer
buf.WriteString(`{"$regularExpression":{"pattern":`)
writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
buf.WriteString(`,"options":`)
writeStringWithEscapes(options, &buf, ejvw.escapeHTML)
buf.WriteString(`}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteString(s string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$timestamp":{"t":`)
buf.WriteString(strconv.FormatUint(uint64(t), 10))
buf.WriteString(`,"i":`)
buf.WriteString(strconv.FormatUint(uint64(i), 10))
buf.WriteString(`}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteUndefined() error {
if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("undefined", "true", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
var buf bytes.Buffer
writeStringWithEscapes(key, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`%s:`, buf.String()))...)
ejvw.push(mElement)
default:
return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
default:
return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
}
// close the document
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = '}'
} else {
ejvw.buf = append(ejvw.buf, '}')
}
switch ejvw.stack[ejvw.frame].mode {
case mCodeWithScope:
ejvw.buf = append(ejvw.buf, '}')
fallthrough
case mDocument:
ejvw.buf = append(ejvw.buf, ',')
case mTopLevel:
// If the value writer has newlines enabled, end top-level documents with a newline so that
// multiple documents encoded to the same writer are separated by newlines. That matches the
// Go json.Encoder behavior and also works with NewExtJSONValueReader.
if ejvw.newlines {
ejvw.buf = append(ejvw.buf, '\n')
}
if ejvw.w != nil {
if _, err := ejvw.w.Write(ejvw.buf); err != nil {
return err
}
ejvw.buf = ejvw.buf[:0]
}
}
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
ejvw.push(mValue)
default:
return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
// close the array
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = ']'
} else {
ejvw.buf = append(ejvw.buf, ']')
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
default:
return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
}
return nil
}
func formatDouble(f float64) string {
var s string
switch {
case math.IsInf(f, 1):
s = "Infinity"
case math.IsInf(f, -1):
s = "-Infinity"
case math.IsNaN(f):
s = "NaN"
default:
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
var hexChars = "0123456789abcdef"
func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
buf.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
i++
continue
}
if start < i {
buf.WriteString(s[start:i])
}
switch b {
case '\\', '"':
buf.WriteByte('\\')
buf.WriteByte(b)
case '\n':
buf.WriteByte('\\')
buf.WriteByte('n')
case '\r':
buf.WriteByte('\\')
buf.WriteByte('r')
case '\t':
buf.WriteByte('\\')
buf.WriteByte('t')
case '\b':
buf.WriteByte('\\')
buf.WriteByte('b')
case '\f':
buf.WriteByte('\\')
buf.WriteByte('f')
default:
// This encodes bytes < 0x20 except for \t, \n and \r.
// If escapeHTML is set, it also escapes <, >, and &
// because they can lead to security holes when
// user-controlled strings are rendered into JSON
// and served to some browsers.
buf.WriteString(`\u00`)
buf.WriteByte(hexChars[b>>4])
buf.WriteByte(hexChars[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\u202`)
buf.WriteByte(hexChars[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
buf.WriteString(s[start:])
}
buf.WriteByte('"')
}
type sortableString []rune
func (ss sortableString) Len() int {
return len(ss)
}
func (ss sortableString) Less(i, j int) bool {
return ss[i] < ss[j]
}
func (ss sortableString) Swap(i, j int) {
ss[i], ss[j] = ss[j], ss[i]
}
func sortStringAlphebeticAscending(s string) string {
ss := sortableString([]rune(s))
sort.Sort(ss)
return string([]rune(ss))
}

259
extjson_writer_test.go Normal file
View File

@ -0,0 +1,259 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"fmt"
"io"
"reflect"
"strings"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
)
func TestExtJSONValueWriter(t *testing.T) {
oid := ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
testCases := []struct {
name string
fn interface{}
params []interface{}
}{
{
"WriteBinary",
(*extJSONValueWriter).WriteBinary,
[]interface{}{[]byte{0x01, 0x02, 0x03}},
},
{
"WriteBinaryWithSubtype (not 0x02)",
(*extJSONValueWriter).WriteBinaryWithSubtype,
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xFF)},
},
{
"WriteBinaryWithSubtype (0x02)",
(*extJSONValueWriter).WriteBinaryWithSubtype,
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0x02)},
},
{
"WriteBoolean",
(*extJSONValueWriter).WriteBoolean,
[]interface{}{true},
},
{
"WriteDBPointer",
(*extJSONValueWriter).WriteDBPointer,
[]interface{}{"bar", oid},
},
{
"WriteDateTime",
(*extJSONValueWriter).WriteDateTime,
[]interface{}{int64(12345678)},
},
{
"WriteDecimal128",
(*extJSONValueWriter).WriteDecimal128,
[]interface{}{NewDecimal128(10, 20)},
},
{
"WriteDouble",
(*extJSONValueWriter).WriteDouble,
[]interface{}{float64(3.14159)},
},
{
"WriteInt32",
(*extJSONValueWriter).WriteInt32,
[]interface{}{int32(123456)},
},
{
"WriteInt64",
(*extJSONValueWriter).WriteInt64,
[]interface{}{int64(1234567890)},
},
{
"WriteJavascript",
(*extJSONValueWriter).WriteJavascript,
[]interface{}{"var foo = 'bar';"},
},
{
"WriteMaxKey",
(*extJSONValueWriter).WriteMaxKey,
[]interface{}{},
},
{
"WriteMinKey",
(*extJSONValueWriter).WriteMinKey,
[]interface{}{},
},
{
"WriteNull",
(*extJSONValueWriter).WriteNull,
[]interface{}{},
},
{
"WriteObjectID",
(*extJSONValueWriter).WriteObjectID,
[]interface{}{oid},
},
{
"WriteRegex",
(*extJSONValueWriter).WriteRegex,
[]interface{}{"bar", "baz"},
},
{
"WriteString",
(*extJSONValueWriter).WriteString,
[]interface{}{"hello, world!"},
},
{
"WriteSymbol",
(*extJSONValueWriter).WriteSymbol,
[]interface{}{"symbollolz"},
},
{
"WriteTimestamp",
(*extJSONValueWriter).WriteTimestamp,
[]interface{}{uint32(10), uint32(20)},
},
{
"WriteUndefined",
(*extJSONValueWriter).WriteUndefined,
[]interface{}{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind())
}
if fn.Type().NumIn() != len(tc.params)+1 || fn.Type().In(0) != reflect.TypeOf((*extJSONValueWriter)(nil)) {
t.Fatalf("fn must have at least one parameter and the first parameter must be a *valueWriter")
}
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
t.Fatalf("fn must have one return value and it must be an error.")
}
params := make([]reflect.Value, 1, len(tc.params)+1)
ejvw := newExtJSONWriter(io.Discard, true, true, false)
params[0] = reflect.ValueOf(ejvw)
for _, param := range tc.params {
params = append(params, reflect.ValueOf(param))
}
t.Run("incorrect transition", func(t *testing.T) {
results := fn.Call(params)
got := results[0].Interface().(error)
fnName := tc.name
if strings.Contains(fnName, "WriteBinary") {
fnName = "WriteBinaryWithSubtype"
}
want := TransitionError{current: mTopLevel, name: fnName, modes: []mode{mElement, mValue},
action: "write"}
if !assert.CompareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
})
}
t.Run("WriteArray", func(t *testing.T) {
ejvw := newExtJSONWriter(io.Discard, true, true, false)
ejvw.push(mArray)
want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel,
name: "WriteArray", modes: []mode{mElement, mValue}, action: "write"}
_, got := ejvw.WriteArray()
if !assert.CompareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteCodeWithScope", func(t *testing.T) {
ejvw := newExtJSONWriter(io.Discard, true, true, false)
ejvw.push(mArray)
want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel,
name: "WriteCodeWithScope", modes: []mode{mElement, mValue}, action: "write"}
_, got := ejvw.WriteCodeWithScope("")
if !assert.CompareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteDocument", func(t *testing.T) {
ejvw := newExtJSONWriter(io.Discard, true, true, false)
ejvw.push(mArray)
want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel,
name: "WriteDocument", modes: []mode{mElement, mValue, mTopLevel}, action: "write"}
_, got := ejvw.WriteDocument()
if !assert.CompareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteDocumentElement", func(t *testing.T) {
ejvw := newExtJSONWriter(io.Discard, true, true, false)
ejvw.push(mElement)
want := TransitionError{current: mElement,
destination: mElement,
parent: mTopLevel,
name: "WriteDocumentElement",
modes: []mode{mDocument, mTopLevel, mCodeWithScope},
action: "write"}
_, got := ejvw.WriteDocumentElement("")
if !assert.CompareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteDocumentEnd", func(t *testing.T) {
ejvw := newExtJSONWriter(io.Discard, true, true, false)
ejvw.push(mElement)
want := fmt.Errorf("incorrect mode to end document: %s", mElement)
got := ejvw.WriteDocumentEnd()
if !assert.CompareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteArrayElement", func(t *testing.T) {
ejvw := newExtJSONWriter(io.Discard, true, true, false)
ejvw.push(mElement)
want := TransitionError{current: mElement,
destination: mValue,
parent: mTopLevel,
name: "WriteArrayElement",
modes: []mode{mArray},
action: "write"}
_, got := ejvw.WriteArrayElement()
if !assert.CompareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteArrayEnd", func(t *testing.T) {
ejvw := newExtJSONWriter(io.Discard, true, true, false)
ejvw.push(mElement)
want := fmt.Errorf("incorrect mode to end array: %s", mElement)
got := ejvw.WriteArrayEnd()
if !assert.CompareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteBytes", func(t *testing.T) {
t.Run("writeElementHeader error", func(t *testing.T) {
ejvw := newExtJSONWriterFromSlice(nil, true, true)
want := TransitionError{current: mTopLevel, destination: mode(0),
name: "WriteBinaryWithSubtype", modes: []mode{mElement, mValue}, action: "write"}
got := ejvw.WriteBinaryWithSubtype(nil, (byte)(TypeEmbeddedDocument))
if !assert.CompareErrors(got, want) {
t.Errorf("Did not received expected error. got %v; want %v", got, want)
}
})
})
t.Run("FormatDoubleWithExponent", func(t *testing.T) {
want := "3E-12"
got := formatDouble(float64(0.000000000003))
if got != want {
t.Errorf("Did not receive expected string. got %s: want %s", got, want)
}
})
}

40
fuzz_test.go Normal file
View File

@ -0,0 +1,40 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"testing"
)
func FuzzDecode(f *testing.F) {
seedBSONCorpus(f)
f.Fuzz(func(t *testing.T, data []byte) {
for _, typ := range []func() interface{}{
func() interface{} { return new(D) },
func() interface{} { return new([]E) },
func() interface{} { return new(M) },
func() interface{} { return new(interface{}) },
func() interface{} { return make(map[string]interface{}) },
func() interface{} { return new([]interface{}) },
} {
i := typ()
if err := Unmarshal(data, i); err != nil {
return
}
encoded, err := Marshal(i)
if err != nil {
t.Fatal("failed to marshal", err)
}
if err := Unmarshal(encoded, i); err != nil {
t.Fatal("failed to unmarshal", err)
}
}
})
}

12
go.mod Normal file
View File

@ -0,0 +1,12 @@
module gitea.psichedelico.com/go/bson
go 1.23.0
toolchain go1.23.7
require (
github.com/google/go-cmp v0.7.0
golang.org/x/sync v0.12.0
)
require github.com/davecgh/go-spew v1.1.1

6
go.sum Normal file
View File

@ -0,0 +1,6 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=

View File

@ -0,0 +1,481 @@
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare.go
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
// Use of this source code is governed by an MIT-style license that can be found in
// the THIRD-PARTY-NOTICES file.
package assert
import (
"bytes"
"fmt"
"reflect"
"time"
)
type CompareType int
const (
compareLess CompareType = iota - 1
compareEqual
compareGreater
)
var (
intType = reflect.TypeOf(int(1))
int8Type = reflect.TypeOf(int8(1))
int16Type = reflect.TypeOf(int16(1))
int32Type = reflect.TypeOf(int32(1))
int64Type = reflect.TypeOf(int64(1))
uintType = reflect.TypeOf(uint(1))
uint8Type = reflect.TypeOf(uint8(1))
uint16Type = reflect.TypeOf(uint16(1))
uint32Type = reflect.TypeOf(uint32(1))
uint64Type = reflect.TypeOf(uint64(1))
float32Type = reflect.TypeOf(float32(1))
float64Type = reflect.TypeOf(float64(1))
stringType = reflect.TypeOf("")
timeType = reflect.TypeOf(time.Time{})
bytesType = reflect.TypeOf([]byte{})
)
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
obj1Value := reflect.ValueOf(obj1)
obj2Value := reflect.ValueOf(obj2)
// throughout this switch we try and avoid calling .Convert() if possible,
// as this has a pretty big performance impact
switch kind {
case reflect.Int:
{
intobj1, ok := obj1.(int)
if !ok {
intobj1 = obj1Value.Convert(intType).Interface().(int)
}
intobj2, ok := obj2.(int)
if !ok {
intobj2 = obj2Value.Convert(intType).Interface().(int)
}
if intobj1 > intobj2 {
return compareGreater, true
}
if intobj1 == intobj2 {
return compareEqual, true
}
if intobj1 < intobj2 {
return compareLess, true
}
}
case reflect.Int8:
{
int8obj1, ok := obj1.(int8)
if !ok {
int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
}
int8obj2, ok := obj2.(int8)
if !ok {
int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
}
if int8obj1 > int8obj2 {
return compareGreater, true
}
if int8obj1 == int8obj2 {
return compareEqual, true
}
if int8obj1 < int8obj2 {
return compareLess, true
}
}
case reflect.Int16:
{
int16obj1, ok := obj1.(int16)
if !ok {
int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
}
int16obj2, ok := obj2.(int16)
if !ok {
int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
}
if int16obj1 > int16obj2 {
return compareGreater, true
}
if int16obj1 == int16obj2 {
return compareEqual, true
}
if int16obj1 < int16obj2 {
return compareLess, true
}
}
case reflect.Int32:
{
int32obj1, ok := obj1.(int32)
if !ok {
int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
}
int32obj2, ok := obj2.(int32)
if !ok {
int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
}
if int32obj1 > int32obj2 {
return compareGreater, true
}
if int32obj1 == int32obj2 {
return compareEqual, true
}
if int32obj1 < int32obj2 {
return compareLess, true
}
}
case reflect.Int64:
{
int64obj1, ok := obj1.(int64)
if !ok {
int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
}
int64obj2, ok := obj2.(int64)
if !ok {
int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
}
if int64obj1 > int64obj2 {
return compareGreater, true
}
if int64obj1 == int64obj2 {
return compareEqual, true
}
if int64obj1 < int64obj2 {
return compareLess, true
}
}
case reflect.Uint:
{
uintobj1, ok := obj1.(uint)
if !ok {
uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
}
uintobj2, ok := obj2.(uint)
if !ok {
uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
}
if uintobj1 > uintobj2 {
return compareGreater, true
}
if uintobj1 == uintobj2 {
return compareEqual, true
}
if uintobj1 < uintobj2 {
return compareLess, true
}
}
case reflect.Uint8:
{
uint8obj1, ok := obj1.(uint8)
if !ok {
uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
}
uint8obj2, ok := obj2.(uint8)
if !ok {
uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
}
if uint8obj1 > uint8obj2 {
return compareGreater, true
}
if uint8obj1 == uint8obj2 {
return compareEqual, true
}
if uint8obj1 < uint8obj2 {
return compareLess, true
}
}
case reflect.Uint16:
{
uint16obj1, ok := obj1.(uint16)
if !ok {
uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
}
uint16obj2, ok := obj2.(uint16)
if !ok {
uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
}
if uint16obj1 > uint16obj2 {
return compareGreater, true
}
if uint16obj1 == uint16obj2 {
return compareEqual, true
}
if uint16obj1 < uint16obj2 {
return compareLess, true
}
}
case reflect.Uint32:
{
uint32obj1, ok := obj1.(uint32)
if !ok {
uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
}
uint32obj2, ok := obj2.(uint32)
if !ok {
uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
}
if uint32obj1 > uint32obj2 {
return compareGreater, true
}
if uint32obj1 == uint32obj2 {
return compareEqual, true
}
if uint32obj1 < uint32obj2 {
return compareLess, true
}
}
case reflect.Uint64:
{
uint64obj1, ok := obj1.(uint64)
if !ok {
uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
}
uint64obj2, ok := obj2.(uint64)
if !ok {
uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
}
if uint64obj1 > uint64obj2 {
return compareGreater, true
}
if uint64obj1 == uint64obj2 {
return compareEqual, true
}
if uint64obj1 < uint64obj2 {
return compareLess, true
}
}
case reflect.Float32:
{
float32obj1, ok := obj1.(float32)
if !ok {
float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
}
float32obj2, ok := obj2.(float32)
if !ok {
float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
}
if float32obj1 > float32obj2 {
return compareGreater, true
}
if float32obj1 == float32obj2 {
return compareEqual, true
}
if float32obj1 < float32obj2 {
return compareLess, true
}
}
case reflect.Float64:
{
float64obj1, ok := obj1.(float64)
if !ok {
float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
}
float64obj2, ok := obj2.(float64)
if !ok {
float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
}
if float64obj1 > float64obj2 {
return compareGreater, true
}
if float64obj1 == float64obj2 {
return compareEqual, true
}
if float64obj1 < float64obj2 {
return compareLess, true
}
}
case reflect.String:
{
stringobj1, ok := obj1.(string)
if !ok {
stringobj1 = obj1Value.Convert(stringType).Interface().(string)
}
stringobj2, ok := obj2.(string)
if !ok {
stringobj2 = obj2Value.Convert(stringType).Interface().(string)
}
if stringobj1 > stringobj2 {
return compareGreater, true
}
if stringobj1 == stringobj2 {
return compareEqual, true
}
if stringobj1 < stringobj2 {
return compareLess, true
}
}
// Check for known struct types we can check for compare results.
case reflect.Struct:
{
// All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) {
break
}
// time.Time can compared!
timeObj1, ok := obj1.(time.Time)
if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
}
timeObj2, ok := obj2.(time.Time)
if !ok {
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
}
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
}
case reflect.Slice:
{
// We only care about the []byte type.
if !canConvert(obj1Value, bytesType) {
break
}
// []byte can be compared!
bytesObj1, ok := obj1.([]byte)
if !ok {
bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
}
bytesObj2, ok := obj2.([]byte)
if !ok {
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
}
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
}
}
return compareEqual, false
}
// Greater asserts that the first element is greater than the second
//
// assert.Greater(t, 2, 1)
// assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a")
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// GreaterOrEqual asserts that the first element is greater than or equal to the second
//
// assert.GreaterOrEqual(t, 2, 1)
// assert.GreaterOrEqual(t, 2, 2)
// assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b")
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// Less asserts that the first element is less than the second
//
// assert.Less(t, 1, 2)
// assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b")
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// LessOrEqual asserts that the first element is less than or equal to the second
//
// assert.LessOrEqual(t, 1, 2)
// assert.LessOrEqual(t, 2, 2)
// assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b")
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}
// Positive asserts that the specified element is positive
//
// assert.Positive(t, 1)
// assert.Positive(t, 1.23)
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
}
// Negative asserts that the specified element is negative
//
// assert.Negative(t, -1)
// assert.Negative(t, -1.23)
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
zero := reflect.Zero(reflect.TypeOf(e))
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
}
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
e1Kind := reflect.ValueOf(e1).Kind()
e2Kind := reflect.ValueOf(e2).Kind()
if e1Kind != e2Kind {
return Fail(t, "Elements should be the same type", msgAndArgs...)
}
compareResult, isComparable := compare(e1, e2, e1Kind)
if !isComparable {
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
}
if !containsValue(allowedComparesResults, compareResult) {
return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
}
return true
}
func containsValue(values []CompareType, value CompareType) bool {
for _, v := range values {
if v == value {
return true
}
}
return false
}
// CompareErrors asserts two errors
func CompareErrors(err1, err2 error) bool {
if err1 == nil && err2 == nil {
return true
}
if err1 == nil || err2 == nil {
return false
}
if err1.Error() != err2.Error() {
return false
}
return true
}

View File

@ -0,0 +1,18 @@
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_can_convert.go
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
// Use of this source code is governed by an MIT-style license that can be found in
// the THIRD-PARTY-NOTICES file.
//go:build go1.17
// +build go1.17
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatibility
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View File

@ -0,0 +1,184 @@
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_go1.17_test.go
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
// Use of this source code is governed by an MIT-style license that can be found in
// the THIRD-PARTY-NOTICES file.
//go:build go1.17
// +build go1.17
package assert
import (
"bytes"
"reflect"
"testing"
"time"
)
func TestCompare17(t *testing.T) {
type customTime time.Time
type customBytes []byte
for _, currCase := range []struct {
less interface{}
greater interface{}
cType string
}{
{less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"},
{less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"},
{less: []byte{1, 1}, greater: []byte{1, 2}, cType: "[]byte"},
{less: customBytes([]byte{1, 1}), greater: customBytes([]byte{1, 2}), cType: "[]byte"},
} {
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
if !isComparable {
t.Error("object should be comparable for type " + currCase.cType)
}
if resLess != compareLess {
t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType,
currCase.less, currCase.greater)
}
resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())
if !isComparable {
t.Error("object are comparable for type " + currCase.cType)
}
if resGreater != compareGreater {
t.Errorf("object greater should be greater than less for type " + currCase.cType)
}
resEqual, isComparable := compare(currCase.less, currCase.less, reflect.ValueOf(currCase.less).Kind())
if !isComparable {
t.Error("object are comparable for type " + currCase.cType)
}
if resEqual != 0 {
t.Errorf("objects should be equal for type " + currCase.cType)
}
}
}
func TestGreater17(t *testing.T) {
mockT := new(testing.T)
if !Greater(mockT, 2, 1) {
t.Error("Greater should return true")
}
if Greater(mockT, 1, 1) {
t.Error("Greater should return false")
}
if Greater(mockT, 1, 2) {
t.Error("Greater should return false")
}
// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than "[1 2]"`},
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than "0001-01-01 01:00:00 +0000 UTC"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, Greater(out, currCase.less, currCase.greater))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Greater")
}
}
func TestGreaterOrEqual17(t *testing.T) {
mockT := new(testing.T)
if !GreaterOrEqual(mockT, 2, 1) {
t.Error("GreaterOrEqual should return true")
}
if !GreaterOrEqual(mockT, 1, 1) {
t.Error("GreaterOrEqual should return true")
}
if GreaterOrEqual(mockT, 1, 2) {
t.Error("GreaterOrEqual should return false")
}
// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than or equal to "[1 2]"`},
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than or equal to "0001-01-01 01:00:00 +0000 UTC"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, GreaterOrEqual(out, currCase.less, currCase.greater))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.GreaterOrEqual")
}
}
func TestLess17(t *testing.T) {
mockT := new(testing.T)
if !Less(mockT, 1, 2) {
t.Error("Less should return true")
}
if Less(mockT, 1, 1) {
t.Error("Less should return false")
}
if Less(mockT, 2, 1) {
t.Error("Less should return false")
}
// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than "[1 1]"`},
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than "0001-01-01 00:00:00 +0000 UTC"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, Less(out, currCase.greater, currCase.less))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Less")
}
}
func TestLessOrEqual17(t *testing.T) {
mockT := new(testing.T)
if !LessOrEqual(mockT, 1, 2) {
t.Error("LessOrEqual should return true")
}
if !LessOrEqual(mockT, 1, 1) {
t.Error("LessOrEqual should return true")
}
if LessOrEqual(mockT, 2, 1) {
t.Error("LessOrEqual should return false")
}
// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than or equal to "[1 1]"`},
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than or equal to "0001-01-01 00:00:00 +0000 UTC"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, LessOrEqual(out, currCase.greater, currCase.less))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.LessOrEqual")
}
}

View File

@ -0,0 +1,18 @@
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_legacy.go
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
// Use of this source code is governed by an MIT-style license that can be found in
// the THIRD-PARTY-NOTICES file.
//go:build !go1.17
// +build !go1.17
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View File

@ -0,0 +1,455 @@
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_test.go
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
// Use of this source code is governed by an MIT-style license that can be found in
// the THIRD-PARTY-NOTICES file.
package assert
import (
"bytes"
"fmt"
"reflect"
"runtime"
"testing"
)
func TestCompare(t *testing.T) {
type customInt int
type customInt8 int8
type customInt16 int16
type customInt32 int32
type customInt64 int64
type customUInt uint
type customUInt8 uint8
type customUInt16 uint16
type customUInt32 uint32
type customUInt64 uint64
type customFloat32 float32
type customFloat64 float64
type customString string
for _, currCase := range []struct {
less interface{}
greater interface{}
cType string
}{
{less: customString("a"), greater: customString("b"), cType: "string"},
{less: "a", greater: "b", cType: "string"},
{less: customInt(1), greater: customInt(2), cType: "int"},
{less: int(1), greater: int(2), cType: "int"},
{less: customInt8(1), greater: customInt8(2), cType: "int8"},
{less: int8(1), greater: int8(2), cType: "int8"},
{less: customInt16(1), greater: customInt16(2), cType: "int16"},
{less: int16(1), greater: int16(2), cType: "int16"},
{less: customInt32(1), greater: customInt32(2), cType: "int32"},
{less: int32(1), greater: int32(2), cType: "int32"},
{less: customInt64(1), greater: customInt64(2), cType: "int64"},
{less: int64(1), greater: int64(2), cType: "int64"},
{less: customUInt(1), greater: customUInt(2), cType: "uint"},
{less: uint8(1), greater: uint8(2), cType: "uint8"},
{less: customUInt8(1), greater: customUInt8(2), cType: "uint8"},
{less: uint16(1), greater: uint16(2), cType: "uint16"},
{less: customUInt16(1), greater: customUInt16(2), cType: "uint16"},
{less: uint32(1), greater: uint32(2), cType: "uint32"},
{less: customUInt32(1), greater: customUInt32(2), cType: "uint32"},
{less: uint64(1), greater: uint64(2), cType: "uint64"},
{less: customUInt64(1), greater: customUInt64(2), cType: "uint64"},
{less: float32(1.23), greater: float32(2.34), cType: "float32"},
{less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"},
{less: float64(1.23), greater: float64(2.34), cType: "float64"},
{less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"},
} {
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
if !isComparable {
t.Error("object should be comparable for type " + currCase.cType)
}
if resLess != compareLess {
t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType,
currCase.less, currCase.greater)
}
resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())
if !isComparable {
t.Error("object are comparable for type " + currCase.cType)
}
if resGreater != compareGreater {
t.Errorf("object greater should be greater than less for type " + currCase.cType)
}
resEqual, isComparable := compare(currCase.less, currCase.less, reflect.ValueOf(currCase.less).Kind())
if !isComparable {
t.Error("object are comparable for type " + currCase.cType)
}
if resEqual != 0 {
t.Errorf("objects should be equal for type " + currCase.cType)
}
}
}
type outputT struct {
buf *bytes.Buffer
helpers map[string]struct{}
}
// Implements TestingT
func (t *outputT) Errorf(format string, args ...interface{}) {
s := fmt.Sprintf(format, args...)
t.buf.WriteString(s)
}
func (t *outputT) Helper() {
if t.helpers == nil {
t.helpers = make(map[string]struct{})
}
t.helpers[callerName(1)] = struct{}{}
}
// callerName gives the function name (qualified with a package path)
// for the caller after skip frames (where 0 means the current function).
func callerName(skip int) string {
// Make room for the skip PC.
var pc [1]uintptr
n := runtime.Callers(skip+2, pc[:]) // skip + runtime.Callers + callerName
if n == 0 {
panic("testing: zero callers found")
}
frames := runtime.CallersFrames(pc[:n])
frame, _ := frames.Next()
return frame.Function
}
func TestGreater(t *testing.T) {
mockT := new(testing.T)
if !Greater(mockT, 2, 1) {
t.Error("Greater should return true")
}
if Greater(mockT, 1, 1) {
t.Error("Greater should return false")
}
if Greater(mockT, 1, 2) {
t.Error("Greater should return false")
}
// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: "a", greater: "b", msg: `"a" is not greater than "b"`},
{less: int(1), greater: int(2), msg: `"1" is not greater than "2"`},
{less: int8(1), greater: int8(2), msg: `"1" is not greater than "2"`},
{less: int16(1), greater: int16(2), msg: `"1" is not greater than "2"`},
{less: int32(1), greater: int32(2), msg: `"1" is not greater than "2"`},
{less: int64(1), greater: int64(2), msg: `"1" is not greater than "2"`},
{less: uint8(1), greater: uint8(2), msg: `"1" is not greater than "2"`},
{less: uint16(1), greater: uint16(2), msg: `"1" is not greater than "2"`},
{less: uint32(1), greater: uint32(2), msg: `"1" is not greater than "2"`},
{less: uint64(1), greater: uint64(2), msg: `"1" is not greater than "2"`},
{less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than "2.34"`},
{less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than "2.34"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, Greater(out, currCase.less, currCase.greater))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Greater")
}
}
func TestGreaterOrEqual(t *testing.T) {
mockT := new(testing.T)
if !GreaterOrEqual(mockT, 2, 1) {
t.Error("GreaterOrEqual should return true")
}
if !GreaterOrEqual(mockT, 1, 1) {
t.Error("GreaterOrEqual should return true")
}
if GreaterOrEqual(mockT, 1, 2) {
t.Error("GreaterOrEqual should return false")
}
// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: "a", greater: "b", msg: `"a" is not greater than or equal to "b"`},
{less: int(1), greater: int(2), msg: `"1" is not greater than or equal to "2"`},
{less: int8(1), greater: int8(2), msg: `"1" is not greater than or equal to "2"`},
{less: int16(1), greater: int16(2), msg: `"1" is not greater than or equal to "2"`},
{less: int32(1), greater: int32(2), msg: `"1" is not greater than or equal to "2"`},
{less: int64(1), greater: int64(2), msg: `"1" is not greater than or equal to "2"`},
{less: uint8(1), greater: uint8(2), msg: `"1" is not greater than or equal to "2"`},
{less: uint16(1), greater: uint16(2), msg: `"1" is not greater than or equal to "2"`},
{less: uint32(1), greater: uint32(2), msg: `"1" is not greater than or equal to "2"`},
{less: uint64(1), greater: uint64(2), msg: `"1" is not greater than or equal to "2"`},
{less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than or equal to "2.34"`},
{less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than or equal to "2.34"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, GreaterOrEqual(out, currCase.less, currCase.greater))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.GreaterOrEqual")
}
}
func TestLess(t *testing.T) {
mockT := new(testing.T)
if !Less(mockT, 1, 2) {
t.Error("Less should return true")
}
if Less(mockT, 1, 1) {
t.Error("Less should return false")
}
if Less(mockT, 2, 1) {
t.Error("Less should return false")
}
// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: "a", greater: "b", msg: `"b" is not less than "a"`},
{less: int(1), greater: int(2), msg: `"2" is not less than "1"`},
{less: int8(1), greater: int8(2), msg: `"2" is not less than "1"`},
{less: int16(1), greater: int16(2), msg: `"2" is not less than "1"`},
{less: int32(1), greater: int32(2), msg: `"2" is not less than "1"`},
{less: int64(1), greater: int64(2), msg: `"2" is not less than "1"`},
{less: uint8(1), greater: uint8(2), msg: `"2" is not less than "1"`},
{less: uint16(1), greater: uint16(2), msg: `"2" is not less than "1"`},
{less: uint32(1), greater: uint32(2), msg: `"2" is not less than "1"`},
{less: uint64(1), greater: uint64(2), msg: `"2" is not less than "1"`},
{less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than "1.23"`},
{less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than "1.23"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, Less(out, currCase.greater, currCase.less))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Less")
}
}
func TestLessOrEqual(t *testing.T) {
mockT := new(testing.T)
if !LessOrEqual(mockT, 1, 2) {
t.Error("LessOrEqual should return true")
}
if !LessOrEqual(mockT, 1, 1) {
t.Error("LessOrEqual should return true")
}
if LessOrEqual(mockT, 2, 1) {
t.Error("LessOrEqual should return false")
}
// Check error report
for _, currCase := range []struct {
less interface{}
greater interface{}
msg string
}{
{less: "a", greater: "b", msg: `"b" is not less than or equal to "a"`},
{less: int(1), greater: int(2), msg: `"2" is not less than or equal to "1"`},
{less: int8(1), greater: int8(2), msg: `"2" is not less than or equal to "1"`},
{less: int16(1), greater: int16(2), msg: `"2" is not less than or equal to "1"`},
{less: int32(1), greater: int32(2), msg: `"2" is not less than or equal to "1"`},
{less: int64(1), greater: int64(2), msg: `"2" is not less than or equal to "1"`},
{less: uint8(1), greater: uint8(2), msg: `"2" is not less than or equal to "1"`},
{less: uint16(1), greater: uint16(2), msg: `"2" is not less than or equal to "1"`},
{less: uint32(1), greater: uint32(2), msg: `"2" is not less than or equal to "1"`},
{less: uint64(1), greater: uint64(2), msg: `"2" is not less than or equal to "1"`},
{less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than or equal to "1.23"`},
{less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than or equal to "1.23"`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, LessOrEqual(out, currCase.greater, currCase.less))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.LessOrEqual")
}
}
func TestPositive(t *testing.T) {
mockT := new(testing.T)
if !Positive(mockT, 1) {
t.Error("Positive should return true")
}
if !Positive(mockT, 1.23) {
t.Error("Positive should return true")
}
if Positive(mockT, -1) {
t.Error("Positive should return false")
}
if Positive(mockT, -1.23) {
t.Error("Positive should return false")
}
// Check error report
for _, currCase := range []struct {
e interface{}
msg string
}{
{e: int(-1), msg: `"-1" is not positive`},
{e: int8(-1), msg: `"-1" is not positive`},
{e: int16(-1), msg: `"-1" is not positive`},
{e: int32(-1), msg: `"-1" is not positive`},
{e: int64(-1), msg: `"-1" is not positive`},
{e: float32(-1.23), msg: `"-1.23" is not positive`},
{e: float64(-1.23), msg: `"-1.23" is not positive`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, Positive(out, currCase.e))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Positive")
}
}
func TestNegative(t *testing.T) {
mockT := new(testing.T)
if !Negative(mockT, -1) {
t.Error("Negative should return true")
}
if !Negative(mockT, -1.23) {
t.Error("Negative should return true")
}
if Negative(mockT, 1) {
t.Error("Negative should return false")
}
if Negative(mockT, 1.23) {
t.Error("Negative should return false")
}
// Check error report
for _, currCase := range []struct {
e interface{}
msg string
}{
{e: int(1), msg: `"1" is not negative`},
{e: int8(1), msg: `"1" is not negative`},
{e: int16(1), msg: `"1" is not negative`},
{e: int32(1), msg: `"1" is not negative`},
{e: int64(1), msg: `"1" is not negative`},
{e: float32(1.23), msg: `"1.23" is not negative`},
{e: float64(1.23), msg: `"1.23" is not negative`},
} {
out := &outputT{buf: bytes.NewBuffer(nil)}
False(t, Negative(out, currCase.e))
Contains(t, out.buf.String(), currCase.msg)
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Negative")
}
}
func Test_compareTwoValuesDifferentValuesTypes(t *testing.T) {
mockT := new(testing.T)
for _, currCase := range []struct {
v1 interface{}
v2 interface{}
compareResult bool
}{
{v1: 123, v2: "abc"},
{v1: "abc", v2: 123456},
{v1: float64(12), v2: "123"},
{v1: "float(12)", v2: float64(1)},
} {
compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, []CompareType{compareLess, compareEqual, compareGreater}, "testFailMessage")
False(t, compareResult)
}
}
func Test_compareTwoValuesNotComparableValues(t *testing.T) {
mockT := new(testing.T)
type CompareStruct struct {
}
for _, currCase := range []struct {
v1 interface{}
v2 interface{}
}{
{v1: CompareStruct{}, v2: CompareStruct{}},
{v1: map[string]int{}, v2: map[string]int{}},
{v1: make([]int, 5), v2: make([]int, 5)},
} {
compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, []CompareType{compareLess, compareEqual, compareGreater}, "testFailMessage")
False(t, compareResult)
}
}
func Test_compareTwoValuesCorrectCompareResult(t *testing.T) {
mockT := new(testing.T)
for _, currCase := range []struct {
v1 interface{}
v2 interface{}
compareTypes []CompareType
}{
{v1: 1, v2: 2, compareTypes: []CompareType{compareLess}},
{v1: 1, v2: 2, compareTypes: []CompareType{compareLess, compareEqual}},
{v1: 2, v2: 2, compareTypes: []CompareType{compareGreater, compareEqual}},
{v1: 2, v2: 2, compareTypes: []CompareType{compareEqual}},
{v1: 2, v2: 1, compareTypes: []CompareType{compareEqual, compareGreater}},
{v1: 2, v2: 1, compareTypes: []CompareType{compareGreater}},
} {
compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, currCase.compareTypes, "testFailMessage")
True(t, compareResult)
}
}
func Test_containsValue(t *testing.T) {
for _, currCase := range []struct {
values []CompareType
value CompareType
result bool
}{
{values: []CompareType{compareGreater}, value: compareGreater, result: true},
{values: []CompareType{compareGreater, compareLess}, value: compareGreater, result: true},
{values: []CompareType{compareGreater, compareLess}, value: compareLess, result: true},
{values: []CompareType{compareGreater, compareLess}, value: compareEqual, result: false},
} {
compareResult := containsValue(currCase.values, currCase.value)
Equal(t, currCase.result, compareResult)
}
}
func TestComparingMsgAndArgsForwarding(t *testing.T) {
msgAndArgs := []interface{}{"format %s %x", "this", 0xc001}
expectedOutput := "format this c001\n"
funcs := []func(t TestingT){
func(t TestingT) { Greater(t, 1, 2, msgAndArgs...) },
func(t TestingT) { GreaterOrEqual(t, 1, 2, msgAndArgs...) },
func(t TestingT) { Less(t, 2, 1, msgAndArgs...) },
func(t TestingT) { LessOrEqual(t, 2, 1, msgAndArgs...) },
func(t TestingT) { Positive(t, 0, msgAndArgs...) },
func(t TestingT) { Negative(t, 0, msgAndArgs...) },
}
for _, f := range funcs {
out := &outputT{buf: bytes.NewBuffer(nil)}
f(out)
Contains(t, out.buf.String(), expectedOutput)
}
}

View File

@ -0,0 +1,325 @@
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_format.go
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
// Use of this source code is governed by an MIT-style license that can be found in
// the THIRD-PARTY-NOTICES file.
package assert
import (
time "time"
)
// Containsf asserts that the specified string, list(array, slice...) or map contains the
// specified substring or element.
//
// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted")
// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted")
// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted")
func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Contains(t, s, contains, append([]interface{}{msg}, args...)...)
}
// ElementsMatchf asserts that the specified listA(array, slice...) is equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should match.
//
// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted")
func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...)
}
// Equalf asserts that two objects are equal.
//
// assert.Equalf(t, 123, 123, "error message %s", "formatted")
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses). Function equality
// cannot be determined and will always fail.
func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Equal(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// EqualErrorf asserts that a function returned an error (i.e. not `nil`)
// and that it is equal to the provided error.
//
// actualObj, err := SomeFunction()
// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted")
func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...)
}
// EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal.
//
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return EqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// Errorf asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if assert.Errorf(t, err, "error message %s", "formatted") {
// assert.Equal(t, expectedErrorf, err)
// }
func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Error(t, err, append([]interface{}{msg}, args...)...)
}
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
}
// Eventuallyf asserts that given condition will be met in waitFor time,
// periodically checking target function each tick.
//
// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Eventually(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...)
}
// Failf reports a failure through
func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, failureMessage, append([]interface{}{msg}, args...)...)
}
// FailNowf fails test
func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return FailNow(t, failureMessage, append([]interface{}{msg}, args...)...)
}
// Falsef asserts that the specified value is false.
//
// assert.Falsef(t, myBool, "error message %s", "formatted")
func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return False(t, value, append([]interface{}{msg}, args...)...)
}
// Greaterf asserts that the first element is greater than the second
//
// assert.Greaterf(t, 2, 1, "error message %s", "formatted")
// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted")
// assert.Greaterf(t, "b", "a", "error message %s", "formatted")
func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Greater(t, e1, e2, append([]interface{}{msg}, args...)...)
}
// GreaterOrEqualf asserts that the first element is greater than or equal to the second
//
// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted")
// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted")
// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted")
// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted")
func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return GreaterOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...)
}
// InDeltaf asserts that the two numerals are within delta of each other.
//
// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted")
func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return InDelta(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
}
// IsTypef asserts that the specified objects are of the same type.
func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return IsType(t, expectedType, object, append([]interface{}{msg}, args...)...)
}
// Lenf asserts that the specified object has specific length.
// Lenf also fails if the object has a type that len() not accept.
//
// assert.Lenf(t, mySlice, 3, "error message %s", "formatted")
func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Len(t, object, length, append([]interface{}{msg}, args...)...)
}
// Lessf asserts that the first element is less than the second
//
// assert.Lessf(t, 1, 2, "error message %s", "formatted")
// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted")
// assert.Lessf(t, "a", "b", "error message %s", "formatted")
func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Less(t, e1, e2, append([]interface{}{msg}, args...)...)
}
// LessOrEqualf asserts that the first element is less than or equal to the second
//
// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted")
// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted")
// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted")
// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted")
func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return LessOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...)
}
// Negativef asserts that the specified element is negative
//
// assert.Negativef(t, -1, "error message %s", "formatted")
// assert.Negativef(t, -1.23, "error message %s", "formatted")
func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Negative(t, e, append([]interface{}{msg}, args...)...)
}
// Nilf asserts that the specified object is nil.
//
// assert.Nilf(t, err, "error message %s", "formatted")
func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Nil(t, object, append([]interface{}{msg}, args...)...)
}
// NoErrorf asserts that a function returned no error (i.e. `nil`).
//
// actualObj, err := SomeFunction()
// if assert.NoErrorf(t, err, "error message %s", "formatted") {
// assert.Equal(t, expectedObj, actualObj)
// }
func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NoError(t, err, append([]interface{}{msg}, args...)...)
}
// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the
// specified substring or element.
//
// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted")
// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted")
// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted")
func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotContains(t, s, contains, append([]interface{}{msg}, args...)...)
}
// NotEqualf asserts that the specified values are NOT equal.
//
// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted")
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses).
func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// NotEqualValuesf asserts that two objects are not equal even when converted to the same type
//
// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted")
func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// NotNilf asserts that the specified object is not nil.
//
// assert.NotNilf(t, err, "error message %s", "formatted")
func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotNil(t, object, append([]interface{}{msg}, args...)...)
}
// Positivef asserts that the specified element is positive
//
// assert.Positivef(t, 1, "error message %s", "formatted")
// assert.Positivef(t, 1.23, "error message %s", "formatted")
func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Positive(t, e, append([]interface{}{msg}, args...)...)
}
// Truef asserts that the specified value is true.
//
// assert.Truef(t, myBool, "error message %s", "formatted")
func Truef(t TestingT, value bool, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return True(t, value, append([]interface{}{msg}, args...)...)
}
// WithinDurationf asserts that the two times are within duration delta of each other.
//
// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted")
func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
}

View File

@ -0,0 +1,126 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// assertion_mongo.go contains MongoDB-specific extensions to the "assert"
// package.
package assert
import (
"context"
"fmt"
"reflect"
"time"
"unsafe"
)
// DifferentAddressRanges asserts that two byte slices reference distinct memory
// address ranges, meaning they reference different underlying byte arrays.
func DifferentAddressRanges(t TestingT, a, b []byte) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if len(a) == 0 || len(b) == 0 {
return true
}
// Find the start and end memory addresses for the underlying byte array for
// each input byte slice.
sliceAddrRange := func(b []byte) (uintptr, uintptr) {
sh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
return sh.Data, sh.Data + uintptr(sh.Cap-1)
}
aStart, aEnd := sliceAddrRange(a)
bStart, bEnd := sliceAddrRange(b)
// If "b" starts after "a" ends or "a" starts after "b" ends, there is no
// overlap.
if bStart > aEnd || aStart > bEnd {
return true
}
// Otherwise, calculate the overlap start and end and print the memory
// overlap error message.
min := func(a, b uintptr) uintptr {
if a < b {
return a
}
return b
}
max := func(a, b uintptr) uintptr {
if a > b {
return a
}
return b
}
overlapLow := max(aStart, bStart)
overlapHigh := min(aEnd, bEnd)
t.Errorf("Byte slices point to the same underlying byte array:\n"+
"\ta addresses:\t%d ... %d\n"+
"\tb addresses:\t%d ... %d\n"+
"\toverlap:\t%d ... %d",
aStart, aEnd,
bStart, bEnd,
overlapLow, overlapHigh)
return false
}
// EqualBSON asserts that the expected and actual BSON binary values are equal.
// If the values are not equal, it prints both the binary and Extended JSON diff
// of the BSON values. The provided BSON value types must implement the
// fmt.Stringer interface.
func EqualBSON(t TestingT, expected, actual interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Equal(t,
expected,
actual,
`expected and actual BSON values do not match
As Extended JSON:
Expected: %s
Actual : %s`,
expected.(fmt.Stringer).String(),
actual.(fmt.Stringer).String())
}
// Soon runs the provided callback and fails the passed-in test if the callback
// does not complete within timeout. The provided callback should respect the
// passed-in context and cease execution when it has expired.
//
// Deprecated: This function will be removed with GODRIVER-2667, use
// assert.Eventually instead.
func Soon(t TestingT, callback func(ctx context.Context), timeout time.Duration) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
// Create context to manually cancel callback after Soon assertion.
callbackCtx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
fullCallback := func() {
callback(callbackCtx)
done <- struct{}{}
}
timer := time.NewTimer(timeout)
defer timer.Stop()
go fullCallback()
select {
case <-done:
return
case <-timer.C:
t.Errorf("timed out in %s waiting for callback", timeout)
}
}

View File

@ -0,0 +1,125 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package assert
import (
"testing"
"gitea.psichedelico.com/go/bson"
)
func TestDifferentAddressRanges(t *testing.T) {
t.Parallel()
slice := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
testCases := []struct {
name string
a []byte
b []byte
want bool
}{
{
name: "distinct byte slices",
a: []byte{0, 1, 2, 3},
b: []byte{0, 1, 2, 3},
want: true,
},
{
name: "same byte slice",
a: slice,
b: slice,
want: false,
},
{
name: "whole and subslice",
a: slice,
b: slice[:4],
want: false,
},
{
name: "two subslices",
a: slice[1:2],
b: slice[3:4],
want: false,
},
{
name: "empty",
a: []byte{0, 1, 2, 3},
b: []byte{},
want: true,
},
{
name: "nil",
a: []byte{0, 1, 2, 3},
b: nil,
want: true,
},
}
for _, tc := range testCases {
tc := tc // Capture range variable.
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := DifferentAddressRanges(new(testing.T), tc.a, tc.b)
if got != tc.want {
t.Errorf("DifferentAddressRanges(%p, %p) = %v, want %v", tc.a, tc.b, got, tc.want)
}
})
}
}
func TestEqualBSON(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
expected interface{}
actual interface{}
want bool
}{
{
name: "equal bson.Raw",
expected: bson.Raw{5, 0, 0, 0, 0},
actual: bson.Raw{5, 0, 0, 0, 0},
want: true,
},
{
name: "different bson.Raw",
expected: bson.Raw{8, 0, 0, 0, 10, 120, 0, 0},
actual: bson.Raw{5, 0, 0, 0, 0},
want: false,
},
{
name: "invalid bson.Raw",
expected: bson.Raw{99, 99, 99, 99},
actual: bson.Raw{5, 0, 0, 0, 0},
want: false,
},
{
name: "nil bson.Raw",
expected: bson.Raw(nil),
actual: bson.Raw(nil),
want: true,
},
}
for _, tc := range testCases {
tc := tc // Capture range variable.
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := EqualBSON(new(testing.T), tc.expected, tc.actual)
if got != tc.want {
t.Errorf("EqualBSON(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want)
}
})
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

766
internal/assert/difflib.go Normal file
View File

@ -0,0 +1,766 @@
// Copied from https://github.com/pmezard/go-difflib/blob/5d4384ee4fb2527b0a1256a821ebfc92f91efefc/difflib/difflib.go
// Copyright 2013 Patrick Mezard. All rights reserved. Use of this source code is
// governed by a license that can be found in the THIRD-PARTY-NOTICES file.
package assert
import (
"bufio"
"bytes"
"fmt"
"io"
"strings"
)
func min(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func calculateRatio(matches, length int) float64 {
if length > 0 {
return 2.0 * float64(matches) / float64(length)
}
return 1.0
}
type Match struct {
A int
B int
Size int
}
type OpCode struct {
Tag byte
I1 int
I2 int
J1 int
J2 int
}
// SequenceMatcher compares sequence of strings. The basic
// algorithm predates, and is a little fancier than, an algorithm
// published in the late 1980's by Ratcliff and Obershelp under the
// hyperbolic name "gestalt pattern matching". The basic idea is to find
// the longest contiguous matching subsequence that contains no "junk"
// elements (R-O doesn't address junk). The same idea is then applied
// recursively to the pieces of the sequences to the left and to the right
// of the matching subsequence. This does not yield minimal edit
// sequences, but does tend to yield matches that "look right" to people.
//
// SequenceMatcher tries to compute a "human-friendly diff" between two
// sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the
// longest *contiguous* & junk-free matching subsequence. That's what
// catches peoples' eyes. The Windows(tm) windiff has another interesting
// notion, pairing up elements that appear uniquely in each sequence.
// That, and the method here, appear to yield more intuitive difference
// reports than does diff. This method appears to be the least vulnerable
// to syncing up on blocks of "junk lines", though (like blank lines in
// ordinary text files, or maybe "<P>" lines in HTML files). That may be
// because this is the only method of the 3 that has a *concept* of
// "junk" <wink>.
//
// Timing: Basic R-O is cubic time worst case and quadratic time expected
// case. SequenceMatcher is quadratic time for the worst case and has
// expected-case behavior dependent in a complicated way on how many
// elements the sequences have in common; best case time is linear.
type SequenceMatcher struct {
a []string
b []string
b2j map[string][]int
IsJunk func(string) bool
autoJunk bool
bJunk map[string]struct{}
matchingBlocks []Match
fullBCount map[string]int
bPopular map[string]struct{}
opCodes []OpCode
}
func NewMatcher(a, b []string) *SequenceMatcher {
m := SequenceMatcher{autoJunk: true}
m.SetSeqs(a, b)
return &m
}
func NewMatcherWithJunk(a, b []string, autoJunk bool,
isJunk func(string) bool) *SequenceMatcher {
m := SequenceMatcher{IsJunk: isJunk, autoJunk: autoJunk}
m.SetSeqs(a, b)
return &m
}
// SetSeqs sets the two sequences to be compared.
func (m *SequenceMatcher) SetSeqs(a, b []string) {
m.SetSeq1(a)
m.SetSeq2(b)
}
// SetSeq1 sets the first sequence to be compared. The second sequence to be compared is
// not changed.
//
// SequenceMatcher computes and caches detailed information about the second
// sequence, so if you want to compare one sequence S against many sequences,
// use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other
// sequences.
//
// See also SetSeqs() and SetSeq2().
func (m *SequenceMatcher) SetSeq1(a []string) {
if &a == &m.a {
return
}
m.a = a
m.matchingBlocks = nil
m.opCodes = nil
}
// SetSeq2 sets the second sequence to be compared. The first sequence to be compared is
// not changed.
func (m *SequenceMatcher) SetSeq2(b []string) {
if &b == &m.b {
return
}
m.b = b
m.matchingBlocks = nil
m.opCodes = nil
m.fullBCount = nil
m.chainB()
}
func (m *SequenceMatcher) chainB() {
// Populate line -> index mapping
b2j := map[string][]int{}
for i, s := range m.b {
indices := b2j[s]
indices = append(indices, i)
b2j[s] = indices
}
// Purge junk elements
m.bJunk = map[string]struct{}{}
if m.IsJunk != nil {
junk := m.bJunk
for s := range b2j {
if m.IsJunk(s) {
junk[s] = struct{}{}
}
}
for s := range junk {
delete(b2j, s)
}
}
// Purge remaining popular elements
popular := map[string]struct{}{}
n := len(m.b)
if m.autoJunk && n >= 200 {
ntest := n/100 + 1
for s, indices := range b2j {
if len(indices) > ntest {
popular[s] = struct{}{}
}
}
for s := range popular {
delete(b2j, s)
}
}
m.bPopular = popular
m.b2j = b2j
}
func (m *SequenceMatcher) isBJunk(s string) bool {
_, ok := m.bJunk[s]
return ok
}
// Find longest matching block in a[alo:ahi] and b[blo:bhi].
//
// If IsJunk is not defined:
//
// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where
//
// alo <= i <= i+k <= ahi
// blo <= j <= j+k <= bhi
//
// and for all (i',j',k') meeting those conditions,
//
// k >= k'
// i <= i'
// and if i == i', j <= j'
//
// In other words, of all maximal matching blocks, return one that
// starts earliest in a, and of all those maximal matching blocks that
// start earliest in a, return the one that starts earliest in b.
//
// If IsJunk is defined, first the longest matching block is
// determined as above, but with the additional restriction that no
// junk element appears in the block. Then that block is extended as
// far as possible by matching (only) junk elements on both sides. So
// the resulting block never matches on junk except as identical junk
// happens to be adjacent to an "interesting" match.
//
// If no blocks match, return (alo, blo, 0).
func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match {
// CAUTION: stripping common prefix or suffix would be incorrect.
// E.g.,
// ab
// acab
// Longest matching block is "ab", but if common prefix is
// stripped, it's "a" (tied with "b"). UNIX(tm) diff does so
// strip, so ends up claiming that ab is changed to acab by
// inserting "ca" in the middle. That's minimal but unintuitive:
// "it's obvious" that someone inserted "ac" at the front.
// Windiff ends up at the same place as diff, but by pairing up
// the unique 'b's and then matching the first two 'a's.
besti, bestj, bestsize := alo, blo, 0
// find longest junk-free match
// during an iteration of the loop, j2len[j] = length of longest
// junk-free match ending with a[i-1] and b[j]
j2len := map[int]int{}
for i := alo; i != ahi; i++ {
// look at all instances of a[i] in b; note that because
// b2j has no junk keys, the loop is skipped if a[i] is junk
newj2len := map[int]int{}
for _, j := range m.b2j[m.a[i]] {
// a[i] matches b[j]
if j < blo {
continue
}
if j >= bhi {
break
}
k := j2len[j-1] + 1
newj2len[j] = k
if k > bestsize {
besti, bestj, bestsize = i-k+1, j-k+1, k
}
}
j2len = newj2len
}
// Extend the best by non-junk elements on each end. In particular,
// "popular" non-junk elements aren't in b2j, which greatly speeds
// the inner loop above, but also means "the best" match so far
// doesn't contain any junk *or* popular non-junk elements.
for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) &&
m.a[besti-1] == m.b[bestj-1] {
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
}
for besti+bestsize < ahi && bestj+bestsize < bhi &&
!m.isBJunk(m.b[bestj+bestsize]) &&
m.a[besti+bestsize] == m.b[bestj+bestsize] {
bestsize++
}
// Now that we have a wholly interesting match (albeit possibly
// empty!), we may as well suck up the matching junk on each
// side of it too. Can't think of a good reason not to, and it
// saves post-processing the (possibly considerable) expense of
// figuring out what to do with it. In the case of an empty
// interesting match, this is clearly the right thing to do,
// because no other kind of match is possible in the regions.
for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) &&
m.a[besti-1] == m.b[bestj-1] {
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
}
for besti+bestsize < ahi && bestj+bestsize < bhi &&
m.isBJunk(m.b[bestj+bestsize]) &&
m.a[besti+bestsize] == m.b[bestj+bestsize] {
bestsize++
}
return Match{A: besti, B: bestj, Size: bestsize}
}
// GetMatchingBlocks returns list of triples describing matching subsequences.
//
// Each triple is of the form (i, j, n), and means that
// a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in
// i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are
// adjacent triples in the list, and the second is not the last triple in the
// list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe
// adjacent equal blocks.
//
// The last triple is a dummy, (len(a), len(b), 0), and is the only
// triple with n==0.
func (m *SequenceMatcher) GetMatchingBlocks() []Match {
if m.matchingBlocks != nil {
return m.matchingBlocks
}
var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match
matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match {
match := m.findLongestMatch(alo, ahi, blo, bhi)
i, j, k := match.A, match.B, match.Size
if match.Size > 0 {
if alo < i && blo < j {
matched = matchBlocks(alo, i, blo, j, matched)
}
matched = append(matched, match)
if i+k < ahi && j+k < bhi {
matched = matchBlocks(i+k, ahi, j+k, bhi, matched)
}
}
return matched
}
matched := matchBlocks(0, len(m.a), 0, len(m.b), nil)
// It's possible that we have adjacent equal blocks in the
// matching_blocks list now.
nonAdjacent := []Match{}
i1, j1, k1 := 0, 0, 0
for _, b := range matched {
// Is this block adjacent to i1, j1, k1?
i2, j2, k2 := b.A, b.B, b.Size
if i1+k1 == i2 && j1+k1 == j2 {
// Yes, so collapse them -- this just increases the length of
// the first block by the length of the second, and the first
// block so lengthened remains the block to compare against.
k1 += k2
} else {
// Not adjacent. Remember the first block (k1==0 means it's
// the dummy we started with), and make the second block the
// new block to compare against.
if k1 > 0 {
nonAdjacent = append(nonAdjacent, Match{i1, j1, k1})
}
i1, j1, k1 = i2, j2, k2
}
}
if k1 > 0 {
nonAdjacent = append(nonAdjacent, Match{i1, j1, k1})
}
nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0})
m.matchingBlocks = nonAdjacent
return m.matchingBlocks
}
// GetOpCodes returns a list of 5-tuples describing how to turn a into b.
//
// Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple
// has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the
// tuple preceding it, and likewise for j1 == the previous j2.
//
// The tags are characters, with these meanings:
//
// 'r' (replace): a[i1:i2] should be replaced by b[j1:j2]
//
// 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case.
//
// 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case.
//
// 'e' (equal): a[i1:i2] == b[j1:j2]
func (m *SequenceMatcher) GetOpCodes() []OpCode {
if m.opCodes != nil {
return m.opCodes
}
i, j := 0, 0
matching := m.GetMatchingBlocks()
opCodes := make([]OpCode, 0, len(matching))
for _, m := range matching {
// invariant: we've pumped out correct diffs to change
// a[:i] into b[:j], and the next matching block is
// a[ai:ai+size] == b[bj:bj+size]. So we need to pump
// out a diff to change a[i:ai] into b[j:bj], pump out
// the matching block, and move (i,j) beyond the match
ai, bj, size := m.A, m.B, m.Size
tag := byte(0)
if i < ai && j < bj {
tag = 'r'
} else if i < ai {
tag = 'd'
} else if j < bj {
tag = 'i'
}
if tag > 0 {
opCodes = append(opCodes, OpCode{tag, i, ai, j, bj})
}
i, j = ai+size, bj+size
// the list of matching blocks is terminated by a
// sentinel with size 0
if size > 0 {
opCodes = append(opCodes, OpCode{'e', ai, i, bj, j})
}
}
m.opCodes = opCodes
return m.opCodes
}
// GetGroupedOpCodes isolates change clusters by eliminating ranges with no changes.
//
// Returns a generator of groups with up to n lines of context.
// Each group is in the same format as returned by GetOpCodes().
func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
if n < 0 {
n = 3
}
codes := m.GetOpCodes()
if len(codes) == 0 {
codes = []OpCode{{'e', 0, 1, 0, 1}}
}
// Fixup leading and trailing groups if they show no changes.
if codes[0].Tag == 'e' {
c := codes[0]
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2}
}
if codes[len(codes)-1].Tag == 'e' {
c := codes[len(codes)-1]
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)}
}
nn := n + n
groups := [][]OpCode{}
group := []OpCode{}
for _, c := range codes {
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
// End the current group and start a new one whenever
// there is a large range with no changes.
if c.Tag == 'e' && i2-i1 > nn {
group = append(group, OpCode{c.Tag, i1, min(i2, i1+n),
j1, min(j2, j1+n)})
groups = append(groups, group)
group = []OpCode{}
i1, j1 = max(i1, i2-n), max(j1, j2-n)
}
group = append(group, OpCode{c.Tag, i1, i2, j1, j2})
}
if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') {
groups = append(groups, group)
}
return groups
}
// Ratio returns a measure of the sequences' similarity (float in [0,1]).
//
// Where T is the total number of elements in both sequences, and
// M is the number of matches, this is 2.0*M / T.
// Note that this is 1 if the sequences are identical, and 0 if
// they have nothing in common.
//
// .Ratio() is expensive to compute if you haven't already computed
// .GetMatchingBlocks() or .GetOpCodes(), in which case you may
// want to try .QuickRatio() or .RealQuickRation() first to get an
// upper bound.
func (m *SequenceMatcher) Ratio() float64 {
matches := 0
for _, m := range m.GetMatchingBlocks() {
matches += m.Size
}
return calculateRatio(matches, len(m.a)+len(m.b))
}
// QuickRatio returns an upper bound on ratio() relatively quickly.
//
// This isn't defined beyond that it is an upper bound on .Ratio(), and
// is faster to compute.
func (m *SequenceMatcher) QuickRatio() float64 {
// viewing a and b as multisets, set matches to the cardinality
// of their intersection; this counts the number of matches
// without regard to order, so is clearly an upper bound
if m.fullBCount == nil {
m.fullBCount = map[string]int{}
for _, s := range m.b {
m.fullBCount[s] = m.fullBCount[s] + 1
}
}
// avail[x] is the number of times x appears in 'b' less the
// number of times we've seen it in 'a' so far ... kinda
avail := map[string]int{}
matches := 0
for _, s := range m.a {
n, ok := avail[s]
if !ok {
n = m.fullBCount[s]
}
avail[s] = n - 1
if n > 0 {
matches++
}
}
return calculateRatio(matches, len(m.a)+len(m.b))
}
// RealQuickRatio returns an upper bound on ratio() very quickly.
//
// This isn't defined beyond that it is an upper bound on .Ratio(), and
// is faster to compute than either .Ratio() or .QuickRatio().
func (m *SequenceMatcher) RealQuickRatio() float64 {
la, lb := len(m.a), len(m.b)
return calculateRatio(min(la, lb), la+lb)
}
// Convert range to the "ed" format
func formatRangeUnified(start, stop int) string {
// Per the diff spec at http://www.unix.org/single_unix_specification/
beginning := start + 1 // lines start numbering with one
length := stop - start
if length == 1 {
return fmt.Sprintf("%d", beginning)
}
if length == 0 {
beginning-- // empty ranges begin at line just before the range
}
return fmt.Sprintf("%d,%d", beginning, length)
}
// UnifiedDiff represents the unified diff parameters.
type UnifiedDiff struct {
A []string // First sequence lines
FromFile string // First file name
FromDate string // First file time
B []string // Second sequence lines
ToFile string // Second file name
ToDate string // Second file time
Eol string // Headers end of line, defaults to LF
Context int // Number of context lines
}
// WriteUnifiedDiff compares two sequences of lines; generates the delta as
// a unified diff.
//
// Unified diffs are a compact way of showing line changes and a few
// lines of context. The number of context lines is set by 'n' which
// defaults to three.
//
// By default, the diff control lines (those with ---, +++, or @@) are
// created with a trailing newline. This is helpful so that inputs
// created from file.readlines() result in diffs that are suitable for
// file.writelines() since both the inputs and outputs have trailing
// newlines.
//
// For inputs that do not have trailing newlines, set the lineterm
// argument to "" so that the output will be uniformly newline free.
//
// The unidiff format normally has a header for filenames and modification
// times. Any or all of these may be specified using strings for
// 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'.
// The modification times are normally expressed in the ISO 8601 format.
func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error {
buf := bufio.NewWriter(writer)
defer buf.Flush()
wf := func(format string, args ...interface{}) error {
_, err := buf.WriteString(fmt.Sprintf(format, args...))
return err
}
ws := func(s string) error {
_, err := buf.WriteString(s)
return err
}
if len(diff.Eol) == 0 {
diff.Eol = "\n"
}
started := false
m := NewMatcher(diff.A, diff.B)
for _, g := range m.GetGroupedOpCodes(diff.Context) {
if !started {
started = true
fromDate := ""
if len(diff.FromDate) > 0 {
fromDate = "\t" + diff.FromDate
}
toDate := ""
if len(diff.ToDate) > 0 {
toDate = "\t" + diff.ToDate
}
if diff.FromFile != "" || diff.ToFile != "" {
err := wf("--- %s%s%s", diff.FromFile, fromDate, diff.Eol)
if err != nil {
return err
}
err = wf("+++ %s%s%s", diff.ToFile, toDate, diff.Eol)
if err != nil {
return err
}
}
}
first, last := g[0], g[len(g)-1]
range1 := formatRangeUnified(first.I1, last.I2)
range2 := formatRangeUnified(first.J1, last.J2)
if err := wf("@@ -%s +%s @@%s", range1, range2, diff.Eol); err != nil {
return err
}
for _, c := range g {
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
if c.Tag == 'e' {
for _, line := range diff.A[i1:i2] {
if err := ws(" " + line); err != nil {
return err
}
}
continue
}
if c.Tag == 'r' || c.Tag == 'd' {
for _, line := range diff.A[i1:i2] {
if err := ws("-" + line); err != nil {
return err
}
}
}
if c.Tag == 'r' || c.Tag == 'i' {
for _, line := range diff.B[j1:j2] {
if err := ws("+" + line); err != nil {
return err
}
}
}
}
}
return nil
}
// GetUnifiedDiffString is like WriteUnifiedDiff but returns the diff as a string.
func GetUnifiedDiffString(diff UnifiedDiff) (string, error) {
w := &bytes.Buffer{}
err := WriteUnifiedDiff(w, diff)
return w.String(), err
}
// Convert range to the "ed" format.
func formatRangeContext(start, stop int) string {
// Per the diff spec at http://www.unix.org/single_unix_specification/
beginning := start + 1 // lines start numbering with one
length := stop - start
if length == 0 {
beginning-- // empty ranges begin at line just before the range
}
if length <= 1 {
return fmt.Sprintf("%d", beginning)
}
return fmt.Sprintf("%d,%d", beginning, beginning+length-1)
}
type ContextDiff UnifiedDiff
// WriteContextDiff compares two sequences of lines; generates the delta as a context diff.
//
// Context diffs are a compact way of showing line changes and a few
// lines of context. The number of context lines is set by diff.Context
// which defaults to three.
//
// By default, the diff control lines (those with *** or ---) are
// created with a trailing newline.
//
// For inputs that do not have trailing newlines, set the diff.Eol
// argument to "" so that the output will be uniformly newline free.
//
// The context diff format normally has a header for filenames and
// modification times. Any or all of these may be specified using
// strings for diff.FromFile, diff.ToFile, diff.FromDate, diff.ToDate.
// The modification times are normally expressed in the ISO 8601 format.
// If not specified, the strings default to blanks.
func WriteContextDiff(writer io.Writer, diff ContextDiff) error {
buf := bufio.NewWriter(writer)
defer buf.Flush()
var diffErr error
wf := func(format string, args ...interface{}) {
_, err := buf.WriteString(fmt.Sprintf(format, args...))
if diffErr == nil && err != nil {
diffErr = err
}
}
ws := func(s string) {
_, err := buf.WriteString(s)
if diffErr == nil && err != nil {
diffErr = err
}
}
if len(diff.Eol) == 0 {
diff.Eol = "\n"
}
prefix := map[byte]string{
'i': "+ ",
'd': "- ",
'r': "! ",
'e': " ",
}
started := false
m := NewMatcher(diff.A, diff.B)
for _, g := range m.GetGroupedOpCodes(diff.Context) {
if !started {
started = true
fromDate := ""
if len(diff.FromDate) > 0 {
fromDate = "\t" + diff.FromDate
}
toDate := ""
if len(diff.ToDate) > 0 {
toDate = "\t" + diff.ToDate
}
if diff.FromFile != "" || diff.ToFile != "" {
wf("*** %s%s%s", diff.FromFile, fromDate, diff.Eol)
wf("--- %s%s%s", diff.ToFile, toDate, diff.Eol)
}
}
first, last := g[0], g[len(g)-1]
ws("***************" + diff.Eol)
range1 := formatRangeContext(first.I1, last.I2)
wf("*** %s ****%s", range1, diff.Eol)
for _, c := range g {
if c.Tag == 'r' || c.Tag == 'd' {
for _, cc := range g {
if cc.Tag == 'i' {
continue
}
for _, line := range diff.A[cc.I1:cc.I2] {
ws(prefix[cc.Tag] + line)
}
}
break
}
}
range2 := formatRangeContext(first.J1, last.J2)
wf("--- %s ----%s", range2, diff.Eol)
for _, c := range g {
if c.Tag == 'r' || c.Tag == 'i' {
for _, cc := range g {
if cc.Tag == 'd' {
continue
}
for _, line := range diff.B[cc.J1:cc.J2] {
ws(prefix[cc.Tag] + line)
}
}
break
}
}
}
return diffErr
}
// GetContextDiffString is like WriteContextDiff but returns the diff as a string.
func GetContextDiffString(diff ContextDiff) (string, error) {
w := &bytes.Buffer{}
err := WriteContextDiff(w, diff)
return w.String(), err
}
// SplitLines splits a string on "\n" while preserving them. The output can be used
// as input for UnifiedDiff and ContextDiff structures.
func SplitLines(s string) []string {
lines := strings.SplitAfter(s, "\n")
lines[len(lines)-1] += "\n"
return lines
}

View File

@ -0,0 +1,326 @@
// Copied from https://github.com/pmezard/go-difflib/blob/5d4384ee4fb2527b0a1256a821ebfc92f91efefc/difflib/difflib_test.go
// Copyright 2013 Patrick Mezard. All rights reserved. Use of this source code is
// governed by a license that can be found in the THIRD-PARTY-NOTICES file.
package assert
import (
"bytes"
"fmt"
"math"
"reflect"
"strings"
"testing"
)
func assertAlmostEqual(t *testing.T, a, b float64, places int) {
if math.Abs(a-b) > math.Pow10(-places) {
t.Errorf("%.7f != %.7f", a, b)
}
}
func assertEqual(t *testing.T, a, b interface{}) {
if !reflect.DeepEqual(a, b) {
t.Errorf("%v != %v", a, b)
}
}
func splitChars(s string) []string {
chars := make([]string, 0, len(s))
// Assume ASCII inputs
for i := 0; i != len(s); i++ {
chars = append(chars, string(s[i]))
}
return chars
}
func TestSequenceMatcherRatio(t *testing.T) {
s := NewMatcher(splitChars("abcd"), splitChars("bcde"))
assertEqual(t, s.Ratio(), 0.75)
assertEqual(t, s.QuickRatio(), 0.75)
assertEqual(t, s.RealQuickRatio(), 1.0)
}
func TestGetOptCodes(t *testing.T) {
a := "qabxcd"
b := "abycdf"
s := NewMatcher(splitChars(a), splitChars(b))
w := &bytes.Buffer{}
for _, op := range s.GetOpCodes() {
fmt.Fprintf(w, "%s a[%d:%d], (%s) b[%d:%d] (%s)\n", string(op.Tag),
op.I1, op.I2, a[op.I1:op.I2], op.J1, op.J2, b[op.J1:op.J2])
}
result := w.String()
expected := `d a[0:1], (q) b[0:0] ()
e a[1:3], (ab) b[0:2] (ab)
r a[3:4], (x) b[2:3] (y)
e a[4:6], (cd) b[3:5] (cd)
i a[6:6], () b[5:6] (f)
`
if expected != result {
t.Errorf("unexpected op codes: \n%s", result)
}
}
func TestGroupedOpCodes(t *testing.T) {
a := []string{}
for i := 0; i != 39; i++ {
a = append(a, fmt.Sprintf("%02d", i))
}
b := []string{}
b = append(b, a[:8]...)
b = append(b, " i")
b = append(b, a[8:19]...)
b = append(b, " x")
b = append(b, a[20:22]...)
b = append(b, a[27:34]...)
b = append(b, " y")
b = append(b, a[35:]...)
s := NewMatcher(a, b)
w := &bytes.Buffer{}
for _, g := range s.GetGroupedOpCodes(-1) {
fmt.Fprintf(w, "group\n")
for _, op := range g {
fmt.Fprintf(w, " %s, %d, %d, %d, %d\n", string(op.Tag),
op.I1, op.I2, op.J1, op.J2)
}
}
result := w.String()
expected := `group
e, 5, 8, 5, 8
i, 8, 8, 8, 9
e, 8, 11, 9, 12
group
e, 16, 19, 17, 20
r, 19, 20, 20, 21
e, 20, 22, 21, 23
d, 22, 27, 23, 23
e, 27, 30, 23, 26
group
e, 31, 34, 27, 30
r, 34, 35, 30, 31
e, 35, 38, 31, 34
`
if expected != result {
t.Errorf("unexpected op codes: \n%s", result)
}
}
func rep(s string, count int) string {
return strings.Repeat(s, count)
}
func TestWithAsciiOneInsert(t *testing.T) {
sm := NewMatcher(splitChars(rep("b", 100)),
splitChars("a"+rep("b", 100)))
assertAlmostEqual(t, sm.Ratio(), 0.995, 3)
assertEqual(t, sm.GetOpCodes(),
[]OpCode{{'i', 0, 0, 0, 1}, {'e', 0, 100, 1, 101}})
assertEqual(t, len(sm.bPopular), 0)
sm = NewMatcher(splitChars(rep("b", 100)),
splitChars(rep("b", 50)+"a"+rep("b", 50)))
assertAlmostEqual(t, sm.Ratio(), 0.995, 3)
assertEqual(t, sm.GetOpCodes(),
[]OpCode{{'e', 0, 50, 0, 50}, {'i', 50, 50, 50, 51}, {'e', 50, 100, 51, 101}})
assertEqual(t, len(sm.bPopular), 0)
}
func TestWithAsciiOnDelete(t *testing.T) {
sm := NewMatcher(splitChars(rep("a", 40)+"c"+rep("b", 40)),
splitChars(rep("a", 40)+rep("b", 40)))
assertAlmostEqual(t, sm.Ratio(), 0.994, 3)
assertEqual(t, sm.GetOpCodes(),
[]OpCode{{'e', 0, 40, 0, 40}, {'d', 40, 41, 40, 40}, {'e', 41, 81, 40, 80}})
}
func TestWithAsciiBJunk(t *testing.T) {
isJunk := func(s string) bool {
return s == " "
}
sm := NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)),
splitChars(rep("a", 44)+rep("b", 40)), true, isJunk)
assertEqual(t, sm.bJunk, map[string]struct{}{})
sm = NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)),
splitChars(rep("a", 44)+rep("b", 40)+rep(" ", 20)), false, isJunk)
assertEqual(t, sm.bJunk, map[string]struct{}{" ": struct{}{}})
isJunk = func(s string) bool {
return s == " " || s == "b"
}
sm = NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)),
splitChars(rep("a", 44)+rep("b", 40)+rep(" ", 20)), false, isJunk)
assertEqual(t, sm.bJunk, map[string]struct{}{" ": struct{}{}, "b": struct{}{}})
}
func TestSFBugsRatioForNullSeqn(t *testing.T) {
sm := NewMatcher(nil, nil)
assertEqual(t, sm.Ratio(), 1.0)
assertEqual(t, sm.QuickRatio(), 1.0)
assertEqual(t, sm.RealQuickRatio(), 1.0)
}
func TestSFBugsComparingEmptyLists(t *testing.T) {
groups := NewMatcher(nil, nil).GetGroupedOpCodes(-1)
assertEqual(t, len(groups), 0)
diff := UnifiedDiff{
FromFile: "Original",
ToFile: "Current",
Context: 3,
}
result, err := GetUnifiedDiffString(diff)
assertEqual(t, err, nil)
assertEqual(t, result, "")
}
func TestOutputFormatRangeFormatUnified(t *testing.T) {
// Per the diff spec at http://www.unix.org/single_unix_specification/
//
// Each <range> field shall be of the form:
// %1d", <beginning line number> if the range contains exactly one line,
// and:
// "%1d,%1d", <beginning line number>, <number of lines> otherwise.
// If a range is empty, its beginning line number shall be the number of
// the line just before the range, or 0 if the empty range starts the file.
fm := formatRangeUnified
assertEqual(t, fm(3, 3), "3,0")
assertEqual(t, fm(3, 4), "4")
assertEqual(t, fm(3, 5), "4,2")
assertEqual(t, fm(3, 6), "4,3")
assertEqual(t, fm(0, 0), "0,0")
}
func TestOutputFormatRangeFormatContext(t *testing.T) {
// Per the diff spec at http://www.unix.org/single_unix_specification/
//
// The range of lines in file1 shall be written in the following format
// if the range contains two or more lines:
// "*** %d,%d ****\n", <beginning line number>, <ending line number>
// and the following format otherwise:
// "*** %d ****\n", <ending line number>
// The ending line number of an empty range shall be the number of the preceding line,
// or 0 if the range is at the start of the file.
//
// Next, the range of lines in file2 shall be written in the following format
// if the range contains two or more lines:
// "--- %d,%d ----\n", <beginning line number>, <ending line number>
// and the following format otherwise:
// "--- %d ----\n", <ending line number>
fm := formatRangeContext
assertEqual(t, fm(3, 3), "3")
assertEqual(t, fm(3, 4), "4")
assertEqual(t, fm(3, 5), "4,5")
assertEqual(t, fm(3, 6), "4,6")
assertEqual(t, fm(0, 0), "0")
}
func TestOutputFormatTabDelimiter(t *testing.T) {
diff := UnifiedDiff{
A: splitChars("one"),
B: splitChars("two"),
FromFile: "Original",
FromDate: "2005-01-26 23:30:50",
ToFile: "Current",
ToDate: "2010-04-12 10:20:52",
Eol: "\n",
}
ud, err := GetUnifiedDiffString(diff)
assertEqual(t, err, nil)
assertEqual(t, SplitLines(ud)[:2], []string{
"--- Original\t2005-01-26 23:30:50\n",
"+++ Current\t2010-04-12 10:20:52\n",
})
cd, err := GetContextDiffString(ContextDiff(diff))
assertEqual(t, err, nil)
assertEqual(t, SplitLines(cd)[:2], []string{
"*** Original\t2005-01-26 23:30:50\n",
"--- Current\t2010-04-12 10:20:52\n",
})
}
func TestOutputFormatNoTrailingTabOnEmptyFiledate(t *testing.T) {
diff := UnifiedDiff{
A: splitChars("one"),
B: splitChars("two"),
FromFile: "Original",
ToFile: "Current",
Eol: "\n",
}
ud, err := GetUnifiedDiffString(diff)
assertEqual(t, err, nil)
assertEqual(t, SplitLines(ud)[:2], []string{"--- Original\n", "+++ Current\n"})
cd, err := GetContextDiffString(ContextDiff(diff))
assertEqual(t, err, nil)
assertEqual(t, SplitLines(cd)[:2], []string{"*** Original\n", "--- Current\n"})
}
func TestOmitFilenames(t *testing.T) {
diff := UnifiedDiff{
A: SplitLines("o\nn\ne\n"),
B: SplitLines("t\nw\no\n"),
Eol: "\n",
}
ud, err := GetUnifiedDiffString(diff)
assertEqual(t, err, nil)
assertEqual(t, SplitLines(ud), []string{
"@@ -0,0 +1,2 @@\n",
"+t\n",
"+w\n",
"@@ -2,2 +3,0 @@\n",
"-n\n",
"-e\n",
"\n",
})
cd, err := GetContextDiffString(ContextDiff(diff))
assertEqual(t, err, nil)
assertEqual(t, SplitLines(cd), []string{
"***************\n",
"*** 0 ****\n",
"--- 1,2 ----\n",
"+ t\n",
"+ w\n",
"***************\n",
"*** 2,3 ****\n",
"- n\n",
"- e\n",
"--- 3 ----\n",
"\n",
})
}
func TestSplitLines(t *testing.T) {
allTests := []struct {
input string
want []string
}{
{"foo", []string{"foo\n"}},
{"foo\nbar", []string{"foo\n", "bar\n"}},
{"foo\nbar\n", []string{"foo\n", "bar\n", "\n"}},
}
for _, test := range allTests {
assertEqual(t, SplitLines(test.input), test.want)
}
}
func benchmarkSplitLines(b *testing.B, count int) {
str := strings.Repeat("foo\n", count)
b.ResetTimer()
n := 0
for i := 0; i < b.N; i++ {
n += len(SplitLines(str))
}
}
func BenchmarkSplitLines100(b *testing.B) {
benchmarkSplitLines(b, 100)
}
func BenchmarkSplitLines10000(b *testing.B) {
benchmarkSplitLines(b, 10000)
}

View File

@ -0,0 +1,60 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/awserr/error.go
// See THIRD-PARTY-NOTICES for original license terms
// Package awserr represents API error interface accessors for the SDK.
package awserr
// An Error wraps lower level errors with code, message and an original error.
// The underlying concrete error type may also satisfy other interfaces which
// can be to used to obtain more specific information about the error.
type Error interface {
// Satisfy the generic error interface.
error
// Returns the short phrase depicting the classification of the error.
Code() string
// Returns the error details message.
Message() string
// Returns the original error if one was set. Nil is returned if not set.
OrigErr() error
}
// BatchedErrors is a batch of errors which also wraps lower level errors with
// code, message, and original errors. Calling Error() will include all errors
// that occurred in the batch.
//
// Replaces BatchError
type BatchedErrors interface {
// Satisfy the base Error interface.
Error
// Returns the original error if one was set. Nil is returned if not set.
OrigErrs() []error
}
// New returns an Error object described by the code, message, and origErr.
//
// If origErr satisfies the Error interface it will not be wrapped within a new
// Error object and will instead be returned.
func New(code, message string, origErr error) Error {
var errs []error
if origErr != nil {
errs = append(errs, origErr)
}
return newBaseError(code, message, errs)
}
// NewBatchError returns an BatchedErrors with a collection of errors as an
// array of errors.
func NewBatchError(code, message string, errs []error) BatchedErrors {
return newBaseError(code, message, errs)
}

View File

@ -0,0 +1,144 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/awserr/types.go
// See THIRD-PARTY-NOTICES for original license terms
package awserr
import (
"fmt"
)
// SprintError returns a string of the formatted error code.
//
// Both extra and origErr are optional. If they are included their lines
// will be added, but if they are not included their lines will be ignored.
func SprintError(code, message, extra string, origErr error) string {
msg := fmt.Sprintf("%s: %s", code, message)
if extra != "" {
msg = fmt.Sprintf("%s\n\t%s", msg, extra)
}
if origErr != nil {
msg = fmt.Sprintf("%s\ncaused by: %s", msg, origErr.Error())
}
return msg
}
// A baseError wraps the code and message which defines an error. It also
// can be used to wrap an original error object.
//
// Should be used as the root for errors satisfying the awserr.Error. Also
// for any error which does not fit into a specific error wrapper type.
type baseError struct {
// Classification of error
code string
// Detailed information about error
message string
// Optional original error this error is based off of. Allows building
// chained errors.
errs []error
}
// newBaseError returns an error object for the code, message, and errors.
//
// code is a short no whitespace phrase depicting the classification of
// the error that is being created.
//
// message is the free flow string containing detailed information about the
// error.
//
// origErrs is the error objects which will be nested under the new errors to
// be returned.
func newBaseError(code, message string, origErrs []error) *baseError {
b := &baseError{
code: code,
message: message,
errs: origErrs,
}
return b
}
// Error returns the string representation of the error.
//
// See ErrorWithExtra for formatting.
//
// Satisfies the error interface.
func (b baseError) Error() string {
size := len(b.errs)
if size > 0 {
return SprintError(b.code, b.message, "", errorList(b.errs))
}
return SprintError(b.code, b.message, "", nil)
}
// String returns the string representation of the error.
// Alias for Error to satisfy the stringer interface.
func (b baseError) String() string {
return b.Error()
}
// Code returns the short phrase depicting the classification of the error.
func (b baseError) Code() string {
return b.code
}
// Message returns the error details message.
func (b baseError) Message() string {
return b.message
}
// OrigErr returns the original error if one was set. Nil is returned if no
// error was set. This only returns the first element in the list. If the full
// list is needed, use BatchedErrors.
func (b baseError) OrigErr() error {
switch len(b.errs) {
case 0:
return nil
case 1:
return b.errs[0]
default:
if err, ok := b.errs[0].(Error); ok {
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
}
return NewBatchError("BatchedErrors",
"multiple errors occurred", b.errs)
}
}
// OrigErrs returns the original errors if one was set. An empty slice is
// returned if no error was set.
func (b baseError) OrigErrs() []error {
return b.errs
}
// An error list that satisfies the golang interface
type errorList []error
// Error returns the string representation of the error.
//
// Satisfies the error interface.
func (e errorList) Error() string {
msg := ""
// How do we want to handle the array size being zero
if size := len(e); size > 0 {
for i := 0; i < size; i++ {
msg += e[i].Error()
// We check the next index to see if it is within the slice.
// If it is, then we append a newline. We do this, because unit tests
// could be broken with the additional '\n'
if i+1 < size {
msg += "\n"
}
}
}
return msg
}

View File

@ -0,0 +1,72 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider.go
// See THIRD-PARTY-NOTICES for original license terms
package credentials
import (
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
)
// A ChainProvider will search for a provider which returns credentials
// and cache that provider until Retrieve is called again.
//
// The ChainProvider provides a way of chaining multiple providers together
// which will pick the first available using priority order of the Providers
// in the list.
//
// If none of the Providers retrieve valid credentials Value, ChainProvider's
// Retrieve() will return the error ErrNoValidProvidersFoundInChain.
//
// If a Provider is found which returns valid credentials Value ChainProvider
// will cache that Provider for all calls to IsExpired(), until Retrieve is
// called again.
type ChainProvider struct {
Providers []Provider
curr Provider
}
// NewChainCredentials returns a pointer to a new Credentials object
// wrapping a chain of providers.
func NewChainCredentials(providers []Provider) *Credentials {
return NewCredentials(&ChainProvider{
Providers: append([]Provider{}, providers...),
})
}
// Retrieve returns the credentials value or error if no provider returned
// without error.
//
// If a provider is found it will be cached and any calls to IsExpired()
// will return the expired state of the cached provider.
func (c *ChainProvider) Retrieve() (Value, error) {
var errs = make([]error, 0, len(c.Providers))
for _, p := range c.Providers {
creds, err := p.Retrieve()
if err == nil {
c.curr = p
return creds, nil
}
errs = append(errs, err)
}
c.curr = nil
var err = awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
return Value{}, err
}
// IsExpired will returned the expired state of the currently cached provider
// if there is one. If there is no current provider, true will be returned.
func (c *ChainProvider) IsExpired() bool {
if c.curr != nil {
return c.curr.IsExpired()
}
return true
}

View File

@ -0,0 +1,176 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider_test.go
// See THIRD-PARTY-NOTICES for original license terms
package credentials
import (
"reflect"
"testing"
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
)
type secondStubProvider struct {
creds Value
expired bool
err error
}
func (s *secondStubProvider) Retrieve() (Value, error) {
s.expired = false
s.creds.ProviderName = "secondStubProvider"
return s.creds, s.err
}
func (s *secondStubProvider) IsExpired() bool {
return s.expired
}
func TestChainProviderWithNames(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&secondStubProvider{
creds: Value{
AccessKeyID: "AKIF",
SecretAccessKey: "NOSECRET",
SessionToken: "",
},
},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if e, a := "secondStubProvider", creds.ProviderName; e != a {
t.Errorf("Expect provider name to match, %v got, %v", e, a)
}
// Also check credentials
if e, a := "AKIF", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "NOSECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestChainProviderGet(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
},
},
}
creds, err := p.Retrieve()
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestChainProviderIsExpired(t *testing.T) {
stubProvider := &stubProvider{expired: true}
p := &ChainProvider{
Providers: []Provider{
stubProvider,
},
}
if !p.IsExpired() {
t.Errorf("Expect expired to be true before any Retrieve")
}
_, err := p.Retrieve()
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if p.IsExpired() {
t.Errorf("Expect not expired after retrieve")
}
stubProvider.expired = true
if !p.IsExpired() {
t.Errorf("Expect return of expired provider")
}
_, err = p.Retrieve()
if err != nil {
t.Errorf("Expect no error, got %v", err)
}
if p.IsExpired() {
t.Errorf("Expect not expired after retrieve")
}
}
func TestChainProviderWithNoProvider(t *testing.T) {
p := &ChainProvider{
Providers: []Provider{},
}
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
if err.Error() != "NoCredentialProviders: no valid providers in chain" {
t.Errorf("Expect no providers error returned, got %v", err)
}
}
func TestChainProviderWithNoValidProvider(t *testing.T) {
errs := []error{
awserr.New("FirstError", "first provider error", nil),
awserr.New("SecondError", "second provider error", nil),
}
p := &ChainProvider{
Providers: []Provider{
&stubProvider{err: errs[0]},
&stubProvider{err: errs[1]},
},
}
if !p.IsExpired() {
t.Errorf("Expect expired with no providers")
}
_, err := p.Retrieve()
expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
if e, a := expectErr, err; !reflect.DeepEqual(e, a) {
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
}
}

View File

@ -0,0 +1,197 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials.go
// See THIRD-PARTY-NOTICES for original license terms
package credentials
import (
"context"
"sync"
"time"
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
"golang.org/x/sync/singleflight"
)
// A Value is the AWS credentials value for individual credential fields.
//
// A Value is also used to represent Azure credentials.
// Azure credentials only consist of an access token, which is stored in the `SessionToken` field.
type Value struct {
// AWS Access key ID
AccessKeyID string
// AWS Secret Access Key
SecretAccessKey string
// AWS Session Token
SessionToken string
// Provider used to get credentials
ProviderName string
}
// HasKeys returns if the credentials Value has both AccessKeyID and
// SecretAccessKey value set.
func (v Value) HasKeys() bool {
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
}
// A Provider is the interface for any component which will provide credentials
// Value. A provider is required to manage its own Expired state, and what to
// be expired means.
//
// The Provider should not need to implement its own mutexes, because
// that will be managed by Credentials.
type Provider interface {
// Retrieve returns nil if it successfully retrieved the value.
// Error is returned if the value were not obtainable, or empty.
Retrieve() (Value, error)
// IsExpired returns if the credentials are no longer valid, and need
// to be retrieved.
IsExpired() bool
}
// ProviderWithContext is a Provider that can retrieve credentials with a Context
type ProviderWithContext interface {
Provider
RetrieveWithContext(context.Context) (Value, error)
}
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
//
// A Credentials is also used to fetch Azure credentials Value.
//
// Credentials will cache the credentials value until they expire. Once the value
// expires the next Get will attempt to retrieve valid credentials.
//
// Credentials is safe to use across multiple goroutines and will manage the
// synchronous state so the Providers do not need to implement their own
// synchronization.
//
// The first Credentials.Get() will always call Provider.Retrieve() to get the
// first instance of the credentials Value. All calls to Get() after that
// will return the cached credentials Value until IsExpired() returns true.
type Credentials struct {
sf singleflight.Group
m sync.RWMutex
creds Value
provider Provider
}
// NewCredentials returns a pointer to a new Credentials with the provider set.
func NewCredentials(provider Provider) *Credentials {
c := &Credentials{
provider: provider,
}
return c
}
// GetWithContext returns the credentials value, or error if the credentials
// Value failed to be retrieved. Will return early if the passed in context is
// canceled.
//
// Will return the cached credentials Value if it has not expired. If the
// credentials Value has expired the Provider's Retrieve() will be called
// to refresh the credentials.
//
// If Credentials.Expire() was called the credentials Value will be force
// expired, and the next call to Get() will cause them to be refreshed.
func (c *Credentials) GetWithContext(ctx context.Context) (Value, error) {
// Check if credentials are cached, and not expired.
select {
case curCreds, ok := <-c.asyncIsExpired():
// ok will only be true, of the credentials were not expired. ok will
// be false and have no value if the credentials are expired.
if ok {
return curCreds, nil
}
case <-ctx.Done():
return Value{}, awserr.New("RequestCanceled",
"request context canceled", ctx.Err())
}
// Cannot pass context down to the actual retrieve, because the first
// context would cancel the whole group when there is not direct
// association of items in the group.
resCh := c.sf.DoChan("", func() (interface{}, error) {
return c.singleRetrieve(&suppressedContext{ctx})
})
select {
case res := <-resCh:
return res.Val.(Value), res.Err
case <-ctx.Done():
return Value{}, awserr.New("RequestCanceled",
"request context canceled", ctx.Err())
}
}
func (c *Credentials) singleRetrieve(ctx context.Context) (interface{}, error) {
c.m.Lock()
defer c.m.Unlock()
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
return curCreds, nil
}
var creds Value
var err error
if p, ok := c.provider.(ProviderWithContext); ok {
creds, err = p.RetrieveWithContext(ctx)
} else {
creds, err = c.provider.Retrieve()
}
if err == nil {
c.creds = creds
}
return creds, err
}
// asyncIsExpired returns a channel of credentials Value. If the channel is
// closed the credentials are expired and credentials value are not empty.
func (c *Credentials) asyncIsExpired() <-chan Value {
ch := make(chan Value, 1)
go func() {
c.m.RLock()
defer c.m.RUnlock()
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
ch <- curCreds
}
close(ch)
}()
return ch
}
// isExpiredLocked helper method wrapping the definition of expired credentials.
func (c *Credentials) isExpiredLocked(creds interface{}) bool {
return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired()
}
type suppressedContext struct {
context.Context
}
func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}
func (s *suppressedContext) Done() <-chan struct{} {
return nil
}
func (s *suppressedContext) Err() error {
return nil
}

View File

@ -0,0 +1,192 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials_test.go
// See THIRD-PARTY-NOTICES for original license terms
package credentials
import (
"context"
"sync"
"testing"
"time"
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
)
func isExpired(c *Credentials) bool {
c.m.RLock()
defer c.m.RUnlock()
return c.isExpiredLocked(c.creds)
}
type stubProvider struct {
creds Value
retrievedCount int
expired bool
err error
}
func (s *stubProvider) Retrieve() (Value, error) {
s.retrievedCount++
s.expired = false
s.creds.ProviderName = "stubProvider"
return s.creds, s.err
}
func (s *stubProvider) IsExpired() bool {
return s.expired
}
func TestCredentialsGet(t *testing.T) {
c := NewCredentials(&stubProvider{
creds: Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "",
},
expired: true,
})
creds, err := c.GetWithContext(context.Background())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("Expect access key ID to match, %v got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("Expect secret access key to match, %v got %v", e, a)
}
if v := creds.SessionToken; len(v) != 0 {
t.Errorf("Expect session token to be empty, %v", v)
}
}
func TestCredentialsGetWithError(t *testing.T) {
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
_, err := c.GetWithContext(context.Background())
if e, a := "provider error", err.(awserr.Error).Code(); e != a {
t.Errorf("Expected provider error, %v got %v", e, a)
}
}
func TestCredentialsExpire(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
stub.expired = false
if !isExpired(c) {
t.Errorf("Expected to start out expired")
}
_, err := c.GetWithContext(context.Background())
if err != nil {
t.Errorf("Expected no err, got %v", err)
}
if isExpired(c) {
t.Errorf("Expected not to be expired")
}
stub.expired = true
if !isExpired(c) {
t.Errorf("Expected to be expired")
}
}
func TestCredentialsGetWithProviderName(t *testing.T) {
stub := &stubProvider{}
c := NewCredentials(stub)
creds, err := c.GetWithContext(context.Background())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if e, a := creds.ProviderName, "stubProvider"; e != a {
t.Errorf("Expected provider name to match, %v got %v", e, a)
}
}
type MockProvider struct {
// The date/time when to expire on
expiration time.Time
// If set will be used by IsExpired to determine the current time.
// Defaults to time.Now if CurrentTime is not set. Available for testing
// to be able to mock out the current time.
CurrentTime func() time.Time
}
// IsExpired returns if the credentials are expired.
func (e *MockProvider) IsExpired() bool {
curTime := e.CurrentTime
if curTime == nil {
curTime = time.Now
}
return e.expiration.Before(curTime())
}
func (*MockProvider) Retrieve() (Value, error) {
return Value{}, nil
}
func TestCredentialsIsExpired_Race(_ *testing.T) {
creds := NewChainCredentials([]Provider{&MockProvider{}})
starter := make(chan struct{})
var wg sync.WaitGroup
wg.Add(10)
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
<-starter
for i := 0; i < 100; i++ {
isExpired(creds)
}
}()
}
close(starter)
wg.Wait()
}
type stubProviderConcurrent struct {
stubProvider
done chan struct{}
}
func (s *stubProviderConcurrent) Retrieve() (Value, error) {
<-s.done
return s.stubProvider.Retrieve()
}
func TestCredentialsGetConcurrent(t *testing.T) {
stub := &stubProviderConcurrent{
done: make(chan struct{}),
}
c := NewCredentials(stub)
done := make(chan struct{})
for i := 0; i < 2; i++ {
go func() {
_, err := c.GetWithContext(context.Background())
if err != nil {
t.Errorf("Expected no err, got %v", err)
}
done <- struct{}{}
}()
}
// Validates that a single call to Retrieve is shared between two calls to Get
stub.done <- struct{}{}
<-done
<-done
}

View File

@ -0,0 +1,51 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/header_rules.go
// See THIRD-PARTY-NOTICES for original license terms
package v4
// validator houses a set of rule needed for validation of a
// string value
type rules []rule
// rule interface allows for more flexible rules and just simply
// checks whether or not a value adheres to that rule
type rule interface {
IsValid(value string) bool
}
// IsValid will iterate through all rules and see if any rules
// apply to the value and supports nested rules
func (r rules) IsValid(value string) bool {
for _, rule := range r {
if rule.IsValid(value) {
return true
}
}
return false
}
// mapRule generic rule for maps
type mapRule map[string]struct{}
// IsValid for the map rule satisfies whether it exists in the map
func (m mapRule) IsValid(value string) bool {
_, ok := m[value]
return ok
}
// excludeList is a generic rule for exclude listing
type excludeList struct {
rule
}
// IsValid for exclude list checks if the value is within the exclude list
func (b excludeList) IsValid(value string) bool {
return !b.rule.IsValid(value)
}

View File

@ -0,0 +1,80 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/request/request.go
// See THIRD-PARTY-NOTICES for original license terms
package v4
import (
"net/http"
"strings"
)
// Returns host from request
func getHost(r *http.Request) string {
if r.Host != "" {
return r.Host
}
if r.URL == nil {
return ""
}
return r.URL.Host
}
// Hostname returns u.Host, without any port number.
//
// If Host is an IPv6 literal with a port number, Hostname returns the
// IPv6 literal without the square brackets. IPv6 literals may include
// a zone identifier.
//
// Copied from the Go 1.8 standard library (net/url)
func stripPort(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return hostport
}
if i := strings.IndexByte(hostport, ']'); i != -1 {
return strings.TrimPrefix(hostport[:i], "[")
}
return hostport[:colon]
}
// Port returns the port part of u.Host, without the leading colon.
// If u.Host doesn't contain a port, Port returns an empty string.
//
// Copied from the Go 1.8 standard library (net/url)
func portOnly(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return ""
}
if i := strings.Index(hostport, "]:"); i != -1 {
return hostport[i+len("]:"):]
}
if strings.Contains(hostport, "]") {
return ""
}
return hostport[colon+len(":"):]
}
// Returns true if the specified URI is using the standard port
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
func isDefaultPort(scheme, port string) bool {
if port == "" {
return true
}
lowerCaseScheme := strings.ToLower(scheme)
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
return true
}
return false
}

View File

@ -0,0 +1,65 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/uri_path.go
// - github.com/aws/aws-sdk-go/blob/v1.44.225/private/protocol/rest/build.go
// See THIRD-PARTY-NOTICES for original license terms
package v4
import (
"bytes"
"fmt"
"net/url"
"strings"
)
// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool
func init() {
for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
}
}
func getURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.EscapedPath()
}
if len(uri) == 0 {
uri = "/"
}
return uri
}
// EscapePath escapes part of a URL path in Amazon style
func EscapePath(path string, encodeSep bool) string {
var buf bytes.Buffer
for i := 0; i < len(path); i++ {
c := path[i]
if noEscape[c] || (c == '/' && !encodeSep) {
buf.WriteByte(c)
} else {
fmt.Fprintf(&buf, "%%%02X", c)
}
}
return buf.String()
}

View File

@ -0,0 +1,421 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/v4.go
// See THIRD-PARTY-NOTICES for original license terms
package v4
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"sort"
"strings"
"time"
"gitea.psichedelico.com/go/bson/internal/aws"
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
)
const (
authorizationHeader = "Authorization"
authHeaderSignatureElem = "Signature="
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
awsV4Request = "aws4_request"
// emptyStringSHA256 is a SHA256 of an empty string
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
)
var ignoredHeaders = rules{
excludeList{
mapRule{
authorizationHeader: struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
},
},
}
// Signer applies AWS v4 signing to given request. Use this to sign requests
// that need to be signed with AWS V4 Signatures.
type Signer struct {
// The authentication credentials the request will be signed against.
// This value must be set to sign requests.
Credentials *credentials.Credentials
}
// NewSigner returns a Signer pointer configured with the credentials provided.
func NewSigner(credentials *credentials.Credentials) *Signer {
v4 := &Signer{
Credentials: credentials,
}
return v4
}
type signingCtx struct {
ServiceName string
Region string
Request *http.Request
Body io.ReadSeeker
Query url.Values
Time time.Time
SignedHeaderVals http.Header
credValues credentials.Value
bodyDigest string
signedHeaders string
canonicalHeaders string
canonicalString string
credentialString string
stringToSign string
signature string
}
// Sign signs AWS v4 requests with the provided body, service name, region the
// request is made to, and time the request is signed at. The signTime allows
// you to specify that a request is signed for the future, and cannot be
// used until then.
//
// Returns a list of HTTP headers that were included in the signature or an
// error if signing the request failed. Generally for signed requests this value
// is not needed as the full request context will be captured by the http.Request
// value. It is included for reference though.
//
// Sign will set the request's Body to be the `body` parameter passed in. If
// the body is not already an io.ReadCloser, it will be wrapped within one. If
// a `nil` body parameter passed to Sign, the request's Body field will be
// also set to nil. Its important to note that this functionality will not
// change the request's ContentLength of the request.
//
// Sign differs from Presign in that it will sign the request using HTTP
// header values. This type of signing is intended for http.Request values that
// will not be shared, or are shared in a way the header values on the request
// will not be lost.
//
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
// generated. To bypass the signer computing the hash you can set the
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
// only compute the hash if the request header value is empty.
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
return v4.signWithBody(r, body, service, region, signTime)
}
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
ctx := &signingCtx{
Request: r,
Body: body,
Query: r.URL.Query(),
Time: signTime,
ServiceName: service,
Region: region,
}
for key := range ctx.Query {
sort.Strings(ctx.Query[key])
}
if ctx.isRequestSigned() {
ctx.Time = time.Now()
}
var err error
ctx.credValues, err = v4.Credentials.GetWithContext(r.Context())
if err != nil {
return http.Header{}, err
}
ctx.sanitizeHostForHeader()
ctx.assignAmzQueryValues()
if err := ctx.build(); err != nil {
return nil, err
}
var reader io.ReadCloser
if body != nil {
var ok bool
if reader, ok = body.(io.ReadCloser); !ok {
reader = ioutil.NopCloser(body)
}
}
r.Body = reader
return ctx.SignedHeaderVals, nil
}
// sanitizeHostForHeader removes default port from host and updates request.Host
func (ctx *signingCtx) sanitizeHostForHeader() {
r := ctx.Request
host := getHost(r)
port := portOnly(host)
if port != "" && isDefaultPort(r.URL.Scheme, port) {
r.Host = stripPort(host)
}
}
func (ctx *signingCtx) assignAmzQueryValues() {
if ctx.credValues.SessionToken != "" {
ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
}
}
func (ctx *signingCtx) build() error {
ctx.buildTime() // no depends
ctx.buildCredentialString() // no depends
if err := ctx.buildBodyDigest(); err != nil {
return err
}
unsignedHeaders := ctx.Request.Header
ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
ctx.buildCanonicalString() // depends on canon headers / signed headers
ctx.buildStringToSign() // depends on canon string
ctx.buildSignature() // depends on string to sign
parts := []string{
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
"SignedHeaders=" + ctx.signedHeaders,
authHeaderSignatureElem + ctx.signature,
}
ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
return nil
}
func (ctx *signingCtx) buildTime() {
ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
}
func (ctx *signingCtx) buildCredentialString() {
ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
}
func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
headers := make([]string, 0, len(header)+1)
headers = append(headers, "host")
for k, v := range header {
if !r.IsValid(k) {
continue // ignored header
}
if ctx.SignedHeaderVals == nil {
ctx.SignedHeaderVals = make(http.Header)
}
lowerCaseKey := strings.ToLower(k)
if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok {
// include additional values
ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...)
continue
}
headers = append(headers, lowerCaseKey)
ctx.SignedHeaderVals[lowerCaseKey] = v
}
sort.Strings(headers)
ctx.signedHeaders = strings.Join(headers, ";")
headerItems := make([]string, len(headers))
for i, k := range headers {
if k == "host" {
if ctx.Request.Host != "" {
headerItems[i] = "host:" + ctx.Request.Host
} else {
headerItems[i] = "host:" + ctx.Request.URL.Host
}
} else {
headerValues := make([]string, len(ctx.SignedHeaderVals[k]))
for i, v := range ctx.SignedHeaderVals[k] {
headerValues[i] = strings.TrimSpace(v)
}
headerItems[i] = k + ":" +
strings.Join(headerValues, ",")
}
}
stripExcessSpaces(headerItems)
ctx.canonicalHeaders = strings.Join(headerItems, "\n")
}
func (ctx *signingCtx) buildCanonicalString() {
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
uri := getURIPath(ctx.Request.URL)
uri = EscapePath(uri, false)
ctx.canonicalString = strings.Join([]string{
ctx.Request.Method,
uri,
ctx.Request.URL.RawQuery,
ctx.canonicalHeaders + "\n",
ctx.signedHeaders,
ctx.bodyDigest,
}, "\n")
}
func (ctx *signingCtx) buildStringToSign() {
ctx.stringToSign = strings.Join([]string{
authHeaderPrefix,
formatTime(ctx.Time),
ctx.credentialString,
hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
}, "\n")
}
func (ctx *signingCtx) buildSignature() {
creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
signature := hmacSHA256(creds, []byte(ctx.stringToSign))
ctx.signature = hex.EncodeToString(signature)
}
func (ctx *signingCtx) buildBodyDigest() error {
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if ctx.Body == nil {
hash = emptyStringSHA256
} else {
if !aws.IsReaderSeekable(ctx.Body) {
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
}
hashBytes, err := makeSha256Reader(ctx.Body)
if err != nil {
return err
}
hash = hex.EncodeToString(hashBytes)
}
}
ctx.bodyDigest = hash
return nil
}
// isRequestSigned returns if the request is currently signed or presigned
func (ctx *signingCtx) isRequestSigned() bool {
return ctx.Request.Header.Get("Authorization") != ""
}
func hmacSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func hashSHA256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
hash := sha256.New()
start, err := reader.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
defer func() {
// ensure error is return if unable to seek back to start of payload.
_, err = reader.Seek(start, io.SeekStart)
}()
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
size, err := aws.SeekerLen(reader)
if err != nil {
_, _ = io.Copy(hash, reader)
} else {
_, _ = io.CopyN(hash, reader, size)
}
return hash.Sum(nil), nil
}
const doubleSpace = " "
// stripExcessSpaces will rewrite the passed in slice's string values to not
// contain multiple side-by-side spaces.
func stripExcessSpaces(vals []string) {
var j, k, l, m, spaces int
for i, str := range vals {
// revive:disable:empty-block
// Trim trailing spaces
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
}
// Trim leading spaces
for k = 0; k < j && str[k] == ' '; k++ {
}
// revive:enable:empty-block
str = str[k : j+1]
// Strip multiple spaces.
j = strings.Index(str, doubleSpace)
if j < 0 {
vals[i] = str
continue
}
buf := []byte(str)
for k, m, l = j, j, len(buf); k < l; k++ {
if buf[k] == ' ' {
if spaces == 0 {
// First space.
buf[m] = buf[k]
m++
}
spaces++
} else {
// End of multiple spaces.
spaces = 0
buf[m] = buf[k]
m++
}
}
vals[i] = string(buf[:m])
}
}
func buildSigningScope(region, service string, dt time.Time) string {
return strings.Join([]string{
formatShortTime(dt),
region,
service,
awsV4Request,
}, "/")
}
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
keyDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
keyRegion := hmacSHA256(keyDate, []byte(region))
keyService := hmacSHA256(keyRegion, []byte(service))
signingKey := hmacSHA256(keyService, []byte(awsV4Request))
return signingKey
}
func formatShortTime(dt time.Time) string {
return dt.UTC().Format(shortTimeFormat)
}
func formatTime(dt time.Time) string {
return dt.UTC().Format(timeFormat)
}

View File

@ -0,0 +1,434 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/v4_test.go
// See THIRD-PARTY-NOTICES for original license terms
package v4
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"strings"
"testing"
"time"
"gitea.psichedelico.com/go/bson/internal/aws"
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
"gitea.psichedelico.com/go/bson/internal/credproviders"
)
func epochTime() time.Time { return time.Unix(0, 0) }
func TestStripExcessHeaders(t *testing.T) {
vals := []string{
"",
"123",
"1 2 3",
"1 2 3 ",
" 1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2 ",
" 1 2 ",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
expected := []string{
"",
"123",
"1 2 3",
"1 2 3",
"1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2",
"1 2",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
stripExcessSpaces(vals)
for i := 0; i < len(vals); i++ {
if e, a := expected[i], vals[i]; e != a {
t.Errorf("%d, expect %v, got %v", i, e, a)
}
}
}
func buildRequest(body string) (*http.Request, io.ReadSeeker) {
reader := strings.NewReader(body)
return buildRequestWithBodyReader("dynamodb", "us-east-1", reader)
}
func buildRequestReaderSeeker(serviceName, region, body string) (*http.Request, io.ReadSeeker) {
reader := &readerSeekerWrapper{strings.NewReader(body)}
return buildRequestWithBodyReader(serviceName, region, reader)
}
func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, io.ReadSeeker) {
var bodyLen int
type lenner interface {
Len() int
}
if lr, ok := body.(lenner); ok {
bodyLen = lr.Len()
}
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
req, _ := http.NewRequest("POST", endpoint, body)
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
req.Header.Set("X-Amz-Target", "prefix.Operation")
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
if bodyLen > 0 {
req.Header.Set("Content-Length", strconv.Itoa(bodyLen))
}
req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
var seeker io.ReadSeeker
if sr, ok := body.(io.ReadSeeker); ok {
seeker = sr
} else {
seeker = aws.ReadSeekCloser(body)
}
return req, seeker
}
func buildSigner() Signer {
return Signer{
Credentials: newTestStaticCredentials(),
}
}
func newTestStaticCredentials() *credentials.Credentials {
return credentials.NewCredentials(&credproviders.StaticProvider{Value: credentials.Value{
AccessKeyID: "AKID",
SecretAccessKey: "SECRET",
SessionToken: "SESSION",
}})
}
func TestSignRequest(t *testing.T) {
req, body := buildRequest("{}")
signer := buildSigner()
_, err := signer.Sign(req, body, "dynamodb", "us-east-1", epochTime())
if err != nil {
t.Errorf("Expected no err, got %v", err)
}
expectedDate := "19700101T000000Z"
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
q := req.Header
if e, a := expectedSig, q.Get("Authorization"); e != a {
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
}
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
}
}
func TestSignUnseekableBody(t *testing.T) {
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
signer := buildSigner()
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
if err == nil {
t.Fatalf("expect error signing request")
}
if e, a := "unseekable request body", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %q to be in %q", e, a)
}
}
func TestSignPreComputedHashUnseekableBody(t *testing.T) {
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
signer := buildSigner()
req.Header.Set("X-Amz-Content-Sha256", "some-content-sha256")
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "some-content-sha256", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignPrecomputedBodyChecksum(t *testing.T) {
req, body := buildRequest("hello")
req.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
signer := buildSigner()
_, err := signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
if err != nil {
t.Errorf("Expected no err, got %v", err)
}
hash := req.Header.Get("X-Amz-Content-Sha256")
if e, a := "PRECOMPUTED", hash; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignWithRequestBody(t *testing.T) {
creds := newTestStaticCredentials()
signer := NewSigner(creds)
expectBody := []byte("abc123")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := expectBody, b; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v, got %v", e, a)
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
req, err := http.NewRequest("POST", server.URL, nil)
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
_, err = signer.Sign(req, bytes.NewReader(expectBody), "service", "region", time.Now())
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
if e, a := http.StatusOK, resp.StatusCode; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignWithRequestBody_Overwrite(t *testing.T) {
creds := newTestStaticCredentials()
signer := NewSigner(creds)
var expectBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
if e, a := len(expectBody), len(b); e != a {
t.Errorf("expect %v, got %v", e, a)
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
req, err := http.NewRequest("GET", server.URL, strings.NewReader("invalid body"))
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
_, err = signer.Sign(req, nil, "service", "region", time.Now())
req.ContentLength = 0
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("expect not no error, got %v", err)
}
if e, a := http.StatusOK, resp.StatusCode; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestBuildCanonicalRequest(t *testing.T) {
req, body := buildRequest("{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
ctx := &signingCtx{
ServiceName: "dynamodb",
Region: "us-east-1",
Request: req,
Body: body,
Query: req.URL.Query(),
Time: time.Now(),
}
ctx.buildCanonicalString()
expected := "https://example.org/bucket/key-._~,!@#$%^&*()?Foo=z&Foo=o&Foo=m&Foo=a"
if e, a := expected, ctx.Request.URL.String(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
}
func TestSignWithBody_ReplaceRequestBody(t *testing.T) {
creds := newTestStaticCredentials()
req, seekerBody := buildRequest("{}")
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
s := NewSigner(creds)
origBody := req.Body
_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if req.Body == origBody {
t.Errorf("expect request body to not be origBody")
}
if req.Body == nil {
t.Errorf("expect request body to be changed but was nil")
}
}
func TestRequestHost(t *testing.T) {
req, body := buildRequest("{}")
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
req.Host = "myhost"
ctx := &signingCtx{
ServiceName: "dynamodb",
Region: "us-east-1",
Request: req,
Body: body,
Query: req.URL.Query(),
Time: time.Now(),
}
ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
if !strings.Contains(ctx.canonicalHeaders, "host:"+req.Host) {
t.Errorf("canonical host header invalid")
}
}
func TestSign_buildCanonicalHeaders(t *testing.T) {
serviceName := "mockAPI"
region := "mock-region"
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
req, err := http.NewRequest("POST", endpoint, nil)
if err != nil {
t.Fatalf("failed to create request, %v", err)
}
req.Header.Set("FooInnerSpace", " inner space ")
req.Header.Set("FooLeadingSpace", " leading-space")
req.Header.Add("FooMultipleSpace", "no-space")
req.Header.Add("FooMultipleSpace", "\ttab-space")
req.Header.Add("FooMultipleSpace", "trailing-space ")
req.Header.Set("FooNoSpace", "no-space")
req.Header.Set("FooTabSpace", "\ttab-space\t")
req.Header.Set("FooTrailingSpace", "trailing-space ")
req.Header.Set("FooWrappedSpace", " wrapped-space ")
ctx := &signingCtx{
ServiceName: serviceName,
Region: region,
Request: req,
Body: nil,
Query: req.URL.Query(),
Time: time.Now(),
}
ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
expectCanonicalHeaders := strings.Join([]string{
`fooinnerspace:inner space`,
`fooleadingspace:leading-space`,
`foomultiplespace:no-space,tab-space,trailing-space`,
`foonospace:no-space`,
`footabspace:tab-space`,
`footrailingspace:trailing-space`,
`foowrappedspace:wrapped-space`,
`host:mockAPI.mock-region.amazonaws.com`,
}, "\n")
if e, a := expectCanonicalHeaders, ctx.canonicalHeaders; e != a {
t.Errorf("expect:\n%s\n\nactual:\n%s", e, a)
}
}
func BenchmarkSignRequest(b *testing.B) {
signer := buildSigner()
req, body := buildRequestReaderSeeker("dynamodb", "us-east-1", "{}")
for i := 0; i < b.N; i++ {
_, err := signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
if err != nil {
b.Errorf("Expected no err, got %v", err)
}
}
}
var stripExcessSpaceCases = []string{
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
`123 321 123 321`,
` 123 321 123 321 `,
` 123 321 123 321 `,
"123",
"1 2 3",
" 1 2 3",
"1 2 3",
"1 23",
"1 2 3",
"1 2 ",
" 1 2 ",
"12 3",
"12 3 1",
"12 3 1",
"12 3 1abc123",
}
func BenchmarkStripExcessSpaces(b *testing.B) {
for i := 0; i < b.N; i++ {
// Make sure to start with a copy of the cases
cases := append([]string{}, stripExcessSpaceCases...)
stripExcessSpaces(cases)
}
}
// readerSeekerWrapper mimics the interface provided by request.offsetReader
type readerSeekerWrapper struct {
r *strings.Reader
}
func (r *readerSeekerWrapper) Read(p []byte) (n int, err error) {
return r.r.Read(p)
}
func (r *readerSeekerWrapper) Seek(offset int64, whence int) (int64, error) {
return r.r.Seek(offset, whence)
}
func (r *readerSeekerWrapper) Len() int {
return r.r.Len()
}

153
internal/aws/types.go Normal file
View File

@ -0,0 +1,153 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/types.go
// See THIRD-PARTY-NOTICES for original license terms
package aws
import (
"io"
)
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
// SDK to accept an io.Reader that is not also an io.Seeker for unsigned
// streaming payload API operations.
//
// A ReadSeekCloser wrapping an nonseekable io.Reader used in an API
// operation's input will prevent that operation being retried in the case of
// network errors, and cause operation requests to fail if the operation
// requires payload signing.
//
// Note: If using With S3 PutObject to stream an object upload The SDK's S3
// Upload manager (s3manager.Uploader) provides support for streaming with the
// ability to retry network errors.
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
return ReaderSeekerCloser{r}
}
// ReaderSeekerCloser represents a reader that can also delegate io.Seeker and
// io.Closer interfaces to the underlying object if they are available.
type ReaderSeekerCloser struct {
r io.Reader
}
// IsReaderSeekable returns if the underlying reader type can be seeked. A
// io.Reader might not actually be seekable if it is the ReaderSeekerCloser
// type.
func IsReaderSeekable(r io.Reader) bool {
switch v := r.(type) {
case ReaderSeekerCloser:
return v.IsSeeker()
case *ReaderSeekerCloser:
return v.IsSeeker()
case io.ReadSeeker:
return true
default:
return false
}
}
// Read reads from the reader up to size of p. The number of bytes read, and
// error if it occurred will be returned.
//
// If the reader is not an io.Reader zero bytes read, and nil error will be
// returned.
//
// Performs the same functionality as io.Reader Read
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
switch t := r.r.(type) {
case io.Reader:
return t.Read(p)
}
return 0, nil
}
// Seek sets the offset for the next Read to offset, interpreted according to
// whence: 0 means relative to the origin of the file, 1 means relative to the
// current offset, and 2 means relative to the end. Seek returns the new offset
// and an error, if any.
//
// If the ReaderSeekerCloser is not an io.Seeker nothing will be done.
func (r ReaderSeekerCloser) Seek(offset int64, whence int) (int64, error) {
switch t := r.r.(type) {
case io.Seeker:
return t.Seek(offset, whence)
}
return int64(0), nil
}
// IsSeeker returns if the underlying reader is also a seeker.
func (r ReaderSeekerCloser) IsSeeker() bool {
_, ok := r.r.(io.Seeker)
return ok
}
// HasLen returns the length of the underlying reader if the value implements
// the Len() int method.
func (r ReaderSeekerCloser) HasLen() (int, bool) {
type lenner interface {
Len() int
}
if lr, ok := r.r.(lenner); ok {
return lr.Len(), true
}
return 0, false
}
// GetLen returns the length of the bytes remaining in the underlying reader.
// Checks first for Len(), then io.Seeker to determine the size of the
// underlying reader.
//
// Will return -1 if the length cannot be determined.
func (r ReaderSeekerCloser) GetLen() (int64, error) {
if l, ok := r.HasLen(); ok {
return int64(l), nil
}
if s, ok := r.r.(io.Seeker); ok {
return seekerLen(s)
}
return -1, nil
}
// SeekerLen attempts to get the number of bytes remaining at the seeker's
// current position. Returns the number of bytes remaining or error.
func SeekerLen(s io.Seeker) (int64, error) {
// Determine if the seeker is actually seekable. ReaderSeekerCloser
// hides the fact that a io.Readers might not actually be seekable.
switch v := s.(type) {
case ReaderSeekerCloser:
return v.GetLen()
case *ReaderSeekerCloser:
return v.GetLen()
}
return seekerLen(s)
}
func seekerLen(s io.Seeker) (int64, error) {
curOffset, err := s.Seek(0, io.SeekCurrent)
if err != nil {
return 0, err
}
endOffset, err := s.Seek(0, io.SeekEnd)
if err != nil {
return 0, err
}
_, err = s.Seek(curOffset, io.SeekStart)
if err != nil {
return 0, err
}
return endOffset - curOffset, nil
}

View File

@ -0,0 +1,40 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncoreutil
// Truncate truncates a given string for a certain width
func Truncate(str string, width int) string {
if width == 0 {
return ""
}
if len(str) <= width {
return str
}
// Truncate the byte slice of the string to the given width.
newStr := str[:width]
// Check if the last byte is at the beginning of a multi-byte character.
// If it is, then remove the last byte.
if newStr[len(newStr)-1]&0xC0 == 0xC0 {
return newStr[:len(newStr)-1]
}
// Check if the last byte is a multi-byte character
if newStr[len(newStr)-1]&0xC0 == 0x80 {
// If it is, step back until you we are at the start of a character
for i := len(newStr) - 1; i >= 0; i-- {
if newStr[i]&0xC0 == 0xC0 {
// Truncate at the end of the character before the character we stepped back to
return newStr[:i]
}
}
}
return newStr
}

View File

@ -0,0 +1,59 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncoreutil
import (
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
)
func TestTruncate(t *testing.T) {
t.Parallel()
for _, tcase := range []struct {
name string
arg string
width int
expected string
}{
{
name: "empty",
arg: "",
width: 0,
expected: "",
},
{
name: "short",
arg: "foo",
width: 1000,
expected: "foo",
},
{
name: "long",
arg: "foo bar baz",
width: 9,
expected: "foo bar b",
},
{
name: "multi-byte",
arg: "你好",
width: 4,
expected: "你",
},
} {
tcase := tcase
t.Run(tcase.name, func(t *testing.T) {
t.Parallel()
actual := Truncate(tcase.arg, tcase.width)
assert.Equal(t, tcase.expected, actual)
})
}
}

View File

@ -0,0 +1,62 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonutil
import (
"fmt"
"gitea.psichedelico.com/go/bson"
)
// StringSliceFromRawValue decodes the provided BSON value into a []string. This function returns an error if the value
// is not an array or any of the elements in the array are not strings. The name parameter is used to add context to
// error messages.
func StringSliceFromRawValue(name string, val bson.RawValue) ([]string, error) {
arr, ok := val.ArrayOK()
if !ok {
return nil, fmt.Errorf("expected '%s' to be an array but it's a BSON %s", name, val.Type)
}
arrayValues, err := arr.Values()
if err != nil {
return nil, err
}
strs := make([]string, 0, len(arrayValues))
for _, arrayVal := range arrayValues {
str, ok := arrayVal.StringValueOK()
if !ok {
return nil, fmt.Errorf("expected '%s' to be an array of strings, but found a BSON %s", name, arrayVal.Type)
}
strs = append(strs, str)
}
return strs, nil
}
// RawArrayToDocuments converts an array of documents to []bson.Raw.
func RawArrayToDocuments(arr bson.RawArray) []bson.Raw {
values, err := arr.Values()
if err != nil {
panic(fmt.Sprintf("error converting BSON document to values: %v", err))
}
out := make([]bson.Raw, len(values))
for i := range values {
out[i] = values[i].Document()
}
return out
}
// RawToInterfaces takes one or many bson.Raw documents and returns them as a []interface{}.
func RawToInterfaces(docs ...bson.Raw) []interface{} {
out := make([]interface{}, len(docs))
for i := range docs {
out[i] = docs[i]
}
return out
}

View File

@ -0,0 +1,62 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package codecutil
import (
"bytes"
"errors"
"fmt"
"io"
"reflect"
"gitea.psichedelico.com/go/bson"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
var ErrNilValue = errors.New("value is nil")
// MarshalError is returned when attempting to transform a value into a document
// results in an error.
type MarshalError struct {
Value interface{}
Err error
}
// Error implements the error interface.
func (e MarshalError) Error() string {
return fmt.Sprintf("cannot transform type %s to a BSON Document: %v",
reflect.TypeOf(e.Value), e.Err)
}
// EncoderFn is used to functionally construct an encoder for marshaling values.
type EncoderFn func(io.Writer) *bson.Encoder
// MarshalValue will attempt to encode the value with the encoder returned by
// the encoder function.
func MarshalValue(val interface{}, encFn EncoderFn) (bsoncore.Value, error) {
// If the val is already a bsoncore.Value, then do nothing.
if bval, ok := val.(bsoncore.Value); ok {
return bval, nil
}
if val == nil {
return bsoncore.Value{}, ErrNilValue
}
buf := new(bytes.Buffer)
enc := encFn(buf)
// Encode the value in a single-element document with an empty key. Use
// bsoncore to extract the first element and return the BSON value.
err := enc.Encode(bson.D{{Key: "", Value: val}})
if err != nil {
return bsoncore.Value{}, MarshalError{Value: val, Err: err}
}
return bsoncore.Document(buf.Bytes()).Index(0).Value(), nil
}

View File

@ -0,0 +1,82 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package codecutil
import (
"io"
"testing"
"gitea.psichedelico.com/go/bson"
"gitea.psichedelico.com/go/bson/internal/assert"
)
func testEncFn(t *testing.T) EncoderFn {
t.Helper()
return func(w io.Writer) *bson.Encoder {
rw := bson.NewDocumentWriter(w)
return bson.NewEncoder(rw)
}
}
func TestMarshalValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
val interface{}
registry *bson.Registry
encFn EncoderFn
want string
wantErr error
}{
{
name: "empty",
val: nil,
want: "",
wantErr: ErrNilValue,
encFn: testEncFn(t),
},
{
name: "bson.D",
val: bson.D{{"foo", "bar"}},
want: `{"foo": "bar"}`,
encFn: testEncFn(t),
},
{
name: "map",
val: map[string]interface{}{"foo": "bar"},
want: `{"foo": "bar"}`,
encFn: testEncFn(t),
},
{
name: "struct",
val: struct{ Foo string }{Foo: "bar"},
want: `{"foo": "bar"}`,
encFn: testEncFn(t),
},
{
name: "non-document type",
val: "foo: bar",
want: `"foo: bar"`,
encFn: testEncFn(t),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
value, err := MarshalValue(test.val, test.encFn)
assert.Equal(t, test.wantErr, err, "expected and actual error do not match")
assert.Equal(t, test.want, value.String(), "expected and actual comments are different")
})
}
}

View File

@ -0,0 +1,148 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package credproviders
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"time"
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
"gitea.psichedelico.com/go/bson/internal/uuid"
)
const (
// assumeRoleProviderName provides a name of assume role provider
assumeRoleProviderName = "AssumeRoleProvider"
stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15`
)
// An AssumeRoleProvider retrieves credentials for assume role with web identity.
type AssumeRoleProvider struct {
AwsRoleArnEnv EnvVar
AwsWebIdentityTokenFileEnv EnvVar
AwsRoleSessionNameEnv EnvVar
httpClient *http.Client
expiration time.Time
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
expiryWindow time.Duration
}
// NewAssumeRoleProvider returns a pointer to an assume role provider.
func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider {
return &AssumeRoleProvider{
// AwsRoleArnEnv is the environment variable for AWS_ROLE_ARN
AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"),
// AwsWebIdentityTokenFileEnv is the environment variable for AWS_WEB_IDENTITY_TOKEN_FILE
AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"),
// AwsRoleSessionNameEnv is the environment variable for AWS_ROLE_SESSION_NAME
AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"),
httpClient: httpClient,
expiryWindow: expiryWindow,
}
}
// RetrieveWithContext retrieves the keys from the AWS service.
func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
const defaultHTTPTimeout = 10 * time.Second
v := credentials.Value{ProviderName: assumeRoleProviderName}
roleArn := a.AwsRoleArnEnv.Get()
tokenFile := a.AwsWebIdentityTokenFileEnv.Get()
if tokenFile == "" && roleArn == "" {
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing")
}
if tokenFile != "" && roleArn == "" {
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing")
}
if tokenFile == "" && roleArn != "" {
return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing")
}
token, err := ioutil.ReadFile(tokenFile)
if err != nil {
return v, err
}
sessionName := a.AwsRoleSessionNameEnv.Get()
if sessionName == "" {
// Use a UUID if the RoleSessionName is not given.
id, err := uuid.New()
if err != nil {
return v, err
}
sessionName = id.String()
}
fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token))
req, err := http.NewRequest(http.MethodPost, fullURI, nil)
if err != nil {
return v, err
}
req.Header.Set("Accept", "application/json")
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
defer cancel()
resp, err := a.httpClient.Do(req.WithContext(ctx))
if err != nil {
return v, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return v, fmt.Errorf("response failure: %s", resp.Status)
}
var stsResp struct {
Response struct {
Result struct {
Credentials struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string `json:"SecretAccessKey"`
Token string `json:"SessionToken"`
Expiration float64 `json:"Expiration"`
} `json:"Credentials"`
} `json:"AssumeRoleWithWebIdentityResult"`
} `json:"AssumeRoleWithWebIdentityResponse"`
}
err = json.NewDecoder(resp.Body).Decode(&stsResp)
if err != nil {
return v, err
}
v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID
v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey
v.SessionToken = stsResp.Response.Result.Credentials.Token
if !v.HasKeys() {
return v, errors.New("failed to retrieve web identity keys")
}
sec := int64(stsResp.Response.Result.Credentials.Expiration)
a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow)
return v, nil
}
// Retrieve retrieves the keys from the AWS service.
func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
return a.RetrieveWithContext(context.Background())
}
// IsExpired returns true if the credentials are expired.
func (a *AssumeRoleProvider) IsExpired() bool {
return a.expiration.Before(time.Now())
}

View File

@ -0,0 +1,183 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package credproviders
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"time"
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
)
const (
// ec2ProviderName provides a name of EC2 provider
ec2ProviderName = "EC2Provider"
awsEC2URI = "http://169.254.169.254/"
awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
awsEC2TokenPath = "latest/api/token"
defaultHTTPTimeout = 10 * time.Second
)
// An EC2Provider retrieves credentials from EC2 metadata.
type EC2Provider struct {
httpClient *http.Client
expiration time.Time
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
expiryWindow time.Duration
}
// NewEC2Provider returns a pointer to an EC2 credential provider.
func NewEC2Provider(httpClient *http.Client, expiryWindow time.Duration) *EC2Provider {
return &EC2Provider{
httpClient: httpClient,
expiryWindow: expiryWindow,
}
}
func (e *EC2Provider) getToken(ctx context.Context) (string, error) {
req, err := http.NewRequest(http.MethodPut, awsEC2URI+awsEC2TokenPath, nil)
if err != nil {
return "", err
}
const defaultEC2TTLSeconds = "30"
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", defaultEC2TTLSeconds)
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
defer cancel()
resp, err := e.httpClient.Do(req.WithContext(ctx))
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
}
token, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
if len(token) == 0 {
return "", errors.New("unable to retrieve token from EC2 metadata")
}
return string(token), nil
}
func (e *EC2Provider) getRoleName(ctx context.Context, token string) (string, error) {
req, err := http.NewRequest(http.MethodGet, awsEC2URI+awsEC2RolePath, nil)
if err != nil {
return "", err
}
req.Header.Set("X-aws-ec2-metadata-token", token)
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
defer cancel()
resp, err := e.httpClient.Do(req.WithContext(ctx))
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
}
role, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
if len(role) == 0 {
return "", errors.New("unable to retrieve role_name from EC2 metadata")
}
return string(role), nil
}
func (e *EC2Provider) getCredentials(ctx context.Context, token string, role string) (credentials.Value, time.Time, error) {
v := credentials.Value{ProviderName: ec2ProviderName}
pathWithRole := awsEC2URI + awsEC2RolePath + role
req, err := http.NewRequest(http.MethodGet, pathWithRole, nil)
if err != nil {
return v, time.Time{}, err
}
req.Header.Set("X-aws-ec2-metadata-token", token)
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
defer cancel()
resp, err := e.httpClient.Do(req.WithContext(ctx))
if err != nil {
return v, time.Time{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return v, time.Time{}, fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
}
var ec2Resp struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string `json:"SecretAccessKey"`
Token string `json:"Token"`
Expiration time.Time `json:"Expiration"`
}
err = json.NewDecoder(resp.Body).Decode(&ec2Resp)
if err != nil {
return v, time.Time{}, err
}
v.AccessKeyID = ec2Resp.AccessKeyID
v.SecretAccessKey = ec2Resp.SecretAccessKey
v.SessionToken = ec2Resp.Token
return v, ec2Resp.Expiration, nil
}
// RetrieveWithContext retrieves the keys from the AWS service.
func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
v := credentials.Value{ProviderName: ec2ProviderName}
token, err := e.getToken(ctx)
if err != nil {
return v, err
}
role, err := e.getRoleName(ctx, token)
if err != nil {
return v, err
}
v, exp, err := e.getCredentials(ctx, token, role)
if err != nil {
return v, err
}
if !v.HasKeys() {
return v, errors.New("failed to retrieve EC2 keys")
}
e.expiration = exp.Add(-e.expiryWindow)
return v, nil
}
// Retrieve retrieves the keys from the AWS service.
func (e *EC2Provider) Retrieve() (credentials.Value, error) {
return e.RetrieveWithContext(context.Background())
}
// IsExpired returns true if the credentials are expired.
func (e *EC2Provider) IsExpired() bool {
return e.expiration.Before(time.Now())
}

View File

@ -0,0 +1,112 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package credproviders
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
)
const (
// ecsProviderName provides a name of ECS provider
ecsProviderName = "ECSProvider"
awsRelativeURI = "http://169.254.170.2/"
)
// An ECSProvider retrieves credentials from ECS metadata.
type ECSProvider struct {
AwsContainerCredentialsRelativeURIEnv EnvVar
httpClient *http.Client
expiration time.Time
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
expiryWindow time.Duration
}
// NewECSProvider returns a pointer to an ECS credential provider.
func NewECSProvider(httpClient *http.Client, expiryWindow time.Duration) *ECSProvider {
return &ECSProvider{
// AwsContainerCredentialsRelativeURIEnv is the environment variable for AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
AwsContainerCredentialsRelativeURIEnv: EnvVar("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"),
httpClient: httpClient,
expiryWindow: expiryWindow,
}
}
// RetrieveWithContext retrieves the keys from the AWS service.
func (e *ECSProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
const defaultHTTPTimeout = 10 * time.Second
v := credentials.Value{ProviderName: ecsProviderName}
relativeEcsURI := e.AwsContainerCredentialsRelativeURIEnv.Get()
if len(relativeEcsURI) == 0 {
return v, errors.New("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is missing")
}
fullURI := awsRelativeURI + relativeEcsURI
req, err := http.NewRequest(http.MethodGet, fullURI, nil)
if err != nil {
return v, err
}
req.Header.Set("Accept", "application/json")
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
defer cancel()
resp, err := e.httpClient.Do(req.WithContext(ctx))
if err != nil {
return v, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return v, fmt.Errorf("response failure: %s", resp.Status)
}
var ecsResp struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string `json:"SecretAccessKey"`
Token string `json:"Token"`
Expiration time.Time `json:"Expiration"`
}
err = json.NewDecoder(resp.Body).Decode(&ecsResp)
if err != nil {
return v, err
}
v.AccessKeyID = ecsResp.AccessKeyID
v.SecretAccessKey = ecsResp.SecretAccessKey
v.SessionToken = ecsResp.Token
if !v.HasKeys() {
return v, errors.New("failed to retrieve ECS keys")
}
e.expiration = ecsResp.Expiration.Add(-e.expiryWindow)
return v, nil
}
// Retrieve retrieves the keys from the AWS service.
func (e *ECSProvider) Retrieve() (credentials.Value, error) {
return e.RetrieveWithContext(context.Background())
}
// IsExpired returns true if the credentials are expired.
func (e *ECSProvider) IsExpired() bool {
return e.expiration.Before(time.Now())
}

View File

@ -0,0 +1,69 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package credproviders
import (
"os"
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
)
// envProviderName provides a name of Env provider
const envProviderName = "EnvProvider"
// EnvVar is an environment variable
type EnvVar string
// Get retrieves the environment variable
func (ev EnvVar) Get() string {
return os.Getenv(string(ev))
}
// A EnvProvider retrieves credentials from the environment variables of the
// running process. Environment credentials never expire.
type EnvProvider struct {
AwsAccessKeyIDEnv EnvVar
AwsSecretAccessKeyEnv EnvVar
AwsSessionTokenEnv EnvVar
retrieved bool
}
// NewEnvProvider returns a pointer to an ECS credential provider.
func NewEnvProvider() *EnvProvider {
return &EnvProvider{
// AwsAccessKeyIDEnv is the environment variable for AWS_ACCESS_KEY_ID
AwsAccessKeyIDEnv: EnvVar("AWS_ACCESS_KEY_ID"),
// AwsSecretAccessKeyEnv is the environment variable for AWS_SECRET_ACCESS_KEY
AwsSecretAccessKeyEnv: EnvVar("AWS_SECRET_ACCESS_KEY"),
// AwsSessionTokenEnv is the environment variable for AWS_SESSION_TOKEN
AwsSessionTokenEnv: EnvVar("AWS_SESSION_TOKEN"),
}
}
// Retrieve retrieves the keys from the environment.
func (e *EnvProvider) Retrieve() (credentials.Value, error) {
e.retrieved = false
v := credentials.Value{
AccessKeyID: e.AwsAccessKeyIDEnv.Get(),
SecretAccessKey: e.AwsSecretAccessKeyEnv.Get(),
SessionToken: e.AwsSessionTokenEnv.Get(),
ProviderName: envProviderName,
}
err := verify(v)
if err == nil {
e.retrieved = true
}
return v, err
}
// IsExpired returns true if the credentials have not been retrieved.
func (e *EnvProvider) IsExpired() bool {
return !e.retrieved
}

View File

@ -0,0 +1,103 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package credproviders
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
)
const (
// AzureProviderName provides a name of Azure provider
AzureProviderName = "AzureProvider"
azureURI = "http://169.254.169.254/metadata/identity/oauth2/token"
)
// An AzureProvider retrieves credentials from Azure IMDS.
type AzureProvider struct {
httpClient *http.Client
expiration time.Time
expiryWindow time.Duration
}
// NewAzureProvider returns a pointer to an Azure credential provider.
func NewAzureProvider(httpClient *http.Client, expiryWindow time.Duration) *AzureProvider {
return &AzureProvider{
httpClient: httpClient,
expiration: time.Time{},
expiryWindow: expiryWindow,
}
}
// RetrieveWithContext retrieves the keys from the Azure service.
func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
v := credentials.Value{ProviderName: AzureProviderName}
req, err := http.NewRequest(http.MethodGet, azureURI, nil)
if err != nil {
return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
}
q := make(url.Values)
q.Set("api-version", "2018-02-01")
q.Set("resource", "https://vault.azure.net")
req.URL.RawQuery = q.Encode()
req.Header.Set("Metadata", "true")
req.Header.Set("Accept", "application/json")
resp, err := a.httpClient.Do(req.WithContext(ctx))
if err != nil {
return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return v, fmt.Errorf("unable to retrieve Azure credentials: error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return v, fmt.Errorf("unable to retrieve Azure credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body)
}
var tokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn string `json:"expires_in"`
}
// Attempt to read body as JSON
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return v, fmt.Errorf("unable to retrieve Azure credentials: error reading body JSON: %w (response body: %s)", err, body)
}
if tokenResponse.AccessToken == "" {
return v, fmt.Errorf("unable to retrieve Azure credentials: got unexpected empty accessToken from Azure Metadata Server. Response body: %s", body)
}
v.SessionToken = tokenResponse.AccessToken
expiresIn, err := time.ParseDuration(tokenResponse.ExpiresIn + "s")
if err != nil {
return v, err
}
if expiration := expiresIn - a.expiryWindow; expiration > 0 {
a.expiration = time.Now().Add(expiration)
}
return v, err
}
// Retrieve retrieves the keys from the Azure service.
func (a *AzureProvider) Retrieve() (credentials.Value, error) {
return a.RetrieveWithContext(context.Background())
}
// IsExpired returns if the credentials have been retrieved.
func (a *AzureProvider) IsExpired() bool {
return a.expiration.Before(time.Now())
}

View File

@ -0,0 +1,59 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package credproviders
import (
"errors"
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
)
// staticProviderName provides a name of Static provider
const staticProviderName = "StaticProvider"
// A StaticProvider is a set of credentials which are set programmatically,
// and will never expire.
type StaticProvider struct {
credentials.Value
verified bool
err error
}
func verify(v credentials.Value) error {
if !v.HasKeys() {
return errors.New("failed to retrieve ACCESS_KEY_ID and SECRET_ACCESS_KEY")
}
if v.AccessKeyID != "" && v.SecretAccessKey == "" {
return errors.New("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing")
}
if v.AccessKeyID == "" && v.SecretAccessKey != "" {
return errors.New("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing")
}
if v.AccessKeyID == "" && v.SecretAccessKey == "" && v.SessionToken != "" {
return errors.New("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing")
}
return nil
}
// Retrieve returns the credentials or error if the credentials are invalid.
func (s *StaticProvider) Retrieve() (credentials.Value, error) {
if !s.verified {
s.err = verify(s.Value)
s.Value.ProviderName = staticProviderName
s.verified = true
}
return s.Value, s.err
}
// IsExpired returns if the credentials are expired.
//
// For StaticProvider, the credentials never expired.
func (s *StaticProvider) IsExpired() bool {
return false
}

40
internal/csfle/csfle.go Normal file
View File

@ -0,0 +1,40 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package csfle
import (
"errors"
"fmt"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
const (
EncryptedCacheCollection = "ecc"
EncryptedStateCollection = "esc"
EncryptedCompactionCollection = "ecoc"
)
// GetEncryptedStateCollectionName returns the encrypted state collection name associated with dataCollectionName.
func GetEncryptedStateCollectionName(efBSON bsoncore.Document, dataCollectionName string, stateCollection string) (string, error) {
fieldName := stateCollection + "Collection"
val, err := efBSON.LookupErr(fieldName)
if err != nil {
if !errors.Is(err, bsoncore.ErrElementNotFound) {
return "", err
}
// Return default name.
defaultName := "enxcol_." + dataCollectionName + "." + stateCollection
return defaultName, nil
}
stateCollectionName, ok := val.StringValueOK()
if !ok {
return "", fmt.Errorf("expected string for '%v', got: %v", fieldName, val.Type)
}
return stateCollectionName, nil
}

106
internal/csot/csot.go Normal file
View File

@ -0,0 +1,106 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package csot
import (
"context"
"time"
)
type clientLevel struct{}
func isClientLevel(ctx context.Context) bool {
val := ctx.Value(clientLevel{})
if val == nil {
return false
}
return val.(bool)
}
// IsTimeoutContext checks if the provided context has been assigned a deadline
// or has unlimited retries.
func IsTimeoutContext(ctx context.Context) bool {
_, ok := ctx.Deadline()
return ok || isClientLevel(ctx)
}
// WithTimeout will set the given timeout on the context, if no deadline has
// already been set.
//
// This function assumes that the timeout field is static, given that the
// timeout should be sourced from the client. Therefore, once a timeout function
// parameter has been applied to the context, it will remain for the lifetime
// of the context.
func WithTimeout(parent context.Context, timeout *time.Duration) (context.Context, context.CancelFunc) {
cancel := func() {}
if timeout == nil || IsTimeoutContext(parent) {
// In the following conditions, do nothing:
// 1. The parent already has a deadline
// 2. The parent does not have a deadline, but a client-level timeout has
// been applied.
// 3. The parent does not have a deadline, there is not client-level
// timeout, and the timeout parameter DNE.
return parent, cancel
}
// If a client-level timeout has not been applied, then apply it.
parent = context.WithValue(parent, clientLevel{}, true)
dur := *timeout
if dur == 0 {
// If the parent does not have a deadline and the timeout is zero, then
// do nothing.
return parent, cancel
}
// If the parent does not have a dealine and the timeout is non-zero, then
// apply the timeout.
return context.WithTimeout(parent, dur)
}
// WithServerSelectionTimeout creates a context with a timeout that is the
// minimum of serverSelectionTimeoutMS and context deadline. The usage of
// non-positive values for serverSelectionTimeoutMS are an anti-pattern and are
// not considered in this calculation.
func WithServerSelectionTimeout(
parent context.Context,
serverSelectionTimeout time.Duration,
) (context.Context, context.CancelFunc) {
if serverSelectionTimeout <= 0 {
return parent, func() {}
}
return context.WithTimeout(parent, serverSelectionTimeout)
}
// ZeroRTTMonitor implements the RTTMonitor interface and is used internally for testing. It returns 0 for all
// RTT calculations and an empty string for RTT statistics.
type ZeroRTTMonitor struct{}
// EWMA implements the RTT monitor interface.
func (zrm *ZeroRTTMonitor) EWMA() time.Duration {
return 0
}
// Min implements the RTT monitor interface.
func (zrm *ZeroRTTMonitor) Min() time.Duration {
return 0
}
// P90 implements the RTT monitor interface.
func (zrm *ZeroRTTMonitor) P90() time.Duration {
return 0
}
// Stats implements the RTT monitor interface.
func (zrm *ZeroRTTMonitor) Stats() string {
return ""
}

249
internal/csot/csot_test.go Normal file
View File

@ -0,0 +1,249 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package csot
import (
"context"
"testing"
"time"
"gitea.psichedelico.com/go/bson/internal/assert"
"gitea.psichedelico.com/go/bson/internal/ptrutil"
)
func newTestContext(t *testing.T, timeout time.Duration) context.Context {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
t.Cleanup(cancel)
return ctx
}
func TestWithServerSelectionTimeout(t *testing.T) {
t.Parallel()
tests := []struct {
name string
parent context.Context
serverSelectionTimeout time.Duration
wantTimeout time.Duration
wantOk bool
}{
{
name: "no context deadine and ssto is zero",
parent: context.Background(),
serverSelectionTimeout: 0,
wantTimeout: 0,
wantOk: false,
},
{
name: "no context deadline and ssto is positive",
parent: context.Background(),
serverSelectionTimeout: 1,
wantTimeout: 1,
wantOk: true,
},
{
name: "no context deadline and ssto is negative",
parent: context.Background(),
serverSelectionTimeout: -1,
wantTimeout: 0,
wantOk: false,
},
{
name: "context deadline is zero and ssto is positive",
parent: newTestContext(t, 0),
serverSelectionTimeout: 1,
wantTimeout: 1,
wantOk: true,
},
{
name: "context deadline is zero and ssto is negative",
parent: newTestContext(t, 0),
serverSelectionTimeout: -1,
wantTimeout: 0,
wantOk: true,
},
{
name: "context deadline is negative and ssto is zero",
parent: newTestContext(t, -1),
serverSelectionTimeout: 0,
wantTimeout: -1,
wantOk: true,
},
{
name: "context deadline is negative and ssto is positive",
parent: newTestContext(t, -1),
serverSelectionTimeout: 1,
wantTimeout: 1,
wantOk: true,
},
{
name: "context deadline is negative and ssto is negative",
parent: newTestContext(t, -1),
serverSelectionTimeout: -1,
wantTimeout: -1,
wantOk: true,
},
{
name: "context deadline is positive and ssto is zero",
parent: newTestContext(t, 1),
serverSelectionTimeout: 0,
wantTimeout: 1,
wantOk: true,
},
{
name: "context deadline is positive and equal to ssto",
parent: newTestContext(t, 1),
serverSelectionTimeout: 1,
wantTimeout: 1,
wantOk: true,
},
{
name: "context deadline is positive lt ssto",
parent: newTestContext(t, 1),
serverSelectionTimeout: 2,
wantTimeout: 2,
wantOk: true,
},
{
name: "context deadline is positive gt ssto",
parent: newTestContext(t, 2),
serverSelectionTimeout: 1,
wantTimeout: 2,
wantOk: true,
},
{
name: "context deadline is positive and ssto is negative",
parent: newTestContext(t, -1),
serverSelectionTimeout: -1,
wantTimeout: 1,
wantOk: true,
},
}
for _, test := range tests {
test := test // Capture the range variable
t.Run(test.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := WithServerSelectionTimeout(test.parent, test.serverSelectionTimeout)
t.Cleanup(cancel)
deadline, gotOk := ctx.Deadline()
assert.Equal(t, test.wantOk, gotOk)
if gotOk {
delta := time.Until(deadline) - test.wantTimeout
tolerance := 10 * time.Millisecond
assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance)
assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance)
}
})
}
}
func TestWithTimeout(t *testing.T) {
t.Parallel()
tests := []struct {
name string
parent context.Context
timeout *time.Duration
wantTimeout time.Duration
wantDeadline bool
wantValues []interface{}
}{
{
name: "deadline set with non-zero timeout",
parent: newTestContext(t, 1),
timeout: ptrutil.Ptr(time.Duration(2)),
wantTimeout: 1,
wantDeadline: true,
wantValues: []interface{}{},
},
{
name: "deadline set with zero timeout",
parent: newTestContext(t, 1),
timeout: ptrutil.Ptr(time.Duration(0)),
wantTimeout: 1,
wantDeadline: true,
wantValues: []interface{}{},
},
{
name: "deadline set with nil timeout",
parent: newTestContext(t, 1),
timeout: nil,
wantTimeout: 1,
wantDeadline: true,
wantValues: []interface{}{},
},
{
name: "deadline unset with non-zero timeout",
parent: context.Background(),
timeout: ptrutil.Ptr(time.Duration(1)),
wantTimeout: 1,
wantDeadline: true,
wantValues: []interface{}{},
},
{
name: "deadline unset with zero timeout",
parent: context.Background(),
timeout: ptrutil.Ptr(time.Duration(0)),
wantTimeout: 0,
wantDeadline: false,
wantValues: []interface{}{clientLevel{}},
},
{
name: "deadline unset with nil timeout",
parent: context.Background(),
timeout: nil,
wantTimeout: 0,
wantDeadline: false,
wantValues: []interface{}{},
},
{
// If "clientLevel" has been set, but a new timeout is applied
// to the context, then the constructed context should retain the old
// timeout. To simplify the code, we assume the first timeout is static.
name: "deadline unset with non-zero timeout at clientLevel",
parent: context.WithValue(context.Background(), clientLevel{}, true),
timeout: ptrutil.Ptr(time.Duration(1)),
wantTimeout: 0,
wantDeadline: false,
wantValues: []interface{}{},
},
}
for _, test := range tests {
test := test // Capture the range variable
t.Run(test.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := WithTimeout(test.parent, test.timeout)
t.Cleanup(cancel)
deadline, gotDeadline := ctx.Deadline()
assert.Equal(t, test.wantDeadline, gotDeadline)
if gotDeadline {
delta := time.Until(deadline) - test.wantTimeout
tolerance := 10 * time.Millisecond
assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance)
assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance)
}
for _, wantValue := range test.wantValues {
assert.NotNil(t, ctx.Value(wantValue), "expected context to have value %v", wantValue)
}
})
}
}

View File

@ -0,0 +1,117 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package decimal128
import (
"strconv"
)
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
const (
MaxDecimal128Exp = 6111
MinDecimal128Exp = -6176
)
func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) {
div64 := uint64(div)
a := h >> 32
aq := a / div64
ar := a % div64
b := ar<<32 + h&(1<<32-1)
bq := b / div64
br := b % div64
c := br<<32 + l>>32
cq := c / div64
cr := c % div64
d := cr<<32 + l&(1<<32-1)
dq := d / div64
dr := d % div64
return (aq<<32 | bq), (cq<<32 | dq), uint32(dr)
}
// String returns a string representation of the decimal value.
func String(h, l uint64) string {
var posSign int // positive sign
var exp int // exponent
var high, low uint64 // significand high/low
if h>>63&1 == 0 {
posSign = 1
}
switch h >> 58 & (1<<5 - 1) {
case 0x1F:
return "NaN"
case 0x1E:
return "-Infinity"[posSign:]
}
low = l
if h>>61&3 == 3 {
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
// Implicit 0b100 prefix in significand.
exp = int(h >> 47 & (1<<14 - 1))
// Spec says all of these values are out of range.
high, low = 0, 0
} else {
// Bits: 1*sign 14*exponent 113*significand
exp = int(h >> 49 & (1<<14 - 1))
high = h & (1<<49 - 1)
}
exp += MinDecimal128Exp
// Would be handled by the logic below, but that's trivial and common.
if high == 0 && low == 0 && exp == 0 {
return "-0"[posSign:]
}
var repr [48]byte // Loop 5 times over 9 digits plus dot, negative sign, and leading zero.
var last = len(repr)
var i = len(repr)
var dot = len(repr) + exp
var rem uint32
Loop:
for d9 := 0; d9 < 5; d9++ {
high, low, rem = divmod(high, low, 1e9)
for d1 := 0; d1 < 9; d1++ {
// Handle "-0.0", "0.00123400", "-1.00E-6", "1.050E+3", etc.
if i < len(repr) && (dot == i || low == 0 && high == 0 && rem > 0 && rem < 10 && (dot < i-6 || exp > 0)) {
exp += len(repr) - i
i--
repr[i] = '.'
last = i - 1
dot = len(repr) // Unmark.
}
c := '0' + byte(rem%10)
rem /= 10
i--
repr[i] = c
// Handle "0E+3", "1E+3", etc.
if low == 0 && high == 0 && rem == 0 && i == len(repr)-1 && (dot < i-5 || exp > 0) {
last = i
break Loop
}
if c != '0' {
last = i
}
// Break early. Works without it, but why.
if dot > i && low == 0 && high == 0 && rem == 0 {
break Loop
}
}
}
repr[last-1] = '-'
last--
if exp > 0 {
return string(repr[last+posSign:]) + "E+" + strconv.Itoa(exp)
}
if exp < 0 {
return string(repr[last+posSign:]) + "E" + strconv.Itoa(exp)
}
return string(repr[last+posSign:])
}

85
internal/errutil/join.go Normal file
View File

@ -0,0 +1,85 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package errutil
import "errors"
// join is a Go 1.13-1.19 compatible version of [errors.Join]. It is only called
// by Join in join_go1.19.go. It is included here in a file without build
// constraints only for testing purposes.
//
// It is heavily based on Join from
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
func join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
e := &joinError{
errs: make([]error, 0, n),
}
for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}
return e
}
// joinError is a Go 1.13-1.19 compatible joinable error type. Its error
// message is identical to [errors.Join], but it implements "Unwrap() error"
// instead of "Unwrap() []error".
//
// It is heavily based on the joinError from
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
type joinError struct {
errs []error
}
func (e *joinError) Error() string {
var b []byte
for i, err := range e.errs {
if i > 0 {
b = append(b, '\n')
}
b = append(b, err.Error()...)
}
return string(b)
}
// Unwrap returns another joinError with the same errors as the current
// joinError except the first error in the slice. Continuing to call Unwrap
// on each returned error will increment through every error in the slice. The
// resulting behavior when using [errors.Is] and [errors.As] is similar to an
// error created using [errors.Join] in Go 1.20+.
func (e *joinError) Unwrap() error {
if len(e.errs) == 1 {
return e.errs[0]
}
return &joinError{errs: e.errs[1:]}
}
// Is calls [errors.Is] with the first error in the slice.
func (e *joinError) Is(target error) bool {
if len(e.errs) == 0 {
return false
}
return errors.Is(e.errs[0], target)
}
// As calls [errors.As] with the first error in the slice.
func (e *joinError) As(target interface{}) bool {
if len(e.errs) == 0 {
return false
}
return errors.As(e.errs[0], target)
}

View File

@ -0,0 +1,20 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !go1.20
// +build !go1.20
package errutil
// Join returns an error that wraps the given errors. Any nil error values are
// discarded. Join returns nil if every value in errs is nil. The error formats
// as the concatenation of the strings obtained by calling the Error method of
// each element of errs, with a newline between each string.
//
// A non-nil error returned by Join implements the "Unwrap() error" method.
func Join(errs ...error) error {
return join(errs...)
}

View File

@ -0,0 +1,17 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build go1.20
// +build go1.20
package errutil
import "errors"
// Join calls [errors.Join].
func Join(errs ...error) error {
return errors.Join(errs...)
}

View File

@ -0,0 +1,243 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package errutil
import (
"context"
"errors"
"fmt"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
)
// TestJoin_Nil asserts that join returns a nil error for the same inputs that
// [errors.Join] returns a nil error.
func TestJoin_Nil(t *testing.T) {
t.Parallel()
assert.Equal(t, errors.Join(), join(), "errors.Join() != join()")
assert.Equal(t, errors.Join(nil), join(nil), "errors.Join(nil) != join(nil)")
assert.Equal(t, errors.Join(nil, nil), join(nil, nil), "errors.Join(nil, nil) != join(nil, nil)")
}
// TestJoin_Error asserts that join returns an error with the same error message
// as the error returned by [errors.Join].
func TestJoin_Error(t *testing.T) {
t.Parallel()
err1 := errors.New("err1")
err2 := errors.New("err2")
tests := []struct {
desc string
errs []error
}{{
desc: "single error",
errs: []error{err1},
}, {
desc: "two errors",
errs: []error{err1, err2},
}, {
desc: "two errors and a nil value",
errs: []error{err1, nil, err2},
}}
for _, test := range tests {
test := test // Capture range variable.
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
want := errors.Join(test.errs...).Error()
got := join(test.errs...).Error()
assert.Equal(t,
want,
got,
"errors.Join().Error() != join().Error() for input %v",
test.errs)
})
}
}
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
// to the error returned by [errors.Join] when passed to [errors.Is].
func TestJoin_ErrorsIs(t *testing.T) {
t.Parallel()
err1 := errors.New("err1")
err2 := errors.New("err2")
tests := []struct {
desc string
errs []error
target error
}{{
desc: "one error with a matching target",
errs: []error{err1},
target: err1,
}, {
desc: "one error with a non-matching target",
errs: []error{err1},
target: err2,
}, {
desc: "nil error",
errs: []error{nil},
target: err1,
}, {
desc: "no errors",
errs: []error{},
target: err1,
}, {
desc: "two different errors with a matching target",
errs: []error{err1, err2},
target: err2,
}, {
desc: "two identical errors with a matching target",
errs: []error{err1, err1},
target: err1,
}, {
desc: "wrapped error with a matching target",
errs: []error{fmt.Errorf("error: %w", err1)},
target: err1,
}, {
desc: "nested joined error with a matching target",
errs: []error{err1, join(err2, errors.New("nope"))},
target: err2,
}, {
desc: "nested joined error with no matching targets",
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
target: err2,
}, {
desc: "nested joined error with a wrapped matching target",
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
target: err1,
}, {
desc: "context.DeadlineExceeded",
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: context.DeadlineExceeded,
}, {
desc: "wrapped context.DeadlineExceeded",
errs: []error{err1, nil, fmt.Errorf("error: %w", context.DeadlineExceeded), err2},
target: context.DeadlineExceeded,
}}
for _, test := range tests {
test := test // Capture range variable.
t.Run(test.desc, func(t *testing.T) {
// Assert that top-level errors returned by errors.Join and join
// behave the same with errors.Is.
want := errors.Join(test.errs...)
got := join(test.errs...)
assert.Equal(t,
errors.Is(want, test.target),
errors.Is(got, test.target),
"errors.Join() and join() behave differently with errors.Is")
// Assert that wrapped errors returned by errors.Join and join
// behave the same with errors.Is.
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
got = fmt.Errorf("error: %w", join(test.errs...))
assert.Equal(t,
errors.Is(want, test.target),
errors.Is(got, test.target),
"errors.Join() and join(), when wrapped, behave differently with errors.Is")
})
}
}
type errType1 struct{}
func (errType1) Error() string { return "" }
type errType2 struct{}
func (errType2) Error() string { return "" }
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
// to the error returned by [errors.Join] when passed to [errors.As].
func TestJoin_ErrorsAs(t *testing.T) {
t.Parallel()
err1 := errType1{}
err2 := errType2{}
tests := []struct {
desc string
errs []error
target interface{}
}{{
desc: "one error with a matching target",
errs: []error{err1},
target: &errType1{},
}, {
desc: "one error with a non-matching target",
errs: []error{err1},
target: &errType2{},
}, {
desc: "nil error",
errs: []error{nil},
target: &errType1{},
}, {
desc: "no errors",
errs: []error{},
target: &errType1{},
}, {
desc: "two different errors with a matching target",
errs: []error{err1, err2},
target: &errType2{},
}, {
desc: "two identical errors with a matching target",
errs: []error{err1, err1},
target: &errType1{},
}, {
desc: "wrapped error with a matching target",
errs: []error{fmt.Errorf("error: %w", err1)},
target: &errType1{},
}, {
desc: "nested joined error with a matching target",
errs: []error{err1, join(err2, errors.New("nope"))},
target: &errType2{},
}, {
desc: "nested joined error with no matching targets",
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
target: &errType2{},
}, {
desc: "nested joined error with a wrapped matching target",
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
target: &errType1{},
}, {
desc: "context.DeadlineExceeded",
errs: []error{err1, nil, context.DeadlineExceeded, err2},
target: &errType2{},
}}
for _, test := range tests {
test := test // Capture range variable.
t.Run(test.desc, func(t *testing.T) {
// Assert that top-level errors returned by errors.Join and join
// behave the same with errors.As.
want := errors.Join(test.errs...)
got := join(test.errs...)
assert.Equal(t,
errors.As(want, test.target),
errors.As(got, test.target),
"errors.Join() and join() behave differently with errors.As")
// Assert that wrapped errors returned by errors.Join and join
// behave the same with errors.As.
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
got = fmt.Errorf("error: %w", join(test.errs...))
assert.Equal(t,
errors.As(want, test.target),
errors.As(got, test.target),
"errors.Join() and join(), when wrapped, behave differently with errors.As")
})
}
}

View File

@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package failpoint
import (
"gitea.psichedelico.com/go/bson"
)
const (
// ModeAlwaysOn is the fail point mode that enables the fail point for an
// indefinite number of matching commands.
ModeAlwaysOn = "alwaysOn"
// ModeOff is the fail point mode that disables the fail point.
ModeOff = "off"
)
// FailPoint is used to configure a server fail point. It is intended to be
// passed as the command argument to RunCommand.
//
// For more information about fail points, see
// https://github.com/mongodb/specifications/tree/HEAD/source/transactions/tests#server-fail-point
type FailPoint struct {
ConfigureFailPoint string `bson:"configureFailPoint"`
// Mode should be a string, FailPointMode, or map[string]interface{}
Mode interface{} `bson:"mode"`
Data Data `bson:"data"`
}
// Mode configures when a fail point will be enabled. It is used to set the
// FailPoint.Mode field.
type Mode struct {
Times int32 `bson:"times"`
Skip int32 `bson:"skip"`
}
// Data configures how a fail point will behave. It is used to set the
// FailPoint.Data field.
type Data struct {
FailCommands []string `bson:"failCommands,omitempty"`
CloseConnection bool `bson:"closeConnection,omitempty"`
ErrorCode int32 `bson:"errorCode,omitempty"`
FailBeforeCommitExceptionCode int32 `bson:"failBeforeCommitExceptionCode,omitempty"`
ErrorLabels *[]string `bson:"errorLabels,omitempty"`
WriteConcernError *WriteConcernError `bson:"writeConcernError,omitempty"`
BlockConnection bool `bson:"blockConnection,omitempty"`
BlockTimeMS int32 `bson:"blockTimeMS,omitempty"`
AppName string `bson:"appName,omitempty"`
}
// WriteConcernError is the write concern error to return when the fail point is
// triggered. It is used to set the FailPoint.Data.WriteConcernError field.
type WriteConcernError struct {
Code int32 `bson:"code"`
Name string `bson:"codeName"`
Errmsg string `bson:"errmsg"`
ErrorLabels *[]string `bson:"errorLabels,omitempty"`
ErrInfo bson.Raw `bson:"errInfo,omitempty"`
}

View File

@ -0,0 +1,13 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package handshake
// LegacyHello is the legacy version of the hello command.
var LegacyHello = "isMaster"
// LegacyHelloLowercase is the lowercase, legacy version of the hello command.
var LegacyHelloLowercase = "ismaster"

View File

@ -0,0 +1,30 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package httputil
import (
"net/http"
)
// DefaultHTTPClient is the default HTTP client used across the driver.
var DefaultHTTPClient = &http.Client{
Transport: http.DefaultTransport.(*http.Transport).Clone(),
}
// CloseIdleHTTPConnections closes any connections which were previously
// connected from previous requests but are now sitting idle in a "keep-alive"
// state. It does not interrupt any connections currently in use.
//
// Borrowed from the Go standard library.
func CloseIdleHTTPConnections(client *http.Client) {
type closeIdler interface {
CloseIdleConnections()
}
if tr, ok := client.Transport.(closeIdler); ok {
tr.CloseIdleConnections()
}
}

14
internal/israce/norace.go Normal file
View File

@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !race
// +build !race
// Package israce reports if the Go race detector is enabled.
package israce
// Enabled reports if the race detector is enabled.
const Enabled = false

14
internal/israce/race.go Normal file
View File

@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build race
// +build race
// Package israce reports if the Go race detector is enabled.
package israce
// Enabled reports if the race detector is enabled.
const Enabled = true

View File

@ -0,0 +1,313 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package logger
import (
"os"
"strconv"
"gitea.psichedelico.com/go/bson"
)
const (
CommandFailed = "Command failed"
CommandStarted = "Command started"
CommandSucceeded = "Command succeeded"
ConnectionPoolCreated = "Connection pool created"
ConnectionPoolReady = "Connection pool ready"
ConnectionPoolCleared = "Connection pool cleared"
ConnectionPoolClosed = "Connection pool closed"
ConnectionCreated = "Connection created"
ConnectionReady = "Connection ready"
ConnectionClosed = "Connection closed"
ConnectionCheckoutStarted = "Connection checkout started"
ConnectionCheckoutFailed = "Connection checkout failed"
ConnectionCheckedOut = "Connection checked out"
ConnectionCheckedIn = "Connection checked in"
ServerSelectionFailed = "Server selection failed"
ServerSelectionStarted = "Server selection started"
ServerSelectionSucceeded = "Server selection succeeded"
ServerSelectionWaiting = "Waiting for suitable server to become available"
TopologyClosed = "Stopped topology monitoring"
TopologyDescriptionChanged = "Topology description changed"
TopologyOpening = "Starting topology monitoring"
TopologyServerClosed = "Stopped server monitoring"
TopologyServerHeartbeatFailed = "Server heartbeat failed"
TopologyServerHeartbeatStarted = "Server heartbeat started"
TopologyServerHeartbeatSucceeded = "Server heartbeat succeeded"
TopologyServerOpening = "Starting server monitoring"
)
const (
KeyAwaited = "awaited"
KeyCommand = "command"
KeyCommandName = "commandName"
KeyDatabaseName = "databaseName"
KeyDriverConnectionID = "driverConnectionId"
KeyDurationMS = "durationMS"
KeyError = "error"
KeyFailure = "failure"
KeyMaxConnecting = "maxConnecting"
KeyMaxIdleTimeMS = "maxIdleTimeMS"
KeyMaxPoolSize = "maxPoolSize"
KeyMessage = "message"
KeyMinPoolSize = "minPoolSize"
KeyNewDescription = "newDescription"
KeyOperation = "operation"
KeyOperationID = "operationId"
KeyPreviousDescription = "previousDescription"
KeyRemainingTimeMS = "remainingTimeMS"
KeyReason = "reason"
KeyReply = "reply"
KeyRequestID = "requestId"
KeySelector = "selector"
KeyServerConnectionID = "serverConnectionId"
KeyServerHost = "serverHost"
KeyServerPort = "serverPort"
KeyServiceID = "serviceId"
KeyTimestamp = "timestamp"
KeyTopologyDescription = "topologyDescription"
KeyTopologyID = "topologyId"
)
// KeyValues is a list of key-value pairs.
type KeyValues []interface{}
// Add adds a key-value pair to an instance of a KeyValues list.
func (kvs *KeyValues) Add(key string, value interface{}) {
*kvs = append(*kvs, key, value)
}
const (
ReasonConnClosedStale = "Connection became stale because the pool was cleared"
ReasonConnClosedIdle = "Connection has been available but unused for longer than the configured max idle time"
ReasonConnClosedError = "An error occurred while using the connection"
ReasonConnClosedPoolClosed = "Connection pool was closed"
ReasonConnCheckoutFailedTimout = "Wait queue timeout elapsed without a connection becoming available"
ReasonConnCheckoutFailedError = "An error occurred while trying to establish a new connection"
ReasonConnCheckoutFailedPoolClosed = "Connection pool was closed"
)
// Component is an enumeration representing the "components" which can be
// logged against. A LogLevel can be configured on a per-component basis.
type Component int
const (
// ComponentAll enables logging for all components.
ComponentAll Component = iota
// ComponentCommand enables command monitor logging.
ComponentCommand
// ComponentTopology enables topology logging.
ComponentTopology
// ComponentServerSelection enables server selection logging.
ComponentServerSelection
// ComponentConnection enables connection services logging.
ComponentConnection
)
const (
mongoDBLogAllEnvVar = "MONGODB_LOG_ALL"
mongoDBLogCommandEnvVar = "MONGODB_LOG_COMMAND"
mongoDBLogTopologyEnvVar = "MONGODB_LOG_TOPOLOGY"
mongoDBLogServerSelectionEnvVar = "MONGODB_LOG_SERVER_SELECTION"
mongoDBLogConnectionEnvVar = "MONGODB_LOG_CONNECTION"
)
var componentEnvVarMap = map[string]Component{
mongoDBLogAllEnvVar: ComponentAll,
mongoDBLogCommandEnvVar: ComponentCommand,
mongoDBLogTopologyEnvVar: ComponentTopology,
mongoDBLogServerSelectionEnvVar: ComponentServerSelection,
mongoDBLogConnectionEnvVar: ComponentConnection,
}
// EnvHasComponentVariables returns true if the environment contains any of the
// component environment variables.
func EnvHasComponentVariables() bool {
for envVar := range componentEnvVarMap {
if os.Getenv(envVar) != "" {
return true
}
}
return false
}
// Command is a struct defining common fields that must be included in all
// commands.
type Command struct {
DriverConnectionID int64 // Driver's ID for the connection
Name string // Command name
DatabaseName string // Database name
Message string // Message associated with the command
OperationID int32 // Driver-generated operation ID
RequestID int64 // Driver-generated request ID
ServerConnectionID *int64 // Server's ID for the connection used for the command
ServerHost string // Hostname or IP address for the server
ServerPort string // Port for the server
ServiceID *bson.ObjectID // ID for the command in load balancer mode
}
// SerializeCommand takes a command and a variable number of key-value pairs and
// returns a slice of interface{} that can be passed to the logger for
// structured logging.
func SerializeCommand(cmd Command, extraKeysAndValues ...interface{}) KeyValues {
// Initialize the boilerplate keys and values.
keysAndValues := KeyValues{
KeyCommandName, cmd.Name,
KeyDatabaseName, cmd.DatabaseName,
KeyDriverConnectionID, cmd.DriverConnectionID,
KeyMessage, cmd.Message,
KeyOperationID, cmd.OperationID,
KeyRequestID, cmd.RequestID,
KeyServerHost, cmd.ServerHost,
}
// Add the extra keys and values.
for i := 0; i < len(extraKeysAndValues); i += 2 {
keysAndValues.Add(extraKeysAndValues[i].(string), extraKeysAndValues[i+1])
}
port, err := strconv.ParseInt(cmd.ServerPort, 10, 32)
if err == nil {
keysAndValues.Add(KeyServerPort, port)
}
// Add the "serverConnectionId" if it is not nil.
if cmd.ServerConnectionID != nil {
keysAndValues.Add(KeyServerConnectionID, *cmd.ServerConnectionID)
}
// Add the "serviceId" if it is not nil.
if cmd.ServiceID != nil {
keysAndValues.Add(KeyServiceID, cmd.ServiceID.Hex())
}
return keysAndValues
}
// Connection contains data that all connection log messages MUST contain.
type Connection struct {
Message string // Message associated with the connection
ServerHost string // Hostname or IP address for the server
ServerPort string // Port for the server
}
// SerializeConnection serializes a Connection message into a slice of keys and
// values that can be passed to a logger.
func SerializeConnection(conn Connection, extraKeysAndValues ...interface{}) KeyValues {
// Initialize the boilerplate keys and values.
keysAndValues := KeyValues{
KeyMessage, conn.Message,
KeyServerHost, conn.ServerHost,
}
// Add the optional keys and values.
for i := 0; i < len(extraKeysAndValues); i += 2 {
keysAndValues.Add(extraKeysAndValues[i].(string), extraKeysAndValues[i+1])
}
port, err := strconv.ParseInt(conn.ServerPort, 10, 32)
if err == nil {
keysAndValues.Add(KeyServerPort, port)
}
return keysAndValues
}
// Server contains data that all server messages MAY contain.
type Server struct {
DriverConnectionID int64 // Driver's ID for the connection
TopologyID bson.ObjectID // Driver's unique ID for this topology
Message string // Message associated with the topology
ServerConnectionID *int64 // Server's ID for the connection
ServerHost string // Hostname or IP address for the server
ServerPort string // Port for the server
}
// SerializeServer serializes a Server message into a slice of keys and
// values that can be passed to a logger.
func SerializeServer(srv Server, extraKV ...interface{}) KeyValues {
// Initialize the boilerplate keys and values.
keysAndValues := KeyValues{
KeyDriverConnectionID, srv.DriverConnectionID,
KeyMessage, srv.Message,
KeyServerHost, srv.ServerHost,
KeyTopologyID, srv.TopologyID.Hex(),
}
if connID := srv.ServerConnectionID; connID != nil {
keysAndValues.Add(KeyServerConnectionID, *connID)
}
port, err := strconv.ParseInt(srv.ServerPort, 10, 32)
if err == nil {
keysAndValues.Add(KeyServerPort, port)
}
// Add the optional keys and values.
for i := 0; i < len(extraKV); i += 2 {
keysAndValues.Add(extraKV[i].(string), extraKV[i+1])
}
return keysAndValues
}
// ServerSelection contains data that all server selection messages MUST
// contain.
type ServerSelection struct {
Selector string
OperationID *int32
Operation string
TopologyDescription string
}
// SerializeServerSelection serializes a Topology message into a slice of keys
// and values that can be passed to a logger.
func SerializeServerSelection(srvSelection ServerSelection, extraKV ...interface{}) KeyValues {
keysAndValues := KeyValues{
KeySelector, srvSelection.Selector,
KeyOperation, srvSelection.Operation,
KeyTopologyDescription, srvSelection.TopologyDescription,
}
if srvSelection.OperationID != nil {
keysAndValues.Add(KeyOperationID, *srvSelection.OperationID)
}
// Add the optional keys and values.
for i := 0; i < len(extraKV); i += 2 {
keysAndValues.Add(extraKV[i].(string), extraKV[i+1])
}
return keysAndValues
}
// Topology contains data that all topology messages MAY contain.
type Topology struct {
ID bson.ObjectID // Driver's unique ID for this topology
Message string // Message associated with the topology
}
// SerializeTopology serializes a Topology message into a slice of keys and
// values that can be passed to a logger.
func SerializeTopology(topo Topology, extraKV ...interface{}) KeyValues {
keysAndValues := KeyValues{
KeyTopologyID, topo.ID.Hex(),
}
// Add the optional keys and values.
for i := 0; i < len(extraKV); i += 2 {
keysAndValues.Add(extraKV[i].(string), extraKV[i+1])
}
return keysAndValues
}

View File

@ -0,0 +1,229 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package logger
import (
"testing"
"gitea.psichedelico.com/go/bson"
"gitea.psichedelico.com/go/bson/internal/assert"
)
func verifySerialization(t *testing.T, got, want KeyValues) {
t.Helper()
for i := 0; i < len(got); i += 2 {
assert.Equal(t, want[i], got[i], "key position mismatch")
assert.Equal(t, want[i+1], got[i+1], "value position mismatch for %q", want[i])
}
}
func TestSerializeCommand(t *testing.T) {
t.Parallel()
serverConnectionID := int64(100)
serviceID := bson.NewObjectID()
tests := []struct {
name string
cmd Command
extraKeysAndValues []interface{}
want KeyValues
}{
{
name: "empty",
want: KeyValues{
KeyCommandName, "",
KeyDatabaseName, "",
KeyDriverConnectionID, int64(0),
KeyMessage, "",
KeyOperationID, int32(0),
KeyRequestID, int64(0),
KeyServerHost, "",
},
},
{
name: "complete Command object",
cmd: Command{
DriverConnectionID: 1,
Name: "foo",
DatabaseName: "db",
Message: "bar",
OperationID: 2,
RequestID: 3,
ServerHost: "localhost",
ServerPort: "27017",
ServerConnectionID: &serverConnectionID,
ServiceID: &serviceID,
},
want: KeyValues{
KeyCommandName, "foo",
KeyDatabaseName, "db",
KeyDriverConnectionID, int64(1),
KeyMessage, "bar",
KeyOperationID, int32(2),
KeyRequestID, int64(3),
KeyServerHost, "localhost",
KeyServerPort, int64(27017),
KeyServerConnectionID, serverConnectionID,
KeyServiceID, serviceID.Hex(),
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
got := SerializeCommand(test.cmd, test.extraKeysAndValues...)
verifySerialization(t, got, test.want)
})
}
}
func TestSerializeConnection(t *testing.T) {
t.Parallel()
tests := []struct {
name string
conn Connection
extraKeysAndValues []interface{}
want KeyValues
}{
{
name: "empty",
want: KeyValues{
KeyMessage, "",
KeyServerHost, "",
},
},
{
name: "complete Connection object",
conn: Connection{
Message: "foo",
ServerHost: "localhost",
ServerPort: "27017",
},
want: KeyValues{
"message", "foo",
"serverHost", "localhost",
"serverPort", int64(27017),
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
got := SerializeConnection(test.conn, test.extraKeysAndValues...)
verifySerialization(t, got, test.want)
})
}
}
func TestSerializeServer(t *testing.T) {
t.Parallel()
topologyID := bson.NewObjectID()
serverConnectionID := int64(100)
tests := []struct {
name string
srv Server
extraKeysAndValues []interface{}
want KeyValues
}{
{
name: "empty",
want: KeyValues{
KeyDriverConnectionID, int64(0),
KeyMessage, "",
KeyServerHost, "",
KeyTopologyID, bson.ObjectID{}.Hex(),
},
},
{
name: "complete Server object",
srv: Server{
DriverConnectionID: 1,
TopologyID: topologyID,
Message: "foo",
ServerConnectionID: &serverConnectionID,
ServerHost: "localhost",
ServerPort: "27017",
},
want: KeyValues{
KeyDriverConnectionID, int64(1),
KeyMessage, "foo",
KeyServerHost, "localhost",
KeyTopologyID, topologyID.Hex(),
KeyServerConnectionID, serverConnectionID,
KeyServerPort, int64(27017),
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
got := SerializeServer(test.srv, test.extraKeysAndValues...)
verifySerialization(t, got, test.want)
})
}
}
func TestSerializeTopology(t *testing.T) {
t.Parallel()
topologyID := bson.NewObjectID()
tests := []struct {
name string
topo Topology
extraKeysAndValues []interface{}
want KeyValues
}{
{
name: "empty",
want: KeyValues{
KeyTopologyID, bson.ObjectID{}.Hex(),
},
},
{
name: "complete Server object",
topo: Topology{
ID: topologyID,
Message: "foo",
},
want: KeyValues{
KeyTopologyID, topologyID.Hex(),
KeyMessage, "foo",
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
got := SerializeTopology(test.topo, test.extraKeysAndValues...)
verifySerialization(t, got, test.want)
})
}
}

View File

@ -0,0 +1,48 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package logger
import "context"
// contextKey is a custom type used to prevent key collisions when using the
// context package.
type contextKey string
const (
contextKeyOperation contextKey = "operation"
contextKeyOperationID contextKey = "operationID"
)
// WithOperationName adds the operation name to the context.
func WithOperationName(ctx context.Context, operation string) context.Context {
return context.WithValue(ctx, contextKeyOperation, operation)
}
// WithOperationID adds the operation ID to the context.
func WithOperationID(ctx context.Context, operationID int32) context.Context {
return context.WithValue(ctx, contextKeyOperationID, operationID)
}
// OperationName returns the operation name from the context.
func OperationName(ctx context.Context) (string, bool) {
operationName := ctx.Value(contextKeyOperation)
if operationName == nil {
return "", false
}
return operationName.(string), true
}
// OperationID returns the operation ID from the context.
func OperationID(ctx context.Context) (int32, bool) {
operationID := ctx.Value(contextKeyOperationID)
if operationID == nil {
return 0, false
}
return operationID.(int32), true
}

View File

@ -0,0 +1,187 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package logger_test
import (
"context"
"testing"
"gitea.psichedelico.com/go/bson/internal/assert"
"gitea.psichedelico.com/go/bson/internal/logger"
)
func TestContext_WithOperationName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
ctx context.Context
opName string
ok bool
}{
{
name: "simple",
ctx: context.Background(),
opName: "foo",
ok: true,
},
}
for _, tt := range tests {
tt := tt // Capture the range variable.
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := logger.WithOperationName(tt.ctx, tt.opName)
opName, ok := logger.OperationName(ctx)
assert.Equal(t, tt.ok, ok)
if ok {
assert.Equal(t, tt.opName, opName)
}
})
}
}
func TestContext_OperationName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
ctx context.Context
opName interface{}
ok bool
}{
{
name: "nil",
ctx: context.Background(),
opName: nil,
ok: false,
},
{
name: "string type",
ctx: context.Background(),
opName: "foo",
ok: true,
},
{
name: "non-string type",
ctx: context.Background(),
opName: int32(1),
ok: false,
},
}
for _, tt := range tests {
tt := tt // Capture the range variable.
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
if opNameStr, ok := tt.opName.(string); ok {
ctx = logger.WithOperationName(tt.ctx, opNameStr)
}
opName, ok := logger.OperationName(ctx)
assert.Equal(t, tt.ok, ok)
if ok {
assert.Equal(t, tt.opName, opName)
}
})
}
}
func TestContext_WithOperationID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
ctx context.Context
opID int32
ok bool
}{
{
name: "non-zero",
ctx: context.Background(),
opID: 1,
ok: true,
},
}
for _, tt := range tests {
tt := tt // Capture the range variable.
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := logger.WithOperationID(tt.ctx, tt.opID)
opID, ok := logger.OperationID(ctx)
assert.Equal(t, tt.ok, ok)
if ok {
assert.Equal(t, tt.opID, opID)
}
})
}
}
func TestContext_OperationID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
ctx context.Context
opID interface{}
ok bool
}{
{
name: "nil",
ctx: context.Background(),
opID: nil,
ok: false,
},
{
name: "i32 type",
ctx: context.Background(),
opID: int32(1),
ok: true,
},
{
name: "non-i32 type",
ctx: context.Background(),
opID: "foo",
ok: false,
},
}
for _, tt := range tests {
tt := tt // Capture the range variable.
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
if opIDI32, ok := tt.opID.(int32); ok {
ctx = logger.WithOperationID(tt.ctx, opIDI32)
}
opName, ok := logger.OperationID(ctx)
assert.Equal(t, tt.ok, ok)
if ok {
assert.Equal(t, tt.opID, opName)
}
})
}
}

View File

@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package logger
import (
"encoding/json"
"io"
"math"
"sync"
"time"
)
// IOSink writes a JSON-encoded message to the io.Writer.
type IOSink struct {
enc *json.Encoder
// encMu protects the encoder from concurrent writes. While the logger
// itself does not concurrently write to the sink, the sink may be used
// concurrently within the driver.
encMu sync.Mutex
}
// Compile-time check to ensure IOSink implements the LogSink interface.
var _ LogSink = &IOSink{}
// NewIOSink will create an IOSink object that writes JSON messages to the
// provided io.Writer.
func NewIOSink(out io.Writer) *IOSink {
return &IOSink{
enc: json.NewEncoder(out),
}
}
// Info will write a JSON-encoded message to the io.Writer.
func (sink *IOSink) Info(_ int, msg string, keysAndValues ...interface{}) {
mapSize := len(keysAndValues) / 2
if math.MaxInt-mapSize >= 2 {
mapSize += 2
}
kvMap := make(map[string]interface{}, mapSize)
kvMap[KeyTimestamp] = time.Now().UnixNano()
kvMap[KeyMessage] = msg
for i := 0; i < len(keysAndValues); i += 2 {
kvMap[keysAndValues[i].(string)] = keysAndValues[i+1]
}
sink.encMu.Lock()
defer sink.encMu.Unlock()
_ = sink.enc.Encode(kvMap)
}
// Error will write a JSON-encoded error message to the io.Writer.
func (sink *IOSink) Error(err error, msg string, kv ...interface{}) {
kv = append(kv, KeyError, err.Error())
sink.Info(0, msg, kv...)
}

74
internal/logger/level.go Normal file
View File

@ -0,0 +1,74 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package logger
import "strings"
// DiffToInfo is the number of levels in the Go Driver that come before the
// "Info" level. This should ensure that "Info" is the 0th level passed to the
// sink.
const DiffToInfo = 1
// Level is an enumeration representing the log severity levels supported by
// the driver. The order of the logging levels is important. The driver expects
// that a user will likely use the "logr" package to create a LogSink, which
// defaults InfoLevel as 0. Any additions to the Level enumeration before the
// InfoLevel will need to also update the "diffToInfo" constant.
type Level int
const (
// LevelOff suppresses logging.
LevelOff Level = iota
// LevelInfo enables logging of informational messages. These logs are
// high-level information about normal driver behavior.
LevelInfo
// LevelDebug enables logging of debug messages. These logs can be
// voluminous and are intended for detailed information that may be
// helpful when debugging an application.
LevelDebug
)
const (
levelLiteralOff = "off"
levelLiteralEmergency = "emergency"
levelLiteralAlert = "alert"
levelLiteralCritical = "critical"
levelLiteralError = "error"
levelLiteralWarning = "warning"
levelLiteralNotice = "notice"
levelLiteralInfo = "info"
levelLiteralDebug = "debug"
levelLiteralTrace = "trace"
)
var LevelLiteralMap = map[string]Level{
levelLiteralOff: LevelOff,
levelLiteralEmergency: LevelInfo,
levelLiteralAlert: LevelInfo,
levelLiteralCritical: LevelInfo,
levelLiteralError: LevelInfo,
levelLiteralWarning: LevelInfo,
levelLiteralNotice: LevelInfo,
levelLiteralInfo: LevelInfo,
levelLiteralDebug: LevelDebug,
levelLiteralTrace: LevelDebug,
}
// ParseLevel will check if the given string is a valid environment variable
// for a logging severity level. If it is, then it will return the associated
// driver's Level. The default Level is “LevelOff”.
func ParseLevel(str string) Level {
for literal, level := range LevelLiteralMap {
if strings.EqualFold(literal, str) {
return level
}
}
return LevelOff
}

265
internal/logger/logger.go Normal file
View File

@ -0,0 +1,265 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package logger provides the internal logging solution for the MongoDB Go
// Driver.
package logger
import (
"fmt"
"os"
"strconv"
"strings"
"gitea.psichedelico.com/go/bson"
"gitea.psichedelico.com/go/bson/internal/bsoncoreutil"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
// DefaultMaxDocumentLength is the default maximum number of bytes that can be
// logged for a stringified BSON document.
const DefaultMaxDocumentLength = 1000
// TruncationSuffix are trailing ellipsis "..." appended to a message to
// indicate to the user that truncation occurred. This constant does not count
// toward the max document length.
const TruncationSuffix = "..."
const logSinkPathEnvVar = "MONGODB_LOG_PATH"
const maxDocumentLengthEnvVar = "MONGODB_LOG_MAX_DOCUMENT_LENGTH"
// LogSink represents a logging implementation, this interface should be 1-1
// with the exported "LogSink" interface in the mongo/options package.
type LogSink interface {
// Info logs a non-error message with the given key/value pairs. The
// level argument is provided for optional logging.
Info(level int, msg string, keysAndValues ...interface{})
// Error logs an error, with the given message and key/value pairs.
Error(err error, msg string, keysAndValues ...interface{})
}
// Logger represents the configuration for the internal logger.
type Logger struct {
ComponentLevels map[Component]Level // Log levels for each component.
Sink LogSink // LogSink for log printing.
MaxDocumentLength uint // Command truncation width.
logFile *os.File // File to write logs to.
}
// New will construct a new logger. If any of the given options are the
// zero-value of the argument type, then the constructor will attempt to
// source the data from the environment. If the environment has not been set,
// then the constructor will the respective default values.
func New(sink LogSink, maxDocLen uint, compLevels map[Component]Level) (*Logger, error) {
logger := &Logger{
ComponentLevels: selectComponentLevels(compLevels),
MaxDocumentLength: selectMaxDocumentLength(maxDocLen),
}
sink, logFile, err := selectLogSink(sink)
if err != nil {
return nil, err
}
logger.Sink = sink
logger.logFile = logFile
return logger, nil
}
// Close will close the logger's log file, if it exists.
func (logger *Logger) Close() error {
if logger.logFile != nil {
return logger.logFile.Close()
}
return nil
}
// LevelComponentEnabled will return true if the given LogLevel is enabled for
// the given LogComponent. If the ComponentLevels on the logger are enabled for
// "ComponentAll", then this function will return true for any level bound by
// the level assigned to "ComponentAll".
//
// If the level is not enabled (i.e. LevelOff), then false is returned. This is
// to avoid false positives, such as returning "true" for a component that is
// not enabled. For example, without this condition, an empty LevelComponent
// would be considered "enabled" for "LevelOff".
func (logger *Logger) LevelComponentEnabled(level Level, component Component) bool {
if level == LevelOff {
return false
}
if logger.ComponentLevels == nil {
return false
}
return logger.ComponentLevels[component] >= level ||
logger.ComponentLevels[ComponentAll] >= level
}
// Print will synchronously print the given message to the configured LogSink.
// If the LogSink is nil, then this method will do nothing. Future work could be done to make
// this method asynchronous, see buffer management in libraries such as log4j.
//
// It's worth noting that many structured logs defined by DBX-wide
// specifications include a "message" field, which is often shared with the
// message arguments passed to this print function. The "Info" method used by
// this function is implemented based on the go-logr/logr LogSink interface,
// which is why "Print" has a message parameter. Any duplication in code is
// intentional to adhere to the logr pattern.
func (logger *Logger) Print(level Level, component Component, msg string, keysAndValues ...interface{}) {
// If the level is not enabled for the component, then
// skip the message.
if !logger.LevelComponentEnabled(level, component) {
return
}
// If the sink is nil, then skip the message.
if logger.Sink == nil {
return
}
logger.Sink.Info(int(level)-DiffToInfo, msg, keysAndValues...)
}
// Error logs an error, with the given message and key/value pairs.
// It functions similarly to Print, but may have unique behavior, and should be
// preferred for logging errors.
func (logger *Logger) Error(err error, msg string, keysAndValues ...interface{}) {
if logger.Sink == nil {
return
}
logger.Sink.Error(err, msg, keysAndValues...)
}
// selectMaxDocumentLength will return the integer value of the first non-zero
// function, with the user-defined function taking priority over the environment
// variables. For the environment, the function will attempt to get the value of
// "MONGODB_LOG_MAX_DOCUMENT_LENGTH" and parse it as an unsigned integer. If the
// environment variable is not set or is not an unsigned integer, then this
// function will return the default max document length.
func selectMaxDocumentLength(maxDocLen uint) uint {
if maxDocLen != 0 {
return maxDocLen
}
maxDocLenEnv := os.Getenv(maxDocumentLengthEnvVar)
if maxDocLenEnv != "" {
maxDocLenEnvInt, err := strconv.ParseUint(maxDocLenEnv, 10, 32)
if err == nil {
return uint(maxDocLenEnvInt)
}
}
return DefaultMaxDocumentLength
}
const (
logSinkPathStdout = "stdout"
logSinkPathStderr = "stderr"
)
// selectLogSink will return the first non-nil LogSink, with the user-defined
// LogSink taking precedence over the environment-defined LogSink. If no LogSink
// is defined, then this function will return a LogSink that writes to stderr.
func selectLogSink(sink LogSink) (LogSink, *os.File, error) {
if sink != nil {
return sink, nil, nil
}
path := os.Getenv(logSinkPathEnvVar)
lowerPath := strings.ToLower(path)
if lowerPath == string(logSinkPathStderr) {
return NewIOSink(os.Stderr), nil, nil
}
if lowerPath == string(logSinkPathStdout) {
return NewIOSink(os.Stdout), nil, nil
}
if path != "" {
logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666)
if err != nil {
return nil, nil, fmt.Errorf("unable to open log file: %w", err)
}
return NewIOSink(logFile), logFile, nil
}
return NewIOSink(os.Stderr), nil, nil
}
// selectComponentLevels returns a new map of LogComponents to LogLevels that is
// the result of merging the user-defined data with the environment, with the
// user-defined data taking priority.
func selectComponentLevels(componentLevels map[Component]Level) map[Component]Level {
selected := make(map[Component]Level)
// Determine if the "MONGODB_LOG_ALL" environment variable is set.
var globalEnvLevel *Level
if all := os.Getenv(mongoDBLogAllEnvVar); all != "" {
level := ParseLevel(all)
globalEnvLevel = &level
}
for envVar, component := range componentEnvVarMap {
// If the component already has a level, then skip it.
if _, ok := componentLevels[component]; ok {
selected[component] = componentLevels[component]
continue
}
// If the "MONGODB_LOG_ALL" environment variable is set, then
// set the level for the component to the value of the
// environment variable.
if globalEnvLevel != nil {
selected[component] = *globalEnvLevel
continue
}
// Otherwise, set the level for the component to the value of
// the environment variable.
selected[component] = ParseLevel(os.Getenv(envVar))
}
return selected
}
// FormatDocument formats a BSON document or RawValue for logging. The document is truncated
// to the given width.
func FormatDocument(msg bson.Raw, width uint) string {
if len(msg) == 0 {
return "{}"
}
str := bsoncore.Document(msg).StringN(int(width))
// If the last byte is not a closing bracket, then the document was truncated
if len(str) > 0 && str[len(str)-1] != '}' {
str += TruncationSuffix
}
return str
}
// FormatString formats a String for logging. The string is truncated
// to the given width.
func FormatString(str string, width uint) string {
strTrunc := bsoncoreutil.Truncate(str, int(width))
// Checks if the string was truncating by comparing the lengths of the two strings.
if len(strTrunc) < len(str) {
strTrunc += TruncationSuffix
}
return strTrunc
}

View File

@ -0,0 +1,576 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package logger
import (
"bytes"
"encoding/json"
"fmt"
"os"
"reflect"
"strings"
"sync"
"testing"
"gitea.psichedelico.com/go/bson"
"gitea.psichedelico.com/go/bson/internal/assert"
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
)
type mockLogSink struct{}
func (mockLogSink) Info(int, string, ...interface{}) {}
func (mockLogSink) Error(error, string, ...interface{}) {}
func BenchmarkLoggerWithLargeDocuments(b *testing.B) {
// Define the large document test cases
testCases := []struct {
name string
create func() bson.D
}{
{
name: "LargeStrings",
create: func() bson.D { return createLargeStringsDocument(10) },
},
{
name: "MassiveArrays",
create: func() bson.D { return createMassiveArraysDocument(100000) },
},
{
name: "VeryVoluminousDocument",
create: func() bson.D { return createVoluminousDocument(100000) },
},
}
for _, tc := range testCases {
tc := tc
b.Run(tc.name, func(b *testing.B) {
// Run benchmark with logging and truncation enabled
b.Run("LoggingWithTruncation", func(b *testing.B) {
logger, err := New(mockLogSink{}, 0, map[Component]Level{
ComponentCommand: LevelDebug,
})
if err != nil {
b.Fatal(err)
}
bs, err := bson.Marshal(tc.create())
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Print(LevelInfo, ComponentCommand, FormatDocument(bs, 1024), "foo", "bar", "baz")
}
})
// Run benchmark with logging enabled without truncation
b.Run("LoggingWithoutTruncation", func(b *testing.B) {
logger, err := New(mockLogSink{}, 0, map[Component]Level{
ComponentCommand: LevelDebug,
})
if err != nil {
b.Fatal(err)
}
bs, err := bson.Marshal(tc.create())
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
msg := bsoncore.Document(bs).String()
logger.Print(LevelInfo, ComponentCommand, msg, "foo", "bar", "baz")
}
})
// Run benchmark without logging or truncation
b.Run("WithoutLoggingOrTruncation", func(b *testing.B) {
bs, err := bson.Marshal(tc.create())
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = bsoncore.Document(bs).String()
}
})
})
}
}
// Helper functions to create large documents
func createVoluminousDocument(numKeys int) bson.D {
d := make(bson.D, numKeys)
for i := 0; i < numKeys; i++ {
d = append(d, bson.E{Key: fmt.Sprintf("key%d", i), Value: "value"})
}
return d
}
func createLargeStringsDocument(sizeMB int) bson.D {
largeString := strings.Repeat("a", sizeMB*1024*1024)
return bson.D{
{Key: "largeString1", Value: largeString},
{Key: "largeString2", Value: largeString},
{Key: "largeString3", Value: largeString},
{Key: "largeString4", Value: largeString},
}
}
func createMassiveArraysDocument(arraySize int) bson.D {
massiveArray := make([]string, arraySize)
for i := 0; i < arraySize; i++ {
massiveArray[i] = "value"
}
return bson.D{
{Key: "massiveArray1", Value: massiveArray},
{Key: "massiveArray2", Value: massiveArray},
{Key: "massiveArray3", Value: massiveArray},
{Key: "massiveArray4", Value: massiveArray},
}
}
func BenchmarkLogger(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
b.Run("Print", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
logger, err := New(mockLogSink{}, 0, map[Component]Level{
ComponentCommand: LevelDebug,
})
if err != nil {
b.Fatal(err)
}
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Print(LevelInfo, ComponentCommand, "foo", "bar", "baz")
}
})
})
}
func mockKeyValues(length int) (KeyValues, map[string]interface{}) {
keysAndValues := KeyValues{}
m := map[string]interface{}{}
for i := 0; i < length; i++ {
keyName := fmt.Sprintf("key%d", i)
valueName := fmt.Sprintf("value%d", i)
keysAndValues.Add(keyName, valueName)
m[keyName] = valueName
}
return keysAndValues, m
}
func BenchmarkIOSinkInfo(b *testing.B) {
keysAndValues, _ := mockKeyValues(10)
b.ReportAllocs()
b.ResetTimer()
sink := NewIOSink(bytes.NewBuffer(nil))
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
sink.Info(0, "foo", keysAndValues...)
}
})
}
func TestIOSinkInfo(t *testing.T) {
t.Parallel()
const threshold = 1000
mockKeyValues, kvmap := mockKeyValues(10)
buf := new(bytes.Buffer)
sink := NewIOSink(buf)
wg := sync.WaitGroup{}
wg.Add(threshold)
for i := 0; i < threshold; i++ {
go func() {
defer wg.Done()
sink.Info(0, "foo", mockKeyValues...)
}()
}
wg.Wait()
dec := json.NewDecoder(buf)
for dec.More() {
var m map[string]interface{}
if err := dec.Decode(&m); err != nil {
t.Fatalf("error unmarshaling JSON: %v", err)
}
delete(m, KeyTimestamp)
delete(m, KeyMessage)
if !reflect.DeepEqual(m, kvmap) {
t.Fatalf("expected %v, got %v", kvmap, m)
}
}
}
func TestSelectMaxDocumentLength(t *testing.T) {
for _, tcase := range []struct {
name string
arg uint
expected uint
env map[string]string
}{
{
name: "default",
arg: 0,
expected: DefaultMaxDocumentLength,
},
{
name: "non-zero",
arg: 100,
expected: 100,
},
{
name: "valid env",
arg: 0,
expected: 100,
env: map[string]string{
maxDocumentLengthEnvVar: "100",
},
},
{
name: "invalid env",
arg: 0,
expected: DefaultMaxDocumentLength,
env: map[string]string{
maxDocumentLengthEnvVar: "foo",
},
},
} {
tcase := tcase
t.Run(tcase.name, func(t *testing.T) {
for k, v := range tcase.env {
t.Setenv(k, v)
}
actual := selectMaxDocumentLength(tcase.arg)
if actual != tcase.expected {
t.Errorf("expected %d, got %d", tcase.expected, actual)
}
})
}
}
func TestSelectLogSink(t *testing.T) {
for _, tcase := range []struct {
name string
arg LogSink
expected LogSink
env map[string]string
}{
{
name: "default",
arg: nil,
expected: NewIOSink(os.Stderr),
},
{
name: "non-nil",
arg: mockLogSink{},
expected: mockLogSink{},
},
{
name: "stdout",
arg: nil,
expected: NewIOSink(os.Stdout),
env: map[string]string{
logSinkPathEnvVar: logSinkPathStdout,
},
},
{
name: "stderr",
arg: nil,
expected: NewIOSink(os.Stderr),
env: map[string]string{
logSinkPathEnvVar: logSinkPathStderr,
},
},
} {
tcase := tcase
t.Run(tcase.name, func(t *testing.T) {
for k, v := range tcase.env {
t.Setenv(k, v)
}
actual, _, _ := selectLogSink(tcase.arg)
if !reflect.DeepEqual(actual, tcase.expected) {
t.Errorf("expected %+v, got %+v", tcase.expected, actual)
}
})
}
}
func TestSelectedComponentLevels(t *testing.T) {
for _, tcase := range []struct {
name string
arg map[Component]Level
expected map[Component]Level
env map[string]string
}{
{
name: "default",
arg: nil,
expected: map[Component]Level{
ComponentCommand: LevelOff,
ComponentTopology: LevelOff,
ComponentServerSelection: LevelOff,
ComponentConnection: LevelOff,
},
},
{
name: "non-nil",
arg: map[Component]Level{
ComponentCommand: LevelDebug,
},
expected: map[Component]Level{
ComponentCommand: LevelDebug,
ComponentTopology: LevelOff,
ComponentServerSelection: LevelOff,
ComponentConnection: LevelOff,
},
},
{
name: "valid env",
arg: nil,
expected: map[Component]Level{
ComponentCommand: LevelDebug,
ComponentTopology: LevelInfo,
ComponentServerSelection: LevelOff,
ComponentConnection: LevelOff,
},
env: map[string]string{
mongoDBLogCommandEnvVar: levelLiteralDebug,
mongoDBLogTopologyEnvVar: levelLiteralInfo,
},
},
{
name: "invalid env",
arg: nil,
expected: map[Component]Level{
ComponentCommand: LevelOff,
ComponentTopology: LevelOff,
ComponentServerSelection: LevelOff,
ComponentConnection: LevelOff,
},
env: map[string]string{
mongoDBLogCommandEnvVar: "foo",
mongoDBLogTopologyEnvVar: "bar",
},
},
} {
tcase := tcase
t.Run(tcase.name, func(t *testing.T) {
for k, v := range tcase.env {
t.Setenv(k, v)
}
actual := selectComponentLevels(tcase.arg)
for k, v := range tcase.expected {
if actual[k] != v {
t.Errorf("expected %d, got %d", v, actual[k])
}
}
})
}
}
func TestLogger_LevelComponentEnabled(t *testing.T) {
t.Parallel()
tests := []struct {
name string
logger Logger
level Level
component Component
want bool
}{
{
name: "zero",
logger: Logger{},
level: LevelOff,
component: ComponentCommand,
want: false,
},
{
name: "empty",
logger: Logger{
ComponentLevels: map[Component]Level{},
},
level: LevelOff,
component: ComponentCommand,
want: false, // LevelOff should never be considered enabled.
},
{
name: "one level below",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentCommand: LevelDebug,
},
},
level: LevelInfo,
component: ComponentCommand,
want: true,
},
{
name: "equal levels",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentCommand: LevelDebug,
},
},
level: LevelDebug,
component: ComponentCommand,
want: true,
},
{
name: "one level above",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentCommand: LevelInfo,
},
},
level: LevelDebug,
component: ComponentCommand,
want: false,
},
{
name: "component mismatch",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentCommand: LevelDebug,
},
},
level: LevelDebug,
component: ComponentTopology,
want: false,
},
{
name: "component all enables with topology",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentAll: LevelDebug,
},
},
level: LevelDebug,
component: ComponentTopology,
want: true,
},
{
name: "component all enables with server selection",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentAll: LevelDebug,
},
},
level: LevelDebug,
component: ComponentServerSelection,
want: true,
},
{
name: "component all enables with connection",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentAll: LevelDebug,
},
},
level: LevelDebug,
component: ComponentConnection,
want: true,
},
{
name: "component all enables with command",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentAll: LevelDebug,
},
},
level: LevelDebug,
component: ComponentCommand,
want: true,
},
{
name: "component all enables with all",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentAll: LevelDebug,
},
},
level: LevelDebug,
component: ComponentAll,
want: true,
},
{
name: "component all does not enable with lower level",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentAll: LevelInfo,
},
},
level: LevelDebug,
component: ComponentCommand,
want: false,
},
{
name: "component all has a lower log level than command",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentAll: LevelInfo,
ComponentCommand: LevelDebug,
},
},
level: LevelDebug,
component: ComponentCommand,
want: true,
},
{
name: "component all has a higher log level than command",
logger: Logger{
ComponentLevels: map[Component]Level{
ComponentAll: LevelDebug,
ComponentCommand: LevelInfo,
},
},
level: LevelDebug,
component: ComponentCommand,
want: true,
},
}
for _, tcase := range tests {
tcase := tcase // Capture the range variable.
t.Run(tcase.name, func(t *testing.T) {
t.Parallel()
got := tcase.logger.LevelComponentEnabled(tcase.level, tcase.component)
assert.Equal(t, tcase.want, got, "unexpected result for LevelComponentEnabled")
})
}
}

Some files were not shown because too many files have changed in this diff Show More