Initial commit
This commit is contained in:
commit
4b4cceb81c
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
1692
THIRD-PARTY-NOTICES
Normal file
1692
THIRD-PARTY-NOTICES
Normal file
File diff suppressed because it is too large
Load Diff
42
array_codec.go
Normal file
42
array_codec.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
// arrayCodec is the Codec used for bsoncore.Array values.
|
||||||
|
type arrayCodec struct{}
|
||||||
|
|
||||||
|
// EncodeValue is the ValueEncoder for bsoncore.Array values.
|
||||||
|
func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tCoreArray {
|
||||||
|
return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
arr := val.Interface().(bsoncore.Array)
|
||||||
|
return copyArrayFromBytes(vw, arr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeValue is the ValueDecoder for bsoncore.Array values.
|
||||||
|
func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||||
|
if !val.CanSet() || val.Type() != tCoreArray {
|
||||||
|
return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.IsNil() {
|
||||||
|
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
val.SetLen(0)
|
||||||
|
arr, err := appendArrayBytes(val.Interface().(bsoncore.Array), vr)
|
||||||
|
val.Set(reflect.ValueOf(arr))
|
||||||
|
return err
|
||||||
|
}
|
449
benchmark_test.go
Normal file
449
benchmark_test.go
Normal file
@ -0,0 +1,449 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type encodetest struct {
|
||||||
|
Field1String string
|
||||||
|
Field1Int64 int64
|
||||||
|
Field1Float64 float64
|
||||||
|
Field2String string
|
||||||
|
Field2Int64 int64
|
||||||
|
Field2Float64 float64
|
||||||
|
Field3String string
|
||||||
|
Field3Int64 int64
|
||||||
|
Field3Float64 float64
|
||||||
|
Field4String string
|
||||||
|
Field4Int64 int64
|
||||||
|
Field4Float64 float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest1 struct {
|
||||||
|
Nested nestedtest2
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest2 struct {
|
||||||
|
Nested nestedtest3
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest3 struct {
|
||||||
|
Nested nestedtest4
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest4 struct {
|
||||||
|
Nested nestedtest5
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest5 struct {
|
||||||
|
Nested nestedtest6
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest6 struct {
|
||||||
|
Nested nestedtest7
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest7 struct {
|
||||||
|
Nested nestedtest8
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest8 struct {
|
||||||
|
Nested nestedtest9
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest9 struct {
|
||||||
|
Nested nestedtest10
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest10 struct {
|
||||||
|
Nested nestedtest11
|
||||||
|
}
|
||||||
|
|
||||||
|
type nestedtest11 struct {
|
||||||
|
Nested encodetest
|
||||||
|
}
|
||||||
|
|
||||||
|
var encodetestInstance = encodetest{
|
||||||
|
Field1String: "foo",
|
||||||
|
Field1Int64: 1,
|
||||||
|
Field1Float64: 3.0,
|
||||||
|
Field2String: "bar",
|
||||||
|
Field2Int64: 2,
|
||||||
|
Field2Float64: 3.1,
|
||||||
|
Field3String: "baz",
|
||||||
|
Field3Int64: 3,
|
||||||
|
Field3Float64: 3.14,
|
||||||
|
Field4String: "qux",
|
||||||
|
Field4Int64: 4,
|
||||||
|
Field4Float64: 3.141,
|
||||||
|
}
|
||||||
|
|
||||||
|
var nestedInstance = nestedtest1{
|
||||||
|
nestedtest2{
|
||||||
|
nestedtest3{
|
||||||
|
nestedtest4{
|
||||||
|
nestedtest5{
|
||||||
|
nestedtest6{
|
||||||
|
nestedtest7{
|
||||||
|
nestedtest8{
|
||||||
|
nestedtest9{
|
||||||
|
nestedtest10{
|
||||||
|
nestedtest11{
|
||||||
|
encodetest{
|
||||||
|
Field1String: "foo",
|
||||||
|
Field1Int64: 1,
|
||||||
|
Field1Float64: 3.0,
|
||||||
|
Field2String: "bar",
|
||||||
|
Field2Int64: 2,
|
||||||
|
Field2Float64: 3.1,
|
||||||
|
Field3String: "baz",
|
||||||
|
Field3Int64: 3,
|
||||||
|
Field3Float64: 3.14,
|
||||||
|
Field4String: "qux",
|
||||||
|
Field4Int64: 4,
|
||||||
|
Field4Float64: 3.141,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
const extendedBSONDir = "./testdata/extended_bson"
|
||||||
|
|
||||||
|
var (
|
||||||
|
extJSONFiles map[string]map[string]interface{}
|
||||||
|
extJSONFilesMu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// readExtJSONFile reads the GZIP-compressed extended JSON document from the given filename in the
|
||||||
|
// "extended BSON" test data directory (./testdata/extended_bson) and returns it as a
|
||||||
|
// map[string]interface{}. It panics on any errors.
|
||||||
|
func readExtJSONFile(filename string) map[string]interface{} {
|
||||||
|
extJSONFilesMu.Lock()
|
||||||
|
defer extJSONFilesMu.Unlock()
|
||||||
|
if v, ok := extJSONFiles[filename]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
filePath := path.Join(extendedBSONDir, filename)
|
||||||
|
file, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("error opening file %q: %s", filePath, err))
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = file.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
gz, err := gzip.NewReader(file)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("error creating GZIP reader: %s", err))
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = gz.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
data, err := ioutil.ReadAll(gz)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("error reading GZIP contents of file: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
var v map[string]interface{}
|
||||||
|
err = UnmarshalExtJSON(data, false, &v)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("error unmarshalling extended JSON: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if extJSONFiles == nil {
|
||||||
|
extJSONFiles = make(map[string]map[string]interface{})
|
||||||
|
}
|
||||||
|
extJSONFiles[filename] = v
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMarshal(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
value interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "simple struct",
|
||||||
|
value: encodetestInstance,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "nested struct",
|
||||||
|
value: nestedInstance,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "deep_bson.json.gz",
|
||||||
|
value: readExtJSONFile("deep_bson.json.gz"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "flat_bson.json.gz",
|
||||||
|
value: readExtJSONFile("flat_bson.json.gz"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "full_bson.json.gz",
|
||||||
|
value: readExtJSONFile("full_bson.json.gz"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.desc, func(b *testing.B) {
|
||||||
|
b.Run("BSON", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := Marshal(tc.value)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("error marshalling BSON: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("extJSON", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := MarshalExtJSON(tc.value, true, false)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("error marshalling extended JSON: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("JSON", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := json.Marshal(tc.value)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("error marshalling JSON: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnmarshal(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
value interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "simple struct",
|
||||||
|
value: encodetestInstance,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "nested struct",
|
||||||
|
value: nestedInstance,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "deep_bson.json.gz",
|
||||||
|
value: readExtJSONFile("deep_bson.json.gz"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "flat_bson.json.gz",
|
||||||
|
value: readExtJSONFile("flat_bson.json.gz"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "full_bson.json.gz",
|
||||||
|
value: readExtJSONFile("full_bson.json.gz"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.desc, func(b *testing.B) {
|
||||||
|
b.Run("BSON", func(b *testing.B) {
|
||||||
|
data, err := Marshal(tc.value)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("error marshalling BSON: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
var v2 map[string]interface{}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
err := Unmarshal(data, &v2)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("error unmarshalling BSON: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("extJSON", func(b *testing.B) {
|
||||||
|
data, err := MarshalExtJSON(tc.value, true, false)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("error marshalling extended JSON: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
var v2 map[string]interface{}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
err := UnmarshalExtJSON(data, true, &v2)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("error unmarshalling extended JSON: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("JSON", func(b *testing.B) {
|
||||||
|
data, err := json.Marshal(tc.value)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("error marshalling JSON: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
var v2 map[string]interface{}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
err := json.Unmarshal(data, &v2)
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("error unmarshalling JSON: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following benchmarks are copied from the Go standard library's
|
||||||
|
// encoding/json package.
|
||||||
|
|
||||||
|
type codeResponse struct {
|
||||||
|
Tree *codeNode `json:"tree"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type codeNode struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Kids []*codeNode `json:"kids"`
|
||||||
|
CLWeight float64 `json:"cl_weight"`
|
||||||
|
Touches int `json:"touches"`
|
||||||
|
MinT int64 `json:"min_t"`
|
||||||
|
MaxT int64 `json:"max_t"`
|
||||||
|
MeanT int64 `json:"mean_t"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var codeJSON []byte
|
||||||
|
var codeBSON []byte
|
||||||
|
var codeStruct codeResponse
|
||||||
|
|
||||||
|
func codeInit() {
|
||||||
|
f, err := os.Open("testdata/code.json.gz")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
gz, err := gzip.NewReader(f)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(gz)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
codeJSON = data
|
||||||
|
|
||||||
|
if err := json.Unmarshal(codeJSON, &codeStruct); err != nil {
|
||||||
|
panic("json.Unmarshal code.json: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if data, err = json.Marshal(&codeStruct); err != nil {
|
||||||
|
panic("json.Marshal code.json: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if codeBSON, err = Marshal(&codeStruct); err != nil {
|
||||||
|
panic("Marshal code.json: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(data, codeJSON) {
|
||||||
|
println("different lengths", len(data), len(codeJSON))
|
||||||
|
for i := 0; i < len(data) && i < len(codeJSON); i++ {
|
||||||
|
if data[i] != codeJSON[i] {
|
||||||
|
println("re-marshal: changed at byte", i)
|
||||||
|
println("orig: ", string(codeJSON[i-10:i+10]))
|
||||||
|
println("new: ", string(data[i-10:i+10]))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic("re-marshal code.json: different result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCodeUnmarshal(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
if codeJSON == nil {
|
||||||
|
b.StopTimer()
|
||||||
|
codeInit()
|
||||||
|
b.StartTimer()
|
||||||
|
}
|
||||||
|
b.Run("BSON", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
var r codeResponse
|
||||||
|
if err := Unmarshal(codeBSON, &r); err != nil {
|
||||||
|
b.Fatal("Unmarshal:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.SetBytes(int64(len(codeBSON)))
|
||||||
|
})
|
||||||
|
b.Run("JSON", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
var r codeResponse
|
||||||
|
if err := json.Unmarshal(codeJSON, &r); err != nil {
|
||||||
|
b.Fatal("json.Unmarshal:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.SetBytes(int64(len(codeJSON)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCodeMarshal(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
if codeJSON == nil {
|
||||||
|
b.StopTimer()
|
||||||
|
codeInit()
|
||||||
|
b.StartTimer()
|
||||||
|
}
|
||||||
|
b.Run("BSON", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
if _, err := Marshal(&codeStruct); err != nil {
|
||||||
|
b.Fatal("Marshal:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.SetBytes(int64(len(codeBSON)))
|
||||||
|
})
|
||||||
|
b.Run("JSON", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
if _, err := json.Marshal(&codeStruct); err != nil {
|
||||||
|
b.Fatal("json.Marshal:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.SetBytes(int64(len(codeJSON)))
|
||||||
|
})
|
||||||
|
}
|
191
bson_binary_vector_spec_test.go
Normal file
191
bson_binary_vector_spec_test.go
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"math"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const bsonBinaryVectorDir = "./testdata/bson-binary-vector/"
|
||||||
|
|
||||||
|
type bsonBinaryVectorTests struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
TestKey string `json:"test_key"`
|
||||||
|
Tests []bsonBinaryVectorTestCase `json:"tests"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type bsonBinaryVectorTestCase struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
Vector []interface{} `json:"vector"`
|
||||||
|
DtypeHex string `json:"dtype_hex"`
|
||||||
|
DtypeAlias string `json:"dtype_alias"`
|
||||||
|
Padding int `json:"padding"`
|
||||||
|
CanonicalBson string `json:"canonical_bson"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBsonBinaryVectorSpec(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir)
|
||||||
|
require.NoErrorf(t, err, "error finding JSON files in %s: %v", bsonBinaryVectorDir, err)
|
||||||
|
|
||||||
|
for _, file := range jsonFiles {
|
||||||
|
filepath := path.Join(bsonBinaryVectorDir, file)
|
||||||
|
content, err := os.ReadFile(filepath)
|
||||||
|
require.NoErrorf(t, err, "reading test file %s", filepath)
|
||||||
|
|
||||||
|
var tests bsonBinaryVectorTests
|
||||||
|
require.NoErrorf(t, json.Unmarshal(content, &tests), "parsing test file %s", filepath)
|
||||||
|
|
||||||
|
t.Run(tests.Description, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, test := range tests.Tests {
|
||||||
|
test := test
|
||||||
|
t.Run(test.Description, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
runBsonBinaryVectorTest(t, tests.TestKey, test)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("Marshaling", func(t *testing.T) {
|
||||||
|
_, err := NewPackedBitVector(nil, 1)
|
||||||
|
require.EqualError(t, err, errNonZeroVectorPadding.Error())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("Marshaling", func(t *testing.T) {
|
||||||
|
_, err := NewPackedBitVector(nil, 8)
|
||||||
|
require.EqualError(t, err, errVectorPaddingTooLarge.Error())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
|
||||||
|
v := make([]T, len(s))
|
||||||
|
for i, e := range s {
|
||||||
|
f := math.NaN()
|
||||||
|
switch val := e.(type) {
|
||||||
|
case float64:
|
||||||
|
f = val
|
||||||
|
case string:
|
||||||
|
if val == "inf" {
|
||||||
|
f = math.Inf(0)
|
||||||
|
} else if val == "-inf" {
|
||||||
|
f = math.Inf(-1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v[i] = T(f)
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) {
|
||||||
|
testVector := make(map[string]Vector)
|
||||||
|
switch alias := test.DtypeHex; alias {
|
||||||
|
case "0x03":
|
||||||
|
testVector[testKey] = Vector{
|
||||||
|
dType: Int8Vector,
|
||||||
|
int8Data: convertSlice[int8](test.Vector),
|
||||||
|
}
|
||||||
|
case "0x27":
|
||||||
|
testVector[testKey] = Vector{
|
||||||
|
dType: Float32Vector,
|
||||||
|
float32Data: convertSlice[float32](test.Vector),
|
||||||
|
}
|
||||||
|
case "0x10":
|
||||||
|
testVector[testKey] = Vector{
|
||||||
|
dType: PackedBitVector,
|
||||||
|
bitData: convertSlice[byte](test.Vector),
|
||||||
|
bitPadding: uint8(test.Padding),
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Fatalf("unsupported vector type: %s", alias)
|
||||||
|
}
|
||||||
|
|
||||||
|
testBSON, err := hex.DecodeString(test.CanonicalBson)
|
||||||
|
require.NoError(t, err, "decoding canonical BSON")
|
||||||
|
|
||||||
|
t.Run("Unmarshaling", func(t *testing.T) {
|
||||||
|
skipCases := map[string]string{
|
||||||
|
"Overflow Vector INT8": "compile-time restriction",
|
||||||
|
"Underflow Vector INT8": "compile-time restriction",
|
||||||
|
"INT8 with float inputs": "compile-time restriction",
|
||||||
|
"Overflow Vector PACKED_BIT": "compile-time restriction",
|
||||||
|
"Underflow Vector PACKED_BIT": "compile-time restriction",
|
||||||
|
"Vector with float values PACKED_BIT": "compile-time restriction",
|
||||||
|
"Negative padding PACKED_BIT": "compile-time restriction",
|
||||||
|
}
|
||||||
|
if reason, ok := skipCases[test.Description]; ok {
|
||||||
|
t.Skipf("skip test case %s: %s", test.Description, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
errMap := map[string]string{
|
||||||
|
"FLOAT32 with padding": "padding must be 0",
|
||||||
|
"INT8 with padding": "padding must be 0",
|
||||||
|
"Padding specified with no vector data PACKED_BIT": "padding must be 0",
|
||||||
|
"Exceeding maximum padding PACKED_BIT": "padding cannot be larger than 7",
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var got map[string]Vector
|
||||||
|
err := Unmarshal(testBSON, &got)
|
||||||
|
if test.Valid {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, testVector, got)
|
||||||
|
} else if errMsg, ok := errMap[test.Description]; ok {
|
||||||
|
require.ErrorContains(t, err, errMsg)
|
||||||
|
} else {
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Marshaling", func(t *testing.T) {
|
||||||
|
skipCases := map[string]string{
|
||||||
|
"FLOAT32 with padding": "private padding field",
|
||||||
|
"Insufficient vector data with 3 bytes FLOAT32": "invalid case",
|
||||||
|
"Insufficient vector data with 5 bytes FLOAT32": "invalid case",
|
||||||
|
"Overflow Vector INT8": "compile-time restriction",
|
||||||
|
"Underflow Vector INT8": "compile-time restriction",
|
||||||
|
"INT8 with padding": "private padding field",
|
||||||
|
"INT8 with float inputs": "compile-time restriction",
|
||||||
|
"Overflow Vector PACKED_BIT": "compile-time restriction",
|
||||||
|
"Underflow Vector PACKED_BIT": "compile-time restriction",
|
||||||
|
"Vector with float values PACKED_BIT": "compile-time restriction",
|
||||||
|
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
|
||||||
|
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
|
||||||
|
"Negative padding PACKED_BIT": "compile-time restriction",
|
||||||
|
}
|
||||||
|
if reason, ok := skipCases[test.Description]; ok {
|
||||||
|
t.Skipf("skip test case %s: %s", test.Description, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got, err := Marshal(testVector)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, testBSON, got)
|
||||||
|
})
|
||||||
|
}
|
504
bson_corpus_spec_test.go
Normal file
504
bson_corpus_spec_test.go
Normal file
@ -0,0 +1,504 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"unicode"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/require"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
BsonType string `json:"bson_type"`
|
||||||
|
TestKey *string `json:"test_key"`
|
||||||
|
Valid []validityTestCase `json:"valid"`
|
||||||
|
DecodeErrors []decodeErrorTestCase `json:"decodeErrors"`
|
||||||
|
ParseErrors []parseErrorTestCase `json:"parseErrors"`
|
||||||
|
Deprecated *bool `json:"deprecated"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type validityTestCase struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
CanonicalBson string `json:"canonical_bson"`
|
||||||
|
CanonicalExtJSON string `json:"canonical_extjson"`
|
||||||
|
RelaxedExtJSON *string `json:"relaxed_extjson"`
|
||||||
|
DegenerateBSON *string `json:"degenerate_bson"`
|
||||||
|
DegenerateExtJSON *string `json:"degenerate_extjson"`
|
||||||
|
ConvertedBSON *string `json:"converted_bson"`
|
||||||
|
ConvertedExtJSON *string `json:"converted_extjson"`
|
||||||
|
Lossy *bool `json:"lossy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type decodeErrorTestCase struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
Bson string `json:"bson"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type parseErrorTestCase struct {
|
||||||
|
Description string `json:"description"`
|
||||||
|
String string `json:"string"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const dataDir = "./testdata/bson-corpus/"
|
||||||
|
|
||||||
|
func findJSONFilesInDir(dir string) ([]string, error) {
|
||||||
|
files := make([]string, 0)
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
files = append(files, entry.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
return files, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// seedExtJSON will add the byte representation of the "extJSON" string to the fuzzer's coprus.
|
||||||
|
func seedExtJSON(f *testing.F, extJSON string, extJSONType string, desc string) {
|
||||||
|
jbytes, err := jsonToBytes(extJSON, extJSONType, desc)
|
||||||
|
if err != nil {
|
||||||
|
f.Fatalf("failed to convert JSON to bytes: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.Add(jbytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// seedTestCase will add the byte representation for each "extJSON" string of each valid test case to the fuzzer's
|
||||||
|
// corpus.
|
||||||
|
func seedTestCase(f *testing.F, tcase *testCase) {
|
||||||
|
for _, vtc := range tcase.Valid {
|
||||||
|
seedExtJSON(f, vtc.CanonicalExtJSON, "canonical", vtc.Description)
|
||||||
|
|
||||||
|
// Seed the relaxed extended JSON.
|
||||||
|
if vtc.RelaxedExtJSON != nil {
|
||||||
|
seedExtJSON(f, *vtc.RelaxedExtJSON, "relaxed", vtc.Description)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seed the degenerate extended JSON.
|
||||||
|
if vtc.DegenerateExtJSON != nil {
|
||||||
|
seedExtJSON(f, *vtc.DegenerateExtJSON, "degenerate", vtc.Description)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seed the converted extended JSON.
|
||||||
|
if vtc.ConvertedExtJSON != nil {
|
||||||
|
seedExtJSON(f, *vtc.ConvertedExtJSON, "converted", vtc.Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// seedBSONCorpus will unmarshal the data from "testdata/bson-corpus" into a slice of "testCase" structs and then
|
||||||
|
// marshal the "*_extjson" field of each "validityTestCase" into a slice of bytes to seed the fuzz corpus.
|
||||||
|
func seedBSONCorpus(f *testing.F) {
|
||||||
|
fileNames, err := findJSONFilesInDir(dataDir)
|
||||||
|
if err != nil {
|
||||||
|
f.Fatalf("failed to find JSON files in directory %q: %v", dataDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, fileName := range fileNames {
|
||||||
|
filePath := path.Join(dataDir, fileName)
|
||||||
|
|
||||||
|
file, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
f.Fatalf("failed to open file %q: %v", filePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tcase testCase
|
||||||
|
if err := json.NewDecoder(file).Decode(&tcase); err != nil {
|
||||||
|
f.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedTestCase(f, &tcase)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func needsEscapedUnicode(bsonType string) bool {
|
||||||
|
return bsonType == "0x02" || bsonType == "0x0D" || bsonType == "0x0E" || bsonType == "0x0F"
|
||||||
|
}
|
||||||
|
|
||||||
|
func unescapeUnicode(s, bsonType string) string {
|
||||||
|
if !needsEscapedUnicode(bsonType) {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
newS := ""
|
||||||
|
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
c := s[i]
|
||||||
|
switch c {
|
||||||
|
case '\\':
|
||||||
|
switch s[i+1] {
|
||||||
|
case 'u':
|
||||||
|
us := s[i : i+6]
|
||||||
|
u, err := strconv.Unquote(strings.Replace(strconv.Quote(us), `\\u`, `\u`, 1))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, r := range u {
|
||||||
|
if r < ' ' {
|
||||||
|
newS += fmt.Sprintf(`\u%04x`, r)
|
||||||
|
} else {
|
||||||
|
newS += string(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += 5
|
||||||
|
default:
|
||||||
|
newS += string(c)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if c > unicode.MaxASCII {
|
||||||
|
r, size := utf8.DecodeRune([]byte(s[i:]))
|
||||||
|
newS += string(r)
|
||||||
|
i += size - 1
|
||||||
|
} else {
|
||||||
|
newS += string(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newS
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeCanonicalDouble(t *testing.T, key string, cEJ string) string {
|
||||||
|
// Unmarshal string into map
|
||||||
|
cEJMap := make(map[string]map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(cEJ), &cEJMap)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Parse the float contained by the map.
|
||||||
|
expectedString := cEJMap[key]["$numberDouble"]
|
||||||
|
expectedFloat, err := strconv.ParseFloat(expectedString, 64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Normalize the string
|
||||||
|
return fmt.Sprintf(`{"%s":{"$numberDouble":"%s"}}`, key, formatDouble(expectedFloat))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
|
||||||
|
// Unmarshal string into map
|
||||||
|
rEJMap := make(map[string]float64)
|
||||||
|
err := json.Unmarshal([]byte(rEJ), &rEJMap)
|
||||||
|
if err != nil {
|
||||||
|
return normalizeCanonicalDouble(t, key, rEJ)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the float contained by the map.
|
||||||
|
expectedFloat := rEJMap[key]
|
||||||
|
|
||||||
|
// Normalize the string
|
||||||
|
return fmt.Sprintf(`{"%s":%s}`, key, formatDouble(expectedFloat))
|
||||||
|
}
|
||||||
|
|
||||||
|
// bsonToNative decodes the BSON bytes (b) into a native Document
|
||||||
|
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
|
||||||
|
var doc D
|
||||||
|
err := Unmarshal(b, &doc)
|
||||||
|
require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType)
|
||||||
|
return doc
|
||||||
|
}
|
||||||
|
|
||||||
|
// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
|
||||||
|
// canonical BSON (cB)
|
||||||
|
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
|
||||||
|
actual, err := Marshal(doc)
|
||||||
|
require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(cB, actual); diff != "" {
|
||||||
|
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
|
||||||
|
testDesc, docSrcDesc, cB, actual)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// jsonToNative decodes the extended JSON string (ej) into a native Document
|
||||||
|
func jsonToNative(ej, ejType, testDesc string) (D, error) {
|
||||||
|
var doc D
|
||||||
|
if err := UnmarshalExtJSON([]byte(ej), ejType != "relaxed", &doc); err != nil {
|
||||||
|
return nil, fmt.Errorf("%s: decoding %s extended JSON: %w", testDesc, ejType, err)
|
||||||
|
}
|
||||||
|
return doc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// jsonToBytes decodes the extended JSON string (ej) into canonical BSON and then encodes it into a byte slice.
|
||||||
|
func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
|
||||||
|
native, err := jsonToNative(ej, ejType, testDesc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := Marshal(native)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%s: encoding %s BSON: %w", testDesc, ejType, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nativeToJSON encodes the native Document (doc) into an extended JSON string
|
||||||
|
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
|
||||||
|
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
|
||||||
|
require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
|
||||||
|
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
|
||||||
|
testDesc, ejType, docSrcDesc, ejShortName, diff)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runTest(t *testing.T, file string) {
|
||||||
|
filepath := path.Join(dataDir, file)
|
||||||
|
content, err := os.ReadFile(filepath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Remove ".json" from filename.
|
||||||
|
file = file[:len(file)-5]
|
||||||
|
testName := "bson_corpus--" + file
|
||||||
|
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
var test testCase
|
||||||
|
require.NoError(t, json.Unmarshal(content, &test))
|
||||||
|
|
||||||
|
t.Run("valid", func(t *testing.T) {
|
||||||
|
for _, v := range test.Valid {
|
||||||
|
t.Run(v.Description, func(t *testing.T) {
|
||||||
|
// get canonical BSON
|
||||||
|
cB, err := hex.DecodeString(v.CanonicalBson)
|
||||||
|
require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description)
|
||||||
|
|
||||||
|
// get canonical extended JSON
|
||||||
|
var compactEJ bytes.Buffer
|
||||||
|
require.NoError(t, json.Compact(&compactEJ, []byte(v.CanonicalExtJSON)))
|
||||||
|
cEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
|
||||||
|
if test.BsonType == "0x01" {
|
||||||
|
cEJ = normalizeCanonicalDouble(t, *test.TestKey, cEJ)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*** canonical BSON round-trip tests ***/
|
||||||
|
doc := bsonToNative(t, cB, "canonical", v.Description)
|
||||||
|
|
||||||
|
// native_to_bson(bson_to_native(cB)) = cB
|
||||||
|
nativeToBSON(t, cB, doc, v.Description, "canonical", "bson_to_native(cB)")
|
||||||
|
|
||||||
|
// native_to_canonical_extended_json(bson_to_native(cB)) = cEJ
|
||||||
|
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "bson_to_native(cB)")
|
||||||
|
|
||||||
|
// native_to_relaxed_extended_json(bson_to_native(cB)) = rEJ (if rEJ exists)
|
||||||
|
if v.RelaxedExtJSON != nil {
|
||||||
|
var compactEJ bytes.Buffer
|
||||||
|
require.NoError(t, json.Compact(&compactEJ, []byte(*v.RelaxedExtJSON)))
|
||||||
|
rEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
|
||||||
|
if test.BsonType == "0x01" {
|
||||||
|
rEJ = normalizeRelaxedDouble(t, *test.TestKey, rEJ)
|
||||||
|
}
|
||||||
|
|
||||||
|
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "rEJ", "bson_to_native(cB)")
|
||||||
|
|
||||||
|
/*** relaxed extended JSON round-trip tests (if exists) ***/
|
||||||
|
doc, err = jsonToNative(rEJ, "relaxed", v.Description)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// native_to_relaxed_extended_json(json_to_native(rEJ)) = rEJ
|
||||||
|
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "eJR", "json_to_native(rEJ)")
|
||||||
|
}
|
||||||
|
|
||||||
|
/*** canonical extended JSON round-trip tests ***/
|
||||||
|
doc, err = jsonToNative(cEJ, "canonical", v.Description)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// native_to_canonical_extended_json(json_to_native(cEJ)) = cEJ
|
||||||
|
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "json_to_native(cEJ)")
|
||||||
|
|
||||||
|
// native_to_bson(json_to_native(cEJ)) = cb (unless lossy)
|
||||||
|
if v.Lossy == nil || !*v.Lossy {
|
||||||
|
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(cEJ)")
|
||||||
|
}
|
||||||
|
|
||||||
|
/*** degenerate BSON round-trip tests (if exists) ***/
|
||||||
|
if v.DegenerateBSON != nil {
|
||||||
|
dB, err := hex.DecodeString(*v.DegenerateBSON)
|
||||||
|
require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description)
|
||||||
|
|
||||||
|
doc = bsonToNative(t, dB, "degenerate", v.Description)
|
||||||
|
|
||||||
|
// native_to_bson(bson_to_native(dB)) = cB
|
||||||
|
nativeToBSON(t, cB, doc, v.Description, "degenerate", "bson_to_native(dB)")
|
||||||
|
}
|
||||||
|
|
||||||
|
/*** degenerate JSON round-trip tests (if exists) ***/
|
||||||
|
if v.DegenerateExtJSON != nil {
|
||||||
|
var compactEJ bytes.Buffer
|
||||||
|
require.NoError(t, json.Compact(&compactEJ, []byte(*v.DegenerateExtJSON)))
|
||||||
|
dEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
|
||||||
|
if test.BsonType == "0x01" {
|
||||||
|
dEJ = normalizeCanonicalDouble(t, *test.TestKey, dEJ)
|
||||||
|
}
|
||||||
|
|
||||||
|
doc, err = jsonToNative(dEJ, "degenerate canonical", v.Description)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// native_to_canonical_extended_json(json_to_native(dEJ)) = cEJ
|
||||||
|
nativeToJSON(t, cEJ, doc, v.Description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
|
||||||
|
|
||||||
|
// native_to_bson(json_to_native(dEJ)) = cB (unless lossy)
|
||||||
|
if v.Lossy == nil || !*v.Lossy {
|
||||||
|
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(dEJ)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("decode error", func(t *testing.T) {
|
||||||
|
for _, d := range test.DecodeErrors {
|
||||||
|
t.Run(d.Description, func(t *testing.T) {
|
||||||
|
b, err := hex.DecodeString(d.Bson)
|
||||||
|
require.NoError(t, err, d.Description)
|
||||||
|
|
||||||
|
var doc D
|
||||||
|
err = Unmarshal(b, &doc)
|
||||||
|
|
||||||
|
// The driver unmarshals invalid UTF-8 strings without error. Loop over the unmarshalled elements
|
||||||
|
// and assert that there was no error if any of the string or DBPointer values contain invalid UTF-8
|
||||||
|
// characters.
|
||||||
|
for _, elem := range doc {
|
||||||
|
value := reflect.ValueOf(elem.Value)
|
||||||
|
invalidString := (value.Kind() == reflect.String) && !utf8.ValidString(value.String())
|
||||||
|
dbPtr, ok := elem.Value.(DBPointer)
|
||||||
|
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)
|
||||||
|
|
||||||
|
if invalidString || invalidDBPtr {
|
||||||
|
require.NoError(t, err, d.Description)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Errorf(t, err, "%s: expected decode error", d.Description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("parse error", func(t *testing.T) {
|
||||||
|
for _, p := range test.ParseErrors {
|
||||||
|
t.Run(p.Description, func(t *testing.T) {
|
||||||
|
s := unescapeUnicode(p.String, test.BsonType)
|
||||||
|
if test.BsonType == "0x13" {
|
||||||
|
s = fmt.Sprintf(`{"decimal128": {"$numberDecimal": "%s"}}`, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch test.BsonType {
|
||||||
|
case "0x00", "0x05", "0x13":
|
||||||
|
var doc D
|
||||||
|
err := UnmarshalExtJSON([]byte(s), true, &doc)
|
||||||
|
// Null bytes are validated when marshaling to BSON
|
||||||
|
if strings.Contains(p.Description, "Null") {
|
||||||
|
_, err = Marshal(doc)
|
||||||
|
}
|
||||||
|
require.Errorf(t, err, "%s: expected parse error", p.Description)
|
||||||
|
default:
|
||||||
|
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_BsonCorpus(t *testing.T) {
|
||||||
|
jsonFiles, err := findJSONFilesInDir(dataDir)
|
||||||
|
require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err)
|
||||||
|
|
||||||
|
for _, file := range jsonFiles {
|
||||||
|
runTest(t, file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelaxedUUIDValidation(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
description string
|
||||||
|
canonicalExtJSON string
|
||||||
|
degenerateExtJSON string
|
||||||
|
expectedErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"valid uuid",
|
||||||
|
"{\"x\" : { \"$binary\" : {\"base64\" : \"c//SZESzTGmQ6OfR38A11A==\", \"subType\" : \"04\"}}}",
|
||||||
|
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035d4\"}}",
|
||||||
|
"",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid uuid--no hyphens",
|
||||||
|
"",
|
||||||
|
"{\"x\" : { \"$uuid\" : \"73ffd26444b34c6990e8e7d1dfc035d4\"}}",
|
||||||
|
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid uuid--trailing hyphens",
|
||||||
|
"",
|
||||||
|
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035--\"}}",
|
||||||
|
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid uuid--malformed hex",
|
||||||
|
"",
|
||||||
|
"{\"x\" : { \"$uuid\" : \"q3@fd26l-44b3-4c69-90e8-e7d1dfc035d4\"}}",
|
||||||
|
"$uuid value does not follow RFC 4122 format regarding hex bytes: encoding/hex: invalid byte: U+0071 'q'",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
// get canonical extended JSON (if provided)
|
||||||
|
cEJ := ""
|
||||||
|
if tc.canonicalExtJSON != "" {
|
||||||
|
var compactCEJ bytes.Buffer
|
||||||
|
require.NoError(t, json.Compact(&compactCEJ, []byte(tc.canonicalExtJSON)))
|
||||||
|
cEJ = unescapeUnicode(compactCEJ.String(), "0x05")
|
||||||
|
}
|
||||||
|
|
||||||
|
// get degenerate extended JSON
|
||||||
|
var compactDEJ bytes.Buffer
|
||||||
|
require.NoError(t, json.Compact(&compactDEJ, []byte(tc.degenerateExtJSON)))
|
||||||
|
dEJ := unescapeUnicode(compactDEJ.String(), "0x05")
|
||||||
|
|
||||||
|
// convert dEJ to native doc
|
||||||
|
var doc D
|
||||||
|
err := UnmarshalExtJSON([]byte(dEJ), true, &doc)
|
||||||
|
|
||||||
|
if tc.expectedErr != "" {
|
||||||
|
assert.Equal(t, tc.expectedErr, err.Error(), "expected error %v, got %v", tc.expectedErr, err)
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, err, "expected no error, got error: %v", err)
|
||||||
|
|
||||||
|
// Marshal doc into extended JSON and compare with cEJ
|
||||||
|
nativeToJSON(t, cEJ, doc, tc.description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
679
bson_test.go
Normal file
679
bson_test.go
Normal file
@ -0,0 +1,679 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/require"
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func noerr(t *testing.T, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Helper()
|
||||||
|
t.Errorf("Unexpected error: (%T)%v", err, err)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimestamp(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
description string
|
||||||
|
tp Timestamp
|
||||||
|
tp2 Timestamp
|
||||||
|
expectedAfter bool
|
||||||
|
expectedBefore bool
|
||||||
|
expectedEqual bool
|
||||||
|
expectedCompare int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
description: "equal",
|
||||||
|
tp: Timestamp{T: 12345, I: 67890},
|
||||||
|
tp2: Timestamp{T: 12345, I: 67890},
|
||||||
|
expectedBefore: false,
|
||||||
|
expectedAfter: false,
|
||||||
|
expectedEqual: true,
|
||||||
|
expectedCompare: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "T greater than",
|
||||||
|
tp: Timestamp{T: 12345, I: 67890},
|
||||||
|
tp2: Timestamp{T: 2345, I: 67890},
|
||||||
|
expectedBefore: false,
|
||||||
|
expectedAfter: true,
|
||||||
|
expectedEqual: false,
|
||||||
|
expectedCompare: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "I greater than",
|
||||||
|
tp: Timestamp{T: 12345, I: 67890},
|
||||||
|
tp2: Timestamp{T: 12345, I: 7890},
|
||||||
|
expectedBefore: false,
|
||||||
|
expectedAfter: true,
|
||||||
|
expectedEqual: false,
|
||||||
|
expectedCompare: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "T less than",
|
||||||
|
tp: Timestamp{T: 12345, I: 67890},
|
||||||
|
tp2: Timestamp{T: 112345, I: 67890},
|
||||||
|
expectedBefore: true,
|
||||||
|
expectedAfter: false,
|
||||||
|
expectedEqual: false,
|
||||||
|
expectedCompare: -1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
description: "I less than",
|
||||||
|
tp: Timestamp{T: 12345, I: 67890},
|
||||||
|
tp2: Timestamp{T: 12345, I: 167890},
|
||||||
|
expectedBefore: true,
|
||||||
|
expectedAfter: false,
|
||||||
|
expectedEqual: false,
|
||||||
|
expectedCompare: -1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc // Capture range variable.
|
||||||
|
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectedAfter, tc.tp.After(tc.tp2), "expected After results to be the same")
|
||||||
|
assert.Equal(t, tc.expectedBefore, tc.tp.Before(tc.tp2), "expected Before results to be the same")
|
||||||
|
assert.Equal(t, tc.expectedEqual, tc.tp.Equal(tc.tp2), "expected Equal results to be the same")
|
||||||
|
assert.Equal(t, tc.expectedCompare, tc.tp.Compare(tc.tp2), "expected Compare result to be the same")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrimitiveIsZero(t *testing.T) {
|
||||||
|
testcases := []struct {
|
||||||
|
name string
|
||||||
|
zero Zeroer
|
||||||
|
nonzero Zeroer
|
||||||
|
}{
|
||||||
|
{"binary", Binary{}, Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}},
|
||||||
|
{"decimal128", Decimal128{}, NewDecimal128(1, 2)},
|
||||||
|
{"objectID", ObjectID{}, NewObjectID()},
|
||||||
|
{"regex", Regex{}, Regex{Pattern: "foo", Options: "bar"}},
|
||||||
|
{"dbPointer", DBPointer{}, DBPointer{DB: "foobar", Pointer: ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}}},
|
||||||
|
{"timestamp", Timestamp{}, Timestamp{T: 12345, I: 67890}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testcases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
require.True(t, tc.zero.IsZero())
|
||||||
|
require.False(t, tc.nonzero.IsZero())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegexCompare(t *testing.T) {
|
||||||
|
testcases := []struct {
|
||||||
|
name string
|
||||||
|
r1 Regex
|
||||||
|
r2 Regex
|
||||||
|
eq bool
|
||||||
|
}{
|
||||||
|
{"equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo1", Options: "bar1"}, true},
|
||||||
|
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo2", Options: "bar2"}, false},
|
||||||
|
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo1", Options: "bar2"}, false},
|
||||||
|
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo2", Options: "bar1"}, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testcases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
require.True(t, tc.r1.Equal(tc.r2) == tc.eq)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDateTime(t *testing.T) {
|
||||||
|
t.Run("json", func(t *testing.T) {
|
||||||
|
t.Run("round trip", func(t *testing.T) {
|
||||||
|
original := DateTime(1000)
|
||||||
|
jsonBytes, err := json.Marshal(original)
|
||||||
|
assert.Nil(t, err, "Marshal error: %v", err)
|
||||||
|
|
||||||
|
var unmarshalled DateTime
|
||||||
|
err = json.Unmarshal(jsonBytes, &unmarshalled)
|
||||||
|
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||||
|
|
||||||
|
assert.Equal(t, original, unmarshalled, "expected DateTime %v, got %v", original, unmarshalled)
|
||||||
|
})
|
||||||
|
t.Run("decode null", func(t *testing.T) {
|
||||||
|
jsonBytes := []byte("null")
|
||||||
|
var dt DateTime
|
||||||
|
err := json.Unmarshal(jsonBytes, &dt)
|
||||||
|
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||||
|
assert.Equal(t, DateTime(0), dt, "expected DateTime value to be 0, got %v", dt)
|
||||||
|
})
|
||||||
|
t.Run("UTC", func(t *testing.T) {
|
||||||
|
dt := DateTime(1681145535123)
|
||||||
|
jsonBytes, err := json.Marshal(dt)
|
||||||
|
assert.Nil(t, err, "Marshal error: %v", err)
|
||||||
|
assert.Equal(t, `"2023-04-10T16:52:15.123Z"`, string(jsonBytes))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("NewDateTimeFromTime", func(t *testing.T) {
|
||||||
|
t.Run("range is not limited", func(t *testing.T) {
|
||||||
|
// If the implementation internally calls time.Time.UnixNano(), the constructor cannot handle times after
|
||||||
|
// the year 2262.
|
||||||
|
|
||||||
|
timeFormat := "2006-01-02T15:04:05.999Z07:00"
|
||||||
|
timeString := "3001-01-01T00:00:00Z"
|
||||||
|
tt, err := time.Parse(timeFormat, timeString)
|
||||||
|
assert.Nil(t, err, "Parse error: %v", err)
|
||||||
|
|
||||||
|
dt := NewDateTimeFromTime(tt)
|
||||||
|
assert.True(t, dt > 0, "expected a valid DateTime greater than 0, got %v", dt)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimeRoundTrip(t *testing.T) {
|
||||||
|
val := struct {
|
||||||
|
Value time.Time
|
||||||
|
ID string
|
||||||
|
}{
|
||||||
|
ID: "time-rt-test",
|
||||||
|
}
|
||||||
|
|
||||||
|
if !val.Value.IsZero() {
|
||||||
|
t.Errorf("Did not get zero time as expected.")
|
||||||
|
}
|
||||||
|
|
||||||
|
bsonOut, err := Marshal(val)
|
||||||
|
noerr(t, err)
|
||||||
|
rtval := struct {
|
||||||
|
Value time.Time
|
||||||
|
ID string
|
||||||
|
}{}
|
||||||
|
|
||||||
|
err = Unmarshal(bsonOut, &rtval)
|
||||||
|
noerr(t, err)
|
||||||
|
if !cmp.Equal(val, rtval) {
|
||||||
|
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
|
||||||
|
}
|
||||||
|
if !rtval.Value.IsZero() {
|
||||||
|
t.Errorf("Did not get zero time as expected.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNonNullTimeRoundTrip(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
now = time.Unix(now.Unix(), 0)
|
||||||
|
val := struct {
|
||||||
|
Value time.Time
|
||||||
|
ID string
|
||||||
|
}{
|
||||||
|
ID: "time-rt-test",
|
||||||
|
Value: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
bsonOut, err := Marshal(val)
|
||||||
|
noerr(t, err)
|
||||||
|
rtval := struct {
|
||||||
|
Value time.Time
|
||||||
|
ID string
|
||||||
|
}{}
|
||||||
|
|
||||||
|
err = Unmarshal(bsonOut, &rtval)
|
||||||
|
noerr(t, err)
|
||||||
|
if !cmp.Equal(val, rtval) {
|
||||||
|
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD(t *testing.T) {
|
||||||
|
t.Run("can marshal", func(t *testing.T) {
|
||||||
|
d := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
|
||||||
|
idx, want := bsoncore.AppendDocumentStart(nil)
|
||||||
|
want = bsoncore.AppendStringElement(want, "foo", "bar")
|
||||||
|
want = bsoncore.AppendStringElement(want, "hello", "world")
|
||||||
|
want = bsoncore.AppendDoubleElement(want, "pi", 3.14159)
|
||||||
|
want, err := bsoncore.AppendDocumentEnd(want, idx)
|
||||||
|
noerr(t, err)
|
||||||
|
got, err := Marshal(d)
|
||||||
|
noerr(t, err)
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("Marshaled documents do not match. got %v; want %v", Raw(got), Raw(want))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("can unmarshal", func(t *testing.T) {
|
||||||
|
want := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
|
||||||
|
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||||||
|
doc = bsoncore.AppendStringElement(doc, "foo", "bar")
|
||||||
|
doc = bsoncore.AppendStringElement(doc, "hello", "world")
|
||||||
|
doc = bsoncore.AppendDoubleElement(doc, "pi", 3.14159)
|
||||||
|
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
|
||||||
|
noerr(t, err)
|
||||||
|
var got D
|
||||||
|
err = Unmarshal(doc, &got)
|
||||||
|
noerr(t, err)
|
||||||
|
if !cmp.Equal(got, want) {
|
||||||
|
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDStringer(t *testing.T) {
|
||||||
|
got := D{{"a", 1}, {"b", 2}}.String()
|
||||||
|
want := `{"a":{"$numberInt":"1"},"b":{"$numberInt":"2"}}`
|
||||||
|
assert.Equal(t, want, got, "expected: %s, got: %s", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMStringer(t *testing.T) {
|
||||||
|
type msg struct {
|
||||||
|
A json.RawMessage `json:"a"`
|
||||||
|
B json.RawMessage `json:"b"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var res msg
|
||||||
|
got := M{"a": 1, "b": 2}.String()
|
||||||
|
err := json.Unmarshal([]byte(got), &res)
|
||||||
|
require.NoError(t, err, "Unmarshal error")
|
||||||
|
|
||||||
|
want := msg{
|
||||||
|
A: json.RawMessage(`{"$numberInt":"1"}`),
|
||||||
|
B: json.RawMessage(`{"$numberInt":"2"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, want, res, "returned string did not unmarshal to the expected document, returned string: %s", got)
|
||||||
|
}
|
||||||
|
func TestD_MarshalJSON(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testcases := []struct {
|
||||||
|
name string
|
||||||
|
test D
|
||||||
|
expected interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"nil",
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"empty",
|
||||||
|
D{},
|
||||||
|
struct{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"non-empty",
|
||||||
|
D{
|
||||||
|
{"a", 42},
|
||||||
|
{"b", true},
|
||||||
|
{"c", "answer"},
|
||||||
|
{"d", nil},
|
||||||
|
{"e", 2.71828},
|
||||||
|
{"f", A{42, true, "answer", nil, 2.71828}},
|
||||||
|
{"g", D{{"foo", "bar"}}},
|
||||||
|
},
|
||||||
|
struct {
|
||||||
|
A int `json:"a"`
|
||||||
|
B bool `json:"b"`
|
||||||
|
C string `json:"c"`
|
||||||
|
D interface{} `json:"d"`
|
||||||
|
E float32 `json:"e"`
|
||||||
|
F []interface{} `json:"f"`
|
||||||
|
G map[string]interface{} `json:"g"`
|
||||||
|
}{
|
||||||
|
A: 42,
|
||||||
|
B: true,
|
||||||
|
C: "answer",
|
||||||
|
D: nil,
|
||||||
|
E: 2.71828,
|
||||||
|
F: []interface{}{42, true, "answer", nil, 2.71828},
|
||||||
|
G: map[string]interface{}{"foo": "bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range testcases {
|
||||||
|
tc := tc
|
||||||
|
t.Run("json.Marshal "+tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got, err := json.Marshal(tc.test)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
want, _ := json.Marshal(tc.expected)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, tc := range testcases {
|
||||||
|
tc := tc
|
||||||
|
t.Run("json.MarshalIndent "+tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got, err := json.MarshalIndent(tc.test, "<prefix>", "<indent>")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
want, _ := json.MarshalIndent(tc.expected, "<prefix>", "<indent>")
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestD_UnmarshalJSON(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
test []byte
|
||||||
|
expected D
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"nil",
|
||||||
|
[]byte(`null`),
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"empty",
|
||||||
|
[]byte(`{}`),
|
||||||
|
D{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"non-empty",
|
||||||
|
[]byte(`{"hello":"world","pi":3.142,"boolean":true,"nothing":null,"list":["hello world",3.142,false,null,{"Lorem":"ipsum"}],"document":{"foo":"bar"}}`),
|
||||||
|
D{
|
||||||
|
{"hello", "world"},
|
||||||
|
{"pi", 3.142},
|
||||||
|
{"boolean", true},
|
||||||
|
{"nothing", nil},
|
||||||
|
{"list", []interface{}{"hello world", 3.142, false, nil, D{{"Lorem", "ipsum"}}}},
|
||||||
|
{"document", D{{"foo", "bar"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var got D
|
||||||
|
err := json.Unmarshal(tc.test, &got)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tc.expected, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("failure", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
test string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"illegal",
|
||||||
|
`nil`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid",
|
||||||
|
`{"pi": 3.142ipsum}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"malformatted",
|
||||||
|
`{"pi", 3.142}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"truncated",
|
||||||
|
`{"pi": 3.142`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"array type",
|
||||||
|
`["pi", 3.142]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"boolean type",
|
||||||
|
`true`,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var a map[string]interface{}
|
||||||
|
want := json.Unmarshal([]byte(tc.test), &a)
|
||||||
|
var b D
|
||||||
|
got := json.Unmarshal([]byte(tc.test), &b)
|
||||||
|
switch w := want.(type) {
|
||||||
|
case *json.UnmarshalTypeError:
|
||||||
|
w.Type = reflect.TypeOf(b)
|
||||||
|
require.IsType(t, want, got)
|
||||||
|
g := got.(*json.UnmarshalTypeError)
|
||||||
|
assert.Equal(t, w, g)
|
||||||
|
default:
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type stringerString string
|
||||||
|
|
||||||
|
func (ss stringerString) String() string {
|
||||||
|
return "bar"
|
||||||
|
}
|
||||||
|
|
||||||
|
type keyBool bool
|
||||||
|
|
||||||
|
func (kb keyBool) MarshalKey() (string, error) {
|
||||||
|
return fmt.Sprintf("%v", kb), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kb *keyBool) UnmarshalKey(key string) error {
|
||||||
|
switch key {
|
||||||
|
case "true":
|
||||||
|
*kb = true
|
||||||
|
case "false":
|
||||||
|
*kb = false
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid bool value %v", key)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type keyStruct struct {
|
||||||
|
val int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k keyStruct) MarshalText() (text []byte, err error) {
|
||||||
|
str := strconv.FormatInt(k.val, 10)
|
||||||
|
|
||||||
|
return []byte(str), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *keyStruct) UnmarshalText(text []byte) error {
|
||||||
|
val, err := strconv.ParseInt(string(text), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*k = keyStruct{
|
||||||
|
val: val,
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapCodec(t *testing.T) {
|
||||||
|
t.Run("EncodeKeysWithStringer", func(t *testing.T) {
|
||||||
|
strstr := stringerString("foo")
|
||||||
|
mapObj := map[stringerString]int{strstr: 1}
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
mapCodec *mapCodec
|
||||||
|
key string
|
||||||
|
}{
|
||||||
|
{"default", &mapCodec{}, "foo"},
|
||||||
|
{"true", &mapCodec{encodeKeysWithStringer: true}, "bar"},
|
||||||
|
{"false", &mapCodec{encodeKeysWithStringer: false}, "foo"},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mapRegistry := NewRegistry()
|
||||||
|
mapRegistry.RegisterKindEncoder(reflect.Map, tc.mapCodec)
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
vw := NewDocumentWriter(buf)
|
||||||
|
enc := NewEncoder(vw)
|
||||||
|
enc.SetRegistry(mapRegistry)
|
||||||
|
err := enc.Encode(mapObj)
|
||||||
|
assert.Nil(t, err, "Encode error: %v", err)
|
||||||
|
str := buf.String()
|
||||||
|
assert.True(t, strings.Contains(str, tc.key), "expected result to contain %v, got: %v", tc.key, str)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("keys implements keyMarshaler and keyUnmarshaler", func(t *testing.T) {
|
||||||
|
mapObj := map[keyBool]int{keyBool(true): 1}
|
||||||
|
|
||||||
|
doc, err := Marshal(mapObj)
|
||||||
|
assert.Nil(t, err, "Marshal error: %v", err)
|
||||||
|
idx, want := bsoncore.AppendDocumentStart(nil)
|
||||||
|
want = bsoncore.AppendInt32Element(want, "true", 1)
|
||||||
|
want, _ = bsoncore.AppendDocumentEnd(want, idx)
|
||||||
|
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
|
||||||
|
|
||||||
|
var got map[keyBool]int
|
||||||
|
err = Unmarshal(doc, &got)
|
||||||
|
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||||
|
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("keys implements encoding.TextMarshaler and encoding.TextUnmarshaler", func(t *testing.T) {
|
||||||
|
mapObj := map[keyStruct]int{
|
||||||
|
{val: 10}: 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
doc, err := Marshal(mapObj)
|
||||||
|
assert.Nil(t, err, "Marshal error: %v", err)
|
||||||
|
idx, want := bsoncore.AppendDocumentStart(nil)
|
||||||
|
want = bsoncore.AppendInt32Element(want, "10", 100)
|
||||||
|
want, _ = bsoncore.AppendDocumentEnd(want, idx)
|
||||||
|
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
|
||||||
|
|
||||||
|
var got map[keyStruct]int
|
||||||
|
err = Unmarshal(doc, &got)
|
||||||
|
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||||
|
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtJSONEscapeKey(t *testing.T) {
|
||||||
|
doc := D{
|
||||||
|
{
|
||||||
|
Key: "\\usb#",
|
||||||
|
Value: int32(1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: "regex",
|
||||||
|
Value: Regex{Pattern: "ab\\\\\\\"ab", Options: "\""},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, err := MarshalExtJSON(&doc, false, false)
|
||||||
|
noerr(t, err)
|
||||||
|
|
||||||
|
want := `{"\\usb#":1,"regex":{"$regularExpression":{"pattern":"ab\\\\\\\"ab","options":"\""}}}`
|
||||||
|
if diff := cmp.Diff(want, string(b)); diff != "" {
|
||||||
|
t.Errorf("Marshaled documents do not match. got %v, want %v", string(b), want)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got D
|
||||||
|
err = UnmarshalExtJSON(b, false, &got)
|
||||||
|
noerr(t, err)
|
||||||
|
if !cmp.Equal(got, doc) {
|
||||||
|
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, doc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBsoncoreArray(t *testing.T) {
|
||||||
|
type BSONDocumentArray struct {
|
||||||
|
Array []D `bson:"array"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BSONArray struct {
|
||||||
|
Array bsoncore.Array `bson:"array"`
|
||||||
|
}
|
||||||
|
|
||||||
|
bda := BSONDocumentArray{
|
||||||
|
Array: []D{
|
||||||
|
{{"x", 1}},
|
||||||
|
{{"x", 2}},
|
||||||
|
{{"x", 3}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedBSON, err := Marshal(bda)
|
||||||
|
assert.Nil(t, err, "Marshal bsoncore.Document array error: %v", err)
|
||||||
|
|
||||||
|
var ba BSONArray
|
||||||
|
err = Unmarshal(expectedBSON, &ba)
|
||||||
|
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||||
|
|
||||||
|
actualBSON, err := Marshal(ba)
|
||||||
|
assert.Nil(t, err, "Marshal bsoncore.Array error: %v", err)
|
||||||
|
|
||||||
|
assert.Equal(t, expectedBSON, actualBSON,
|
||||||
|
"expected BSON to be %v after Marshalling again; got %v", expectedBSON, actualBSON)
|
||||||
|
|
||||||
|
doc := bsoncore.Document(actualBSON)
|
||||||
|
v := doc.Lookup("array")
|
||||||
|
assert.Equal(t, bsoncore.TypeArray, v.Type, "expected type array, got %v", v.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
var baseTime = time.Date(2024, 10, 11, 12, 13, 14, 12345678, time.UTC)
|
||||||
|
|
||||||
|
func BenchmarkDateTimeMarshalJSON(b *testing.B) {
|
||||||
|
t := NewDateTimeFromTime(baseTime)
|
||||||
|
data, err := t.MarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(data)))
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := t.MarshalJSON(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDateTimeUnmarshalJSON(b *testing.B) {
|
||||||
|
t := NewDateTimeFromTime(baseTime)
|
||||||
|
data, err := t.MarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(len(data)))
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
var dt DateTime
|
||||||
|
if err := dt.UnmarshalJSON(data); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
199
bsoncodec.go
Normal file
199
bsoncodec.go
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
emptyValue = reflect.Value{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValueEncoderError is an error returned from a ValueEncoder when the provided value can't be
|
||||||
|
// encoded by the ValueEncoder.
|
||||||
|
type ValueEncoderError struct {
|
||||||
|
Name string
|
||||||
|
Types []reflect.Type
|
||||||
|
Kinds []reflect.Kind
|
||||||
|
Received reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vee ValueEncoderError) Error() string {
|
||||||
|
typeKinds := make([]string, 0, len(vee.Types)+len(vee.Kinds))
|
||||||
|
for _, t := range vee.Types {
|
||||||
|
typeKinds = append(typeKinds, t.String())
|
||||||
|
}
|
||||||
|
for _, k := range vee.Kinds {
|
||||||
|
if k == reflect.Map {
|
||||||
|
typeKinds = append(typeKinds, "map[string]*")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
typeKinds = append(typeKinds, k.String())
|
||||||
|
}
|
||||||
|
received := vee.Received.Kind().String()
|
||||||
|
if vee.Received.IsValid() {
|
||||||
|
received = vee.Received.Type().String()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s can only encode valid %s, but got %s", vee.Name, strings.Join(typeKinds, ", "), received)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueDecoderError is an error returned from a ValueDecoder when the provided value can't be
|
||||||
|
// decoded by the ValueDecoder.
|
||||||
|
type ValueDecoderError struct {
|
||||||
|
Name string
|
||||||
|
Types []reflect.Type
|
||||||
|
Kinds []reflect.Kind
|
||||||
|
Received reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vde ValueDecoderError) Error() string {
|
||||||
|
typeKinds := make([]string, 0, len(vde.Types)+len(vde.Kinds))
|
||||||
|
for _, t := range vde.Types {
|
||||||
|
typeKinds = append(typeKinds, t.String())
|
||||||
|
}
|
||||||
|
for _, k := range vde.Kinds {
|
||||||
|
if k == reflect.Map {
|
||||||
|
typeKinds = append(typeKinds, "map[string]*")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
typeKinds = append(typeKinds, k.String())
|
||||||
|
}
|
||||||
|
received := vde.Received.Kind().String()
|
||||||
|
if vde.Received.IsValid() {
|
||||||
|
received = vde.Received.Type().String()
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeContext is the contextual information required for a Codec to encode a
|
||||||
|
// value.
|
||||||
|
type EncodeContext struct {
|
||||||
|
*Registry
|
||||||
|
|
||||||
|
// minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64,
|
||||||
|
// uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits)
|
||||||
|
// that can represent the integer value.
|
||||||
|
minSize bool
|
||||||
|
|
||||||
|
errorOnInlineDuplicates bool
|
||||||
|
stringifyMapKeysWithFmt bool
|
||||||
|
nilMapAsEmpty bool
|
||||||
|
nilSliceAsEmpty bool
|
||||||
|
nilByteSliceAsEmpty bool
|
||||||
|
omitZeroStruct bool
|
||||||
|
useJSONStructTags bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeContext is the contextual information required for a Codec to decode a
|
||||||
|
// value.
|
||||||
|
type DecodeContext struct {
|
||||||
|
*Registry
|
||||||
|
|
||||||
|
// truncate, if true, instructs decoders to to truncate the fractional part of BSON "double"
|
||||||
|
// values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64,
|
||||||
|
// uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to
|
||||||
|
// BSON "decimal128" values.
|
||||||
|
truncate bool
|
||||||
|
|
||||||
|
// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the
|
||||||
|
// usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is
|
||||||
|
// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an
|
||||||
|
// error.
|
||||||
|
defaultDocumentType reflect.Type
|
||||||
|
|
||||||
|
binaryAsSlice bool
|
||||||
|
|
||||||
|
// a false value results in a decoding error.
|
||||||
|
objectIDAsHexString bool
|
||||||
|
|
||||||
|
useJSONStructTags bool
|
||||||
|
useLocalTimeZone bool
|
||||||
|
zeroMaps bool
|
||||||
|
zeroStructs bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON.
|
||||||
|
// The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the
|
||||||
|
// EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc
|
||||||
|
// is provided to allow use of a function with the correct signature as a ValueEncoder. An
|
||||||
|
// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and
|
||||||
|
// to provide configuration information.
|
||||||
|
type ValueEncoder interface {
|
||||||
|
EncodeValue(EncodeContext, ValueWriter, reflect.Value) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueEncoderFunc is an adapter function that allows a function with the correct signature to be
|
||||||
|
// used as a ValueEncoder.
|
||||||
|
type ValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error
|
||||||
|
|
||||||
|
// EncodeValue implements the ValueEncoder interface.
|
||||||
|
func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
return fn(ec, vw, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type.
|
||||||
|
// Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc,
|
||||||
|
// ValueDecoderFunc is provided to allow the use of a function with the correct signature as a
|
||||||
|
// ValueDecoder. A DecodeContext instance is provided and serves similar functionality to the
|
||||||
|
// EncodeContext.
|
||||||
|
type ValueDecoder interface {
|
||||||
|
DecodeValue(DecodeContext, ValueReader, reflect.Value) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValueDecoderFunc is an adapter function that allows a function with the correct signature to be
|
||||||
|
// used as a ValueDecoder.
|
||||||
|
type ValueDecoderFunc func(DecodeContext, ValueReader, reflect.Value) error
|
||||||
|
|
||||||
|
// DecodeValue implements the ValueDecoder interface.
|
||||||
|
func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||||
|
return fn(dc, vr, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// typeDecoder is the interface implemented by types that can handle the decoding of a value given its type.
|
||||||
|
type typeDecoder interface {
|
||||||
|
decodeType(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder.
|
||||||
|
type typeDecoderFunc func(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
|
||||||
|
|
||||||
|
func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
|
||||||
|
return fn(dc, vr, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder.
|
||||||
|
type decodeAdapter struct {
|
||||||
|
ValueDecoderFunc
|
||||||
|
typeDecoderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ValueDecoder = decodeAdapter{}
|
||||||
|
var _ typeDecoder = decodeAdapter{}
|
||||||
|
|
||||||
|
func decodeTypeOrValueWithInfo(vd ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
|
||||||
|
if td, _ := vd.(typeDecoder); td != nil {
|
||||||
|
val, err := td.decodeType(dc, vr, t)
|
||||||
|
if err == nil && val.Type() != t {
|
||||||
|
// This conversion step is necessary for slices and maps. If a user declares variables like:
|
||||||
|
//
|
||||||
|
// type myBool bool
|
||||||
|
// var m map[string]myBool
|
||||||
|
//
|
||||||
|
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
|
||||||
|
// because we'll try to assign a value of type bool to one of type myBool.
|
||||||
|
val = val.Convert(t)
|
||||||
|
}
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
val := reflect.New(t).Elem()
|
||||||
|
err := vd.DecodeValue(dc, vr, val)
|
||||||
|
return val, err
|
||||||
|
}
|
72
bsoncodec_test.go
Normal file
72
bsoncodec_test.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleValueEncoder() {
|
||||||
|
var _ ValueEncoderFunc = func(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if val.Kind() != reflect.String {
|
||||||
|
return ValueEncoderError{Name: "StringEncodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vw.WriteString(val.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleValueDecoder() {
|
||||||
|
var _ ValueDecoderFunc = func(_ DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||||
|
if !val.CanSet() || val.Kind() != reflect.String {
|
||||||
|
return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
if vr.Type() != TypeString {
|
||||||
|
return fmt.Errorf("cannot decode %v into a string type", vr.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
str, err := vr.ReadString()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
val.SetString(str)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type llCodec struct {
|
||||||
|
t *testing.T
|
||||||
|
decodeval interface{}
|
||||||
|
encodeval interface{}
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llc *llCodec) EncodeValue(_ EncodeContext, _ ValueWriter, i interface{}) error {
|
||||||
|
if llc.err != nil {
|
||||||
|
return llc.err
|
||||||
|
}
|
||||||
|
|
||||||
|
llc.encodeval = i
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llc *llCodec) DecodeValue(_ DecodeContext, _ ValueReader, val reflect.Value) error {
|
||||||
|
if llc.err != nil {
|
||||||
|
return llc.err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.TypeOf(llc.decodeval).AssignableTo(val.Type()) {
|
||||||
|
llc.t.Errorf("decodeval must be assignable to val provided to DecodeValue, but is not. decodeval %T; val %T", llc.decodeval, val)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
val.Set(reflect.ValueOf(llc.decodeval))
|
||||||
|
return nil
|
||||||
|
}
|
846
bsonrw_test.go
Normal file
846
bsonrw_test.go
Normal file
@ -0,0 +1,846 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ ValueReader = &valueReaderWriter{}
|
||||||
|
_ ValueWriter = &valueReaderWriter{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// invoked is a type used to indicate what method was called last.
|
||||||
|
type invoked byte
|
||||||
|
|
||||||
|
// These are the different methods that can be invoked.
|
||||||
|
const (
|
||||||
|
nothing invoked = iota
|
||||||
|
readArray
|
||||||
|
readBinary
|
||||||
|
readBoolean
|
||||||
|
readDocument
|
||||||
|
readCodeWithScope
|
||||||
|
readDBPointer
|
||||||
|
readDateTime
|
||||||
|
readDecimal128
|
||||||
|
readDouble
|
||||||
|
readInt32
|
||||||
|
readInt64
|
||||||
|
readJavascript
|
||||||
|
readMaxKey
|
||||||
|
readMinKey
|
||||||
|
readNull
|
||||||
|
readObjectID
|
||||||
|
readRegex
|
||||||
|
readString
|
||||||
|
readSymbol
|
||||||
|
readTimestamp
|
||||||
|
readUndefined
|
||||||
|
readElement
|
||||||
|
readValue
|
||||||
|
writeArray
|
||||||
|
writeBinary
|
||||||
|
writeBinaryWithSubtype
|
||||||
|
writeBoolean
|
||||||
|
writeCodeWithScope
|
||||||
|
writeDBPointer
|
||||||
|
writeDateTime
|
||||||
|
writeDecimal128
|
||||||
|
writeDouble
|
||||||
|
writeInt32
|
||||||
|
writeInt64
|
||||||
|
writeJavascript
|
||||||
|
writeMaxKey
|
||||||
|
writeMinKey
|
||||||
|
writeNull
|
||||||
|
writeObjectID
|
||||||
|
writeRegex
|
||||||
|
writeString
|
||||||
|
writeDocument
|
||||||
|
writeSymbol
|
||||||
|
writeTimestamp
|
||||||
|
writeUndefined
|
||||||
|
writeDocumentElement
|
||||||
|
writeDocumentEnd
|
||||||
|
writeArrayElement
|
||||||
|
writeArrayEnd
|
||||||
|
skip
|
||||||
|
)
|
||||||
|
|
||||||
|
func (i invoked) String() string {
|
||||||
|
switch i {
|
||||||
|
case nothing:
|
||||||
|
return "Nothing"
|
||||||
|
case readArray:
|
||||||
|
return "ReadArray"
|
||||||
|
case readBinary:
|
||||||
|
return "ReadBinary"
|
||||||
|
case readBoolean:
|
||||||
|
return "ReadBoolean"
|
||||||
|
case readDocument:
|
||||||
|
return "ReadDocument"
|
||||||
|
case readCodeWithScope:
|
||||||
|
return "ReadCodeWithScope"
|
||||||
|
case readDBPointer:
|
||||||
|
return "ReadDBPointer"
|
||||||
|
case readDateTime:
|
||||||
|
return "ReadDateTime"
|
||||||
|
case readDecimal128:
|
||||||
|
return "ReadDecimal128"
|
||||||
|
case readDouble:
|
||||||
|
return "ReadDouble"
|
||||||
|
case readInt32:
|
||||||
|
return "ReadInt32"
|
||||||
|
case readInt64:
|
||||||
|
return "ReadInt64"
|
||||||
|
case readJavascript:
|
||||||
|
return "ReadJavascript"
|
||||||
|
case readMaxKey:
|
||||||
|
return "ReadMaxKey"
|
||||||
|
case readMinKey:
|
||||||
|
return "ReadMinKey"
|
||||||
|
case readNull:
|
||||||
|
return "ReadNull"
|
||||||
|
case readObjectID:
|
||||||
|
return "ReadObjectID"
|
||||||
|
case readRegex:
|
||||||
|
return "ReadRegex"
|
||||||
|
case readString:
|
||||||
|
return "ReadString"
|
||||||
|
case readSymbol:
|
||||||
|
return "ReadSymbol"
|
||||||
|
case readTimestamp:
|
||||||
|
return "ReadTimestamp"
|
||||||
|
case readUndefined:
|
||||||
|
return "ReadUndefined"
|
||||||
|
case readElement:
|
||||||
|
return "ReadElement"
|
||||||
|
case readValue:
|
||||||
|
return "ReadValue"
|
||||||
|
case writeArray:
|
||||||
|
return "WriteArray"
|
||||||
|
case writeBinary:
|
||||||
|
return "WriteBinary"
|
||||||
|
case writeBinaryWithSubtype:
|
||||||
|
return "WriteBinaryWithSubtype"
|
||||||
|
case writeBoolean:
|
||||||
|
return "WriteBoolean"
|
||||||
|
case writeCodeWithScope:
|
||||||
|
return "WriteCodeWithScope"
|
||||||
|
case writeDBPointer:
|
||||||
|
return "WriteDBPointer"
|
||||||
|
case writeDateTime:
|
||||||
|
return "WriteDateTime"
|
||||||
|
case writeDecimal128:
|
||||||
|
return "WriteDecimal128"
|
||||||
|
case writeDouble:
|
||||||
|
return "WriteDouble"
|
||||||
|
case writeInt32:
|
||||||
|
return "WriteInt32"
|
||||||
|
case writeInt64:
|
||||||
|
return "WriteInt64"
|
||||||
|
case writeJavascript:
|
||||||
|
return "WriteJavascript"
|
||||||
|
case writeMaxKey:
|
||||||
|
return "WriteMaxKey"
|
||||||
|
case writeMinKey:
|
||||||
|
return "WriteMinKey"
|
||||||
|
case writeNull:
|
||||||
|
return "WriteNull"
|
||||||
|
case writeObjectID:
|
||||||
|
return "WriteObjectID"
|
||||||
|
case writeRegex:
|
||||||
|
return "WriteRegex"
|
||||||
|
case writeString:
|
||||||
|
return "WriteString"
|
||||||
|
case writeDocument:
|
||||||
|
return "WriteDocument"
|
||||||
|
case writeSymbol:
|
||||||
|
return "WriteSymbol"
|
||||||
|
case writeTimestamp:
|
||||||
|
return "WriteTimestamp"
|
||||||
|
case writeUndefined:
|
||||||
|
return "WriteUndefined"
|
||||||
|
case writeDocumentElement:
|
||||||
|
return "WriteDocumentElement"
|
||||||
|
case writeDocumentEnd:
|
||||||
|
return "WriteDocumentEnd"
|
||||||
|
case writeArrayElement:
|
||||||
|
return "WriteArrayElement"
|
||||||
|
case writeArrayEnd:
|
||||||
|
return "WriteArrayEnd"
|
||||||
|
default:
|
||||||
|
return "<unknown>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valueReaderWriter is a test implementation of a bsonrw.ValueReader and bsonrw.ValueWriter
|
||||||
|
type valueReaderWriter struct {
|
||||||
|
T *testing.T
|
||||||
|
invoked invoked
|
||||||
|
Return interface{} // Can be a primitive or a bsoncore.Value
|
||||||
|
BSONType Type
|
||||||
|
Err error
|
||||||
|
ErrAfter invoked // error after this method is called
|
||||||
|
depth uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// prevent infinite recursion.
|
||||||
|
func (llvrw *valueReaderWriter) checkdepth() {
|
||||||
|
llvrw.depth++
|
||||||
|
if llvrw.depth > 1000 {
|
||||||
|
panic("max depth exceeded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) Type() Type {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
return llvrw.BSONType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) Skip() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = skip
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadArray implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadArray() (ArrayReader, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readArray
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return nil, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadBinary implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadBinary() (b []byte, btype byte, err error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readBinary
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return nil, 0x00, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tt := llvrw.Return.(type) {
|
||||||
|
case bsoncore.Value:
|
||||||
|
subtype, data, _, ok := bsoncore.ReadBinary(tt.Data)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Error("Invalid Value provided for return value of ReadBinary.")
|
||||||
|
return nil, 0x00, nil
|
||||||
|
}
|
||||||
|
return data, subtype, nil
|
||||||
|
default:
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadBinary: %T", llvrw.Return)
|
||||||
|
return nil, 0x00, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadBoolean implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadBoolean() (bool, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readBoolean
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return false, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tt := llvrw.Return.(type) {
|
||||||
|
case bool:
|
||||||
|
return tt, nil
|
||||||
|
case bsoncore.Value:
|
||||||
|
b, _, ok := bsoncore.ReadBoolean(tt.Data)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Error("Invalid Value provided for return value of ReadBoolean.")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
default:
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadBoolean: %T", llvrw.Return)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadDocument implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadDocument() (DocumentReader, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readDocument
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return nil, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadCodeWithScope implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readCodeWithScope
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return "", nil, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadDBPointer implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadDBPointer() (ns string, oid ObjectID, err error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readDBPointer
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return "", ObjectID{}, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tt := llvrw.Return.(type) {
|
||||||
|
case bsoncore.Value:
|
||||||
|
ns, oid, _, ok := bsoncore.ReadDBPointer(tt.Data)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Error("Invalid Value instance provided for return value of ReadDBPointer")
|
||||||
|
return "", ObjectID{}, nil
|
||||||
|
}
|
||||||
|
return ns, oid, nil
|
||||||
|
default:
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadDBPointer: %T", llvrw.Return)
|
||||||
|
return "", ObjectID{}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadDateTime implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadDateTime() (int64, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readDateTime
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return 0, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
dt, ok := llvrw.Return.(int64)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadDateTime: %T", llvrw.Return)
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return dt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadDecimal128 implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadDecimal128() (Decimal128, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readDecimal128
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return Decimal128{}, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
d128, ok := llvrw.Return.(Decimal128)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadDecimal128: %T", llvrw.Return)
|
||||||
|
return Decimal128{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return d128, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadDouble implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadDouble() (float64, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readDouble
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return 0, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
f64, ok := llvrw.Return.(float64)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadDouble: %T", llvrw.Return)
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return f64, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadInt32 implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadInt32() (int32, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readInt32
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return 0, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
i32, ok := llvrw.Return.(int32)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadInt32: %T", llvrw.Return)
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return i32, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadInt64 implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadInt64() (int64, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readInt64
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return 0, llvrw.Err
|
||||||
|
}
|
||||||
|
i64, ok := llvrw.Return.(int64)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadInt64: %T", llvrw.Return)
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return i64, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadJavascript implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadJavascript() (code string, err error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readJavascript
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return "", llvrw.Err
|
||||||
|
}
|
||||||
|
js, ok := llvrw.Return.(string)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadJavascript: %T", llvrw.Return)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return js, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadMaxKey implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadMaxKey() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readMaxKey
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadMinKey implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadMinKey() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readMinKey
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadNull implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadNull() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readNull
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadObjectID implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadObjectID() (ObjectID, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readObjectID
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return ObjectID{}, llvrw.Err
|
||||||
|
}
|
||||||
|
oid, ok := llvrw.Return.(ObjectID)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadObjectID: %T", llvrw.Return)
|
||||||
|
return ObjectID{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return oid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadRegex implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadRegex() (pattern string, options string, err error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readRegex
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return "", "", llvrw.Err
|
||||||
|
}
|
||||||
|
switch tt := llvrw.Return.(type) {
|
||||||
|
case bsoncore.Value:
|
||||||
|
pattern, options, _, ok := bsoncore.ReadRegex(tt.Data)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Error("Invalid Value instance provided for ReadRegex")
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
return pattern, options, nil
|
||||||
|
default:
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadRegex: %T", llvrw.Return)
|
||||||
|
return "", "", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadString implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadString() (string, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readString
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return "", llvrw.Err
|
||||||
|
}
|
||||||
|
str, ok := llvrw.Return.(string)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadString: %T", llvrw.Return)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return str, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadSymbol implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadSymbol() (symbol string, err error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readSymbol
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return "", llvrw.Err
|
||||||
|
}
|
||||||
|
switch tt := llvrw.Return.(type) {
|
||||||
|
case string:
|
||||||
|
return tt, nil
|
||||||
|
case bsoncore.Value:
|
||||||
|
symbol, _, ok := bsoncore.ReadSymbol(tt.Data)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Error("Invalid Value instance provided for ReadSymbol")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return symbol, nil
|
||||||
|
default:
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadSymbol: %T", llvrw.Return)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadTimestamp implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadTimestamp() (t uint32, i uint32, err error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readTimestamp
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return 0, 0, llvrw.Err
|
||||||
|
}
|
||||||
|
switch tt := llvrw.Return.(type) {
|
||||||
|
case bsoncore.Value:
|
||||||
|
t, i, _, ok := bsoncore.ReadTimestamp(tt.Data)
|
||||||
|
if !ok {
|
||||||
|
llvrw.T.Errorf("Invalid Value instance provided for return value of ReadTimestamp")
|
||||||
|
return 0, 0, nil
|
||||||
|
}
|
||||||
|
return t, i, nil
|
||||||
|
default:
|
||||||
|
llvrw.T.Errorf("Incorrect type provided for return value of ReadTimestamp: %T", llvrw.Return)
|
||||||
|
return 0, 0, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadUndefined implements the ValueReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadUndefined() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readUndefined
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteArray implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteArray() (ArrayWriter, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeArray
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return nil, llvrw.Err
|
||||||
|
}
|
||||||
|
return llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBinary implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteBinary([]byte) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeBinary
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBinaryWithSubtype implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteBinaryWithSubtype([]byte, byte) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeBinaryWithSubtype
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteBoolean implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteBoolean(bool) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeBoolean
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteCodeWithScope implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteCodeWithScope(string) (DocumentWriter, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeCodeWithScope
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return nil, llvrw.Err
|
||||||
|
}
|
||||||
|
return llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDBPointer implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteDBPointer(string, ObjectID) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeDBPointer
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDateTime implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteDateTime(int64) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeDateTime
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDecimal128 implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteDecimal128(Decimal128) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeDecimal128
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDouble implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteDouble(float64) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeDouble
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteInt32 implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteInt32(int32) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeInt32
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteInt64 implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteInt64(int64) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeInt64
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteJavascript implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteJavascript(string) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeJavascript
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteMaxKey implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteMaxKey() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeMaxKey
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteMinKey implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteMinKey() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeMinKey
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteNull implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteNull() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeNull
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteObjectID implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteObjectID(ObjectID) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeObjectID
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteRegex implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteRegex(string, string) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeRegex
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteString implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteString(string) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeString
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDocument implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteDocument() (DocumentWriter, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeDocument
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return nil, llvrw.Err
|
||||||
|
}
|
||||||
|
return llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteSymbol implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteSymbol(string) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeSymbol
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTimestamp implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteTimestamp(uint32, uint32) error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeTimestamp
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteUndefined implements the ValueWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteUndefined() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeUndefined
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadElement implements the DocumentReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadElement() (string, ValueReader, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readElement
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return "", nil, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDocumentElement implements the DocumentWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteDocumentElement(string) (ValueWriter, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeDocumentElement
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return nil, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteDocumentEnd implements the DocumentWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteDocumentEnd() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeDocumentEnd
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadValue implements the ArrayReader interface.
|
||||||
|
func (llvrw *valueReaderWriter) ReadValue() (ValueReader, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = readValue
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return nil, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteArrayElement implements the ArrayWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteArrayElement() (ValueWriter, error) {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeArrayElement
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return nil, llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return llvrw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteArrayEnd implements the ArrayWriter interface.
|
||||||
|
func (llvrw *valueReaderWriter) WriteArrayEnd() error {
|
||||||
|
llvrw.checkdepth()
|
||||||
|
llvrw.invoked = writeArrayEnd
|
||||||
|
if llvrw.ErrAfter == llvrw.invoked {
|
||||||
|
return llvrw.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
97
byte_slice_codec.go
Normal file
97
byte_slice_codec.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// byteSliceCodec is the Codec used for []byte values.
|
||||||
|
type byteSliceCodec struct {
|
||||||
|
// encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values
|
||||||
|
// instead of BSON null.
|
||||||
|
encodeNilAsEmpty bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that byteSliceCodec satisfies the typeDecoder interface, which allows it to be
|
||||||
|
// used by collection type decoders (e.g. map, slice, etc) to set individual values in a
|
||||||
|
// collection.
|
||||||
|
var _ typeDecoder = &byteSliceCodec{}
|
||||||
|
|
||||||
|
// EncodeValue is the ValueEncoder for []byte.
|
||||||
|
func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tByteSlice {
|
||||||
|
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
|
||||||
|
}
|
||||||
|
if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty {
|
||||||
|
return vw.WriteNull()
|
||||||
|
}
|
||||||
|
return vw.WriteBinary(val.Interface().([]byte))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
|
||||||
|
if t != tByteSlice {
|
||||||
|
return emptyValue, ValueDecoderError{
|
||||||
|
Name: "ByteSliceDecodeValue",
|
||||||
|
Types: []reflect.Type{tByteSlice},
|
||||||
|
Received: reflect.Zero(t),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var data []byte
|
||||||
|
var err error
|
||||||
|
switch vrType := vr.Type(); vrType {
|
||||||
|
case TypeString:
|
||||||
|
str, err := vr.ReadString()
|
||||||
|
if err != nil {
|
||||||
|
return emptyValue, err
|
||||||
|
}
|
||||||
|
data = []byte(str)
|
||||||
|
case TypeSymbol:
|
||||||
|
sym, err := vr.ReadSymbol()
|
||||||
|
if err != nil {
|
||||||
|
return emptyValue, err
|
||||||
|
}
|
||||||
|
data = []byte(sym)
|
||||||
|
case TypeBinary:
|
||||||
|
var subtype byte
|
||||||
|
data, subtype, err = vr.ReadBinary()
|
||||||
|
if err != nil {
|
||||||
|
return emptyValue, err
|
||||||
|
}
|
||||||
|
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
|
||||||
|
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "[]byte"}
|
||||||
|
}
|
||||||
|
case TypeNull:
|
||||||
|
err = vr.ReadNull()
|
||||||
|
case TypeUndefined:
|
||||||
|
err = vr.ReadUndefined()
|
||||||
|
default:
|
||||||
|
return emptyValue, fmt.Errorf("cannot decode %v into a []byte", vrType)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return emptyValue, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.ValueOf(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeValue is the ValueDecoder for []byte.
|
||||||
|
func (bsc *byteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||||
|
if !val.CanSet() || val.Type() != tByteSlice {
|
||||||
|
return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
elem, err := bsc.decodeType(dc, vr, tByteSlice)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
val.Set(elem)
|
||||||
|
return nil
|
||||||
|
}
|
166
codec_cache.go
Normal file
166
codec_cache.go
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Runtime check that the kind encoder and decoder caches can store any valid
|
||||||
|
// reflect.Kind constant.
|
||||||
|
func init() {
|
||||||
|
if s := reflect.Kind(len(kindEncoderCache{}.entries)).String(); s != "kind27" {
|
||||||
|
panic("The capacity of kindEncoderCache is too small.\n" +
|
||||||
|
"This is due to a new type being added to reflect.Kind.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// statically assert array size
|
||||||
|
var _ = (kindEncoderCache{}).entries[reflect.UnsafePointer]
|
||||||
|
var _ = (kindDecoderCache{}).entries[reflect.UnsafePointer]
|
||||||
|
|
||||||
|
type typeEncoderCache struct {
|
||||||
|
cache sync.Map // map[reflect.Type]ValueEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) {
|
||||||
|
c.cache.Store(rt, enc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) {
|
||||||
|
if v, _ := c.cache.Load(rt); v != nil {
|
||||||
|
return v.(ValueEncoder), true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder {
|
||||||
|
if v, loaded := c.cache.LoadOrStore(rt, enc); loaded {
|
||||||
|
enc = v.(ValueEncoder)
|
||||||
|
}
|
||||||
|
return enc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *typeEncoderCache) Clone() *typeEncoderCache {
|
||||||
|
cc := new(typeEncoderCache)
|
||||||
|
c.cache.Range(func(k, v interface{}) bool {
|
||||||
|
if k != nil && v != nil {
|
||||||
|
cc.cache.Store(k, v)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return cc
|
||||||
|
}
|
||||||
|
|
||||||
|
type typeDecoderCache struct {
|
||||||
|
cache sync.Map // map[reflect.Type]ValueDecoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) {
|
||||||
|
c.cache.Store(rt, dec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) {
|
||||||
|
if v, _ := c.cache.Load(rt); v != nil {
|
||||||
|
return v.(ValueDecoder), true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder {
|
||||||
|
if v, loaded := c.cache.LoadOrStore(rt, dec); loaded {
|
||||||
|
dec = v.(ValueDecoder)
|
||||||
|
}
|
||||||
|
return dec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *typeDecoderCache) Clone() *typeDecoderCache {
|
||||||
|
cc := new(typeDecoderCache)
|
||||||
|
c.cache.Range(func(k, v interface{}) bool {
|
||||||
|
if k != nil && v != nil {
|
||||||
|
cc.cache.Store(k, v)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return cc
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomic.Value requires that all calls to Store() have the same concrete type
|
||||||
|
// so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type
|
||||||
|
// is always the same (since different concrete types may implement the
|
||||||
|
// ValueEncoder interface).
|
||||||
|
type kindEncoderCacheEntry struct {
|
||||||
|
enc ValueEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
type kindEncoderCache struct {
|
||||||
|
entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) {
|
||||||
|
if enc != nil && rt < reflect.Kind(len(c.entries)) {
|
||||||
|
c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) {
|
||||||
|
if rt < reflect.Kind(len(c.entries)) {
|
||||||
|
if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok {
|
||||||
|
return ent.enc, ent.enc != nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kindEncoderCache) Clone() *kindEncoderCache {
|
||||||
|
cc := new(kindEncoderCache)
|
||||||
|
for i, v := range c.entries {
|
||||||
|
if val := v.Load(); val != nil {
|
||||||
|
cc.entries[i].Store(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cc
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomic.Value requires that all calls to Store() have the same concrete type
|
||||||
|
// so we wrap the ValueDecoder with a kindDecoderCacheEntry to ensure the type
|
||||||
|
// is always the same (since different concrete types may implement the
|
||||||
|
// ValueDecoder interface).
|
||||||
|
type kindDecoderCacheEntry struct {
|
||||||
|
dec ValueDecoder
|
||||||
|
}
|
||||||
|
|
||||||
|
type kindDecoderCache struct {
|
||||||
|
entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) {
|
||||||
|
if rt < reflect.Kind(len(c.entries)) {
|
||||||
|
c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) {
|
||||||
|
if rt < reflect.Kind(len(c.entries)) {
|
||||||
|
if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok {
|
||||||
|
return ent.dec, ent.dec != nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *kindDecoderCache) Clone() *kindDecoderCache {
|
||||||
|
cc := new(kindDecoderCache)
|
||||||
|
for i, v := range c.entries {
|
||||||
|
if val := v.Load(); val != nil {
|
||||||
|
cc.entries[i].Store(val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cc
|
||||||
|
}
|
176
codec_cache_test.go
Normal file
176
codec_cache_test.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NB(charlie): the array size is a power of 2 because we use the remainder of
|
||||||
|
// it (mod) in benchmarks and that is faster when the size is a power of 2.
|
||||||
|
var codecCacheTestTypes = [16]reflect.Type{
|
||||||
|
reflect.TypeOf(uint8(0)),
|
||||||
|
reflect.TypeOf(uint16(0)),
|
||||||
|
reflect.TypeOf(uint32(0)),
|
||||||
|
reflect.TypeOf(uint64(0)),
|
||||||
|
reflect.TypeOf(uint(0)),
|
||||||
|
reflect.TypeOf(uintptr(0)),
|
||||||
|
reflect.TypeOf(int8(0)),
|
||||||
|
reflect.TypeOf(int16(0)),
|
||||||
|
reflect.TypeOf(int32(0)),
|
||||||
|
reflect.TypeOf(int64(0)),
|
||||||
|
reflect.TypeOf(int(0)),
|
||||||
|
reflect.TypeOf(float32(0)),
|
||||||
|
reflect.TypeOf(float64(0)),
|
||||||
|
reflect.TypeOf(true),
|
||||||
|
reflect.TypeOf(struct{ A int }{}),
|
||||||
|
reflect.TypeOf(map[int]int{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeCache(t *testing.T) {
|
||||||
|
rt := reflect.TypeOf(int(0))
|
||||||
|
ec := new(typeEncoderCache)
|
||||||
|
dc := new(typeDecoderCache)
|
||||||
|
|
||||||
|
codec := new(fakeCodec)
|
||||||
|
ec.Store(rt, codec)
|
||||||
|
dc.Store(rt, codec)
|
||||||
|
if v, ok := ec.Load(rt); !ok || !reflect.DeepEqual(v, codec) {
|
||||||
|
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true)
|
||||||
|
}
|
||||||
|
if v, ok := dc.Load(rt); !ok || !reflect.DeepEqual(v, codec) {
|
||||||
|
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we overwrite the stored value with nil
|
||||||
|
ec.Store(rt, nil)
|
||||||
|
dc.Store(rt, nil)
|
||||||
|
if v, ok := ec.Load(rt); ok || v != nil {
|
||||||
|
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false)
|
||||||
|
}
|
||||||
|
if v, ok := dc.Load(rt); ok || v != nil {
|
||||||
|
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTypeCacheClone(t *testing.T) {
|
||||||
|
codec := new(fakeCodec)
|
||||||
|
ec1 := new(typeEncoderCache)
|
||||||
|
dc1 := new(typeDecoderCache)
|
||||||
|
for _, rt := range codecCacheTestTypes {
|
||||||
|
ec1.Store(rt, codec)
|
||||||
|
dc1.Store(rt, codec)
|
||||||
|
}
|
||||||
|
ec2 := ec1.Clone()
|
||||||
|
dc2 := dc1.Clone()
|
||||||
|
for _, rt := range codecCacheTestTypes {
|
||||||
|
if v, _ := ec2.Load(rt); !reflect.DeepEqual(v, codec) {
|
||||||
|
t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec)
|
||||||
|
}
|
||||||
|
if v, _ := dc2.Load(rt); !reflect.DeepEqual(v, codec) {
|
||||||
|
t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKindCacheArray(t *testing.T) {
|
||||||
|
// Check array bounds
|
||||||
|
var c kindEncoderCache
|
||||||
|
codec := new(fakeCodec)
|
||||||
|
c.Store(reflect.UnsafePointer, codec) // valid
|
||||||
|
c.Store(reflect.UnsafePointer+1, codec) // ignored
|
||||||
|
if v, ok := c.Load(reflect.UnsafePointer); !ok || v != codec {
|
||||||
|
t.Errorf("Load(reflect.UnsafePointer) = %v, %t; want: %v, %t", v, ok, codec, true)
|
||||||
|
}
|
||||||
|
if v, ok := c.Load(reflect.UnsafePointer + 1); ok || v != nil {
|
||||||
|
t.Errorf("Load(reflect.UnsafePointer + 1) = %v, %t; want: %v, %t", v, ok, nil, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure that reflect.UnsafePointer is the last/largest reflect.Type.
|
||||||
|
//
|
||||||
|
// The String() method of invalid reflect.Type types are of the format
|
||||||
|
// "kind{NUMBER}".
|
||||||
|
for rt := reflect.UnsafePointer + 1; rt < reflect.UnsafePointer+16; rt++ {
|
||||||
|
s := rt.String()
|
||||||
|
if !strings.Contains(s, strconv.Itoa(int(rt))) {
|
||||||
|
t.Errorf("reflect.Type(%d) appears to be valid: %q", rt, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKindCacheClone(t *testing.T) {
|
||||||
|
e1 := new(kindEncoderCache)
|
||||||
|
d1 := new(kindDecoderCache)
|
||||||
|
codec := new(fakeCodec)
|
||||||
|
for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ {
|
||||||
|
e1.Store(k, codec)
|
||||||
|
d1.Store(k, codec)
|
||||||
|
}
|
||||||
|
e2 := e1.Clone()
|
||||||
|
for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ {
|
||||||
|
v1, ok1 := e1.Load(k)
|
||||||
|
v2, ok2 := e2.Load(k)
|
||||||
|
if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil {
|
||||||
|
t.Errorf("Encoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d2 := d1.Clone()
|
||||||
|
for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ {
|
||||||
|
v1, ok1 := d1.Load(k)
|
||||||
|
v2, ok2 := d2.Load(k)
|
||||||
|
if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil {
|
||||||
|
t.Errorf("Decoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKindCacheEncoderNilEncoder(t *testing.T) {
|
||||||
|
t.Run("Encoder", func(t *testing.T) {
|
||||||
|
c := new(kindEncoderCache)
|
||||||
|
c.Store(reflect.Invalid, ValueEncoder(nil))
|
||||||
|
v, ok := c.Load(reflect.Invalid)
|
||||||
|
if v != nil || ok {
|
||||||
|
t.Errorf("Load of nil ValueEncoder should return: nil, false; got: %v, %t", v, ok)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("Decoder", func(t *testing.T) {
|
||||||
|
c := new(kindDecoderCache)
|
||||||
|
c.Store(reflect.Invalid, ValueDecoder(nil))
|
||||||
|
v, ok := c.Load(reflect.Invalid)
|
||||||
|
if v != nil || ok {
|
||||||
|
t.Errorf("Load of nil ValueDecoder should return: nil, false; got: %v, %t", v, ok)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncoderCacheLoad(b *testing.B) {
|
||||||
|
c := new(typeEncoderCache)
|
||||||
|
codec := new(fakeCodec)
|
||||||
|
typs := codecCacheTestTypes
|
||||||
|
for _, t := range typs {
|
||||||
|
c.Store(t, codec)
|
||||||
|
}
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for i := 0; pb.Next(); i++ {
|
||||||
|
c.Load(typs[i%len(typs)])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncoderCacheStore(b *testing.B) {
|
||||||
|
c := new(typeEncoderCache)
|
||||||
|
codec := new(fakeCodec)
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
typs := codecCacheTestTypes
|
||||||
|
for i := 0; pb.Next(); i++ {
|
||||||
|
c.Store(typs[i%len(typs)], codec)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
61
cond_addr_codec.go
Normal file
61
cond_addr_codec.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder.
|
||||||
|
type condAddrEncoder struct {
|
||||||
|
canAddrEnc ValueEncoder
|
||||||
|
elseEnc ValueEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ValueEncoder = &condAddrEncoder{}
|
||||||
|
|
||||||
|
// newCondAddrEncoder returns an condAddrEncoder.
|
||||||
|
func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder {
|
||||||
|
encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
|
||||||
|
return &encoder
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeValue is the ValueEncoderFunc for a value that may be addressable.
|
||||||
|
func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if val.CanAddr() {
|
||||||
|
return cae.canAddrEnc.EncodeValue(ec, vw, val)
|
||||||
|
}
|
||||||
|
if cae.elseEnc != nil {
|
||||||
|
return cae.elseEnc.EncodeValue(ec, vw, val)
|
||||||
|
}
|
||||||
|
return errNoEncoder{Type: val.Type()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// condAddrDecoder is the decoder used when a pointer to the value has a decoder.
|
||||||
|
type condAddrDecoder struct {
|
||||||
|
canAddrDec ValueDecoder
|
||||||
|
elseDec ValueDecoder
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ValueDecoder = &condAddrDecoder{}
|
||||||
|
|
||||||
|
// newCondAddrDecoder returns an CondAddrDecoder.
|
||||||
|
func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder {
|
||||||
|
decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec}
|
||||||
|
return &decoder
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeValue is the ValueDecoderFunc for a value that may be addressable.
|
||||||
|
func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||||
|
if val.CanAddr() {
|
||||||
|
return cad.canAddrDec.DecodeValue(dc, vr, val)
|
||||||
|
}
|
||||||
|
if cad.elseDec != nil {
|
||||||
|
return cad.elseDec.DecodeValue(dc, vr, val)
|
||||||
|
}
|
||||||
|
return errNoDecoder{Type: val.Type()}
|
||||||
|
}
|
95
cond_addr_codec_test.go
Normal file
95
cond_addr_codec_test.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCondAddrCodec(t *testing.T) {
|
||||||
|
var inner int
|
||||||
|
canAddrVal := reflect.ValueOf(&inner)
|
||||||
|
addressable := canAddrVal.Elem()
|
||||||
|
unaddressable := reflect.ValueOf(inner)
|
||||||
|
rw := &valueReaderWriter{}
|
||||||
|
|
||||||
|
t.Run("addressEncode", func(t *testing.T) {
|
||||||
|
invoked := 0
|
||||||
|
encode1 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error {
|
||||||
|
invoked = 1
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
encode2 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error {
|
||||||
|
invoked = 2
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
condEncoder := newCondAddrEncoder(encode1, encode2)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
val reflect.Value
|
||||||
|
invoked int
|
||||||
|
}{
|
||||||
|
{"canAddr", addressable, 1},
|
||||||
|
{"else", unaddressable, 2},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := condEncoder.EncodeValue(EncodeContext{}, rw, tc.val)
|
||||||
|
assert.Nil(t, err, "CondAddrEncoder error: %v", err)
|
||||||
|
|
||||||
|
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("error", func(t *testing.T) {
|
||||||
|
errEncoder := newCondAddrEncoder(encode1, nil)
|
||||||
|
err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable)
|
||||||
|
want := errNoEncoder{Type: unaddressable.Type()}
|
||||||
|
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("addressDecode", func(t *testing.T) {
|
||||||
|
invoked := 0
|
||||||
|
decode1 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error {
|
||||||
|
invoked = 1
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
decode2 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error {
|
||||||
|
invoked = 2
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
condDecoder := newCondAddrDecoder(decode1, decode2)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
val reflect.Value
|
||||||
|
invoked int
|
||||||
|
}{
|
||||||
|
{"canAddr", addressable, 1},
|
||||||
|
{"else", unaddressable, 2},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val)
|
||||||
|
assert.Nil(t, err, "CondAddrDecoder error: %v", err)
|
||||||
|
|
||||||
|
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("error", func(t *testing.T) {
|
||||||
|
errDecoder := newCondAddrDecoder(decode1, nil)
|
||||||
|
err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable)
|
||||||
|
want := errNoDecoder{Type: unaddressable.Type()}
|
||||||
|
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
431
copier.go
Normal file
431
copier.go
Normal file
@ -0,0 +1,431 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
// copyDocument handles copying one document from the src to the dst.
|
||||||
|
func copyDocument(dst ValueWriter, src ValueReader) error {
|
||||||
|
dr, err := src.ReadDocument()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dw, err := dst.WriteDocument()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return copyDocumentCore(dw, dr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyArrayFromBytes copies the values from a BSON array represented as a
|
||||||
|
// []byte to a ValueWriter.
|
||||||
|
func copyArrayFromBytes(dst ValueWriter, src []byte) error {
|
||||||
|
aw, err := dst.WriteArray()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = copyBytesToArrayWriter(aw, src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return aw.WriteArrayEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyDocumentFromBytes copies the values from a BSON document represented as a
|
||||||
|
// []byte to a ValueWriter.
|
||||||
|
func copyDocumentFromBytes(dst ValueWriter, src []byte) error {
|
||||||
|
dw, err := dst.WriteDocument()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = copyBytesToDocumentWriter(dw, src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dw.WriteDocumentEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
type writeElementFn func(key string) (ValueWriter, error)
|
||||||
|
|
||||||
|
// copyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an
|
||||||
|
// ArrayWriter.
|
||||||
|
func copyBytesToArrayWriter(dst ArrayWriter, src []byte) error {
|
||||||
|
wef := func(_ string) (ValueWriter, error) {
|
||||||
|
return dst.WriteArrayElement()
|
||||||
|
}
|
||||||
|
|
||||||
|
return copyBytesToValueWriter(src, wef)
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a
|
||||||
|
// DocumentWriter.
|
||||||
|
func copyBytesToDocumentWriter(dst DocumentWriter, src []byte) error {
|
||||||
|
wef := func(key string) (ValueWriter, error) {
|
||||||
|
return dst.WriteDocumentElement(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return copyBytesToValueWriter(src, wef)
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyBytesToValueWriter(src []byte, wef writeElementFn) error {
|
||||||
|
// TODO(skriptble): Create errors types here. Anything that is a tag should be a property.
|
||||||
|
length, rem, ok := bsoncore.ReadLength(src)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src))
|
||||||
|
}
|
||||||
|
if len(src) < int(length) {
|
||||||
|
return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length)
|
||||||
|
}
|
||||||
|
rem = rem[:length-4]
|
||||||
|
|
||||||
|
var t bsoncore.Type
|
||||||
|
var key string
|
||||||
|
var val bsoncore.Value
|
||||||
|
for {
|
||||||
|
t, rem, ok = bsoncore.ReadType(rem)
|
||||||
|
if !ok {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
if t == bsoncore.Type(0) {
|
||||||
|
if len(rem) != 0 {
|
||||||
|
return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
key, rem, ok = bsoncore.ReadKey(rem)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid key found. remaining bytes=%v", rem)
|
||||||
|
}
|
||||||
|
|
||||||
|
// write as either array element or document element using writeElementFn
|
||||||
|
vw, err := wef(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
val, rem, ok = bsoncore.ReadValue(rem, t)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t)
|
||||||
|
}
|
||||||
|
err = copyValueFromBytes(vw, Type(t), val.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyDocumentToBytes copies an entire document from the ValueReader and
|
||||||
|
// returns it as bytes.
|
||||||
|
func copyDocumentToBytes(src ValueReader) ([]byte, error) {
|
||||||
|
return appendDocumentBytes(nil, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendDocumentBytes functions the same as CopyDocumentToBytes, but will
|
||||||
|
// append the result to dst.
|
||||||
|
func appendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) {
|
||||||
|
if br, ok := src.(bytesReader); ok {
|
||||||
|
_, dst, err := br.readValueBytes(dst)
|
||||||
|
return dst, err
|
||||||
|
}
|
||||||
|
|
||||||
|
vw := vwPool.Get().(*valueWriter)
|
||||||
|
defer putValueWriter(vw)
|
||||||
|
|
||||||
|
vw.reset(dst)
|
||||||
|
|
||||||
|
err := copyDocument(vw, src)
|
||||||
|
dst = vw.buf
|
||||||
|
return dst, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendArrayBytes copies an array from the ValueReader to dst.
|
||||||
|
func appendArrayBytes(dst []byte, src ValueReader) ([]byte, error) {
|
||||||
|
if br, ok := src.(bytesReader); ok {
|
||||||
|
_, dst, err := br.readValueBytes(dst)
|
||||||
|
return dst, err
|
||||||
|
}
|
||||||
|
|
||||||
|
vw := vwPool.Get().(*valueWriter)
|
||||||
|
defer putValueWriter(vw)
|
||||||
|
|
||||||
|
vw.reset(dst)
|
||||||
|
|
||||||
|
err := copyArray(vw, src)
|
||||||
|
dst = vw.buf
|
||||||
|
return dst, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyValueFromBytes will write the value represtend by t and src to dst.
|
||||||
|
func copyValueFromBytes(dst ValueWriter, t Type, src []byte) error {
|
||||||
|
if wvb, ok := dst.(bytesWriter); ok {
|
||||||
|
return wvb.writeValueBytes(t, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
vr := newDocumentReader(bytes.NewReader(src))
|
||||||
|
vr.pushElement(t)
|
||||||
|
|
||||||
|
return copyValue(dst, vr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyValueToBytes copies a value from src and returns it as a Type and a
|
||||||
|
// []byte.
|
||||||
|
func copyValueToBytes(src ValueReader) (Type, []byte, error) {
|
||||||
|
if br, ok := src.(bytesReader); ok {
|
||||||
|
return br.readValueBytes(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
vw := vwPool.Get().(*valueWriter)
|
||||||
|
defer putValueWriter(vw)
|
||||||
|
|
||||||
|
vw.reset(nil)
|
||||||
|
vw.push(mElement)
|
||||||
|
|
||||||
|
err := copyValue(vw, src)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Type(vw.buf[0]), vw.buf[2:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyValue will copy a single value from src to dst.
|
||||||
|
func copyValue(dst ValueWriter, src ValueReader) error {
|
||||||
|
var err error
|
||||||
|
switch src.Type() {
|
||||||
|
case TypeDouble:
|
||||||
|
var f64 float64
|
||||||
|
f64, err = src.ReadDouble()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteDouble(f64)
|
||||||
|
case TypeString:
|
||||||
|
var str string
|
||||||
|
str, err = src.ReadString()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = dst.WriteString(str)
|
||||||
|
case TypeEmbeddedDocument:
|
||||||
|
err = copyDocument(dst, src)
|
||||||
|
case TypeArray:
|
||||||
|
err = copyArray(dst, src)
|
||||||
|
case TypeBinary:
|
||||||
|
var data []byte
|
||||||
|
var subtype byte
|
||||||
|
data, subtype, err = src.ReadBinary()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteBinaryWithSubtype(data, subtype)
|
||||||
|
case TypeUndefined:
|
||||||
|
err = src.ReadUndefined()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteUndefined()
|
||||||
|
case TypeObjectID:
|
||||||
|
var oid ObjectID
|
||||||
|
oid, err = src.ReadObjectID()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteObjectID(oid)
|
||||||
|
case TypeBoolean:
|
||||||
|
var b bool
|
||||||
|
b, err = src.ReadBoolean()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteBoolean(b)
|
||||||
|
case TypeDateTime:
|
||||||
|
var dt int64
|
||||||
|
dt, err = src.ReadDateTime()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteDateTime(dt)
|
||||||
|
case TypeNull:
|
||||||
|
err = src.ReadNull()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteNull()
|
||||||
|
case TypeRegex:
|
||||||
|
var pattern, options string
|
||||||
|
pattern, options, err = src.ReadRegex()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteRegex(pattern, options)
|
||||||
|
case TypeDBPointer:
|
||||||
|
var ns string
|
||||||
|
var pointer ObjectID
|
||||||
|
ns, pointer, err = src.ReadDBPointer()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteDBPointer(ns, pointer)
|
||||||
|
case TypeJavaScript:
|
||||||
|
var js string
|
||||||
|
js, err = src.ReadJavascript()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteJavascript(js)
|
||||||
|
case TypeSymbol:
|
||||||
|
var symbol string
|
||||||
|
symbol, err = src.ReadSymbol()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteSymbol(symbol)
|
||||||
|
case TypeCodeWithScope:
|
||||||
|
var code string
|
||||||
|
var srcScope DocumentReader
|
||||||
|
code, srcScope, err = src.ReadCodeWithScope()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var dstScope DocumentWriter
|
||||||
|
dstScope, err = dst.WriteCodeWithScope(code)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = copyDocumentCore(dstScope, srcScope)
|
||||||
|
case TypeInt32:
|
||||||
|
var i32 int32
|
||||||
|
i32, err = src.ReadInt32()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteInt32(i32)
|
||||||
|
case TypeTimestamp:
|
||||||
|
var t, i uint32
|
||||||
|
t, i, err = src.ReadTimestamp()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteTimestamp(t, i)
|
||||||
|
case TypeInt64:
|
||||||
|
var i64 int64
|
||||||
|
i64, err = src.ReadInt64()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteInt64(i64)
|
||||||
|
case TypeDecimal128:
|
||||||
|
var d128 Decimal128
|
||||||
|
d128, err = src.ReadDecimal128()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteDecimal128(d128)
|
||||||
|
case TypeMinKey:
|
||||||
|
err = src.ReadMinKey()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteMinKey()
|
||||||
|
case TypeMaxKey:
|
||||||
|
err = src.ReadMaxKey()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = dst.WriteMaxKey()
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("cannot copy unknown BSON type %s", src.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyArray(dst ValueWriter, src ValueReader) error {
|
||||||
|
ar, err := src.ReadArray()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
aw, err := dst.WriteArray()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
vr, err := ar.ReadValue()
|
||||||
|
if errors.Is(err, ErrEOA) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
vw, err := aw.WriteArrayElement()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = copyValue(vw, vr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return aw.WriteArrayEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
|
||||||
|
for {
|
||||||
|
key, vr, err := dr.ReadElement()
|
||||||
|
if errors.Is(err, ErrEOD) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
vw, err := dw.WriteDocumentElement(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = copyValue(vw, vr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dw.WriteDocumentEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
// bytesReader is the interface used to read BSON bytes from a valueReader.
|
||||||
|
//
|
||||||
|
// The bytes of the value will be appended to dst.
|
||||||
|
type bytesReader interface {
|
||||||
|
readValueBytes(dst []byte) (Type, []byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bytesWriter is the interface used to write BSON bytes to a valueWriter.
|
||||||
|
type bytesWriter interface {
|
||||||
|
writeValueBytes(t Type, b []byte) error
|
||||||
|
}
|
528
copier_test.go
Normal file
528
copier_test.go
Normal file
@ -0,0 +1,528 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCopier(t *testing.T) {
|
||||||
|
t.Run("CopyDocument", func(t *testing.T) {
|
||||||
|
t.Run("ReadDocument Error", func(t *testing.T) {
|
||||||
|
want := errors.New("ReadDocumentError")
|
||||||
|
src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadDocument}
|
||||||
|
got := copyDocument(nil, src)
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("WriteDocument Error", func(t *testing.T) {
|
||||||
|
want := errors.New("WriteDocumentError")
|
||||||
|
src := &TestValueReaderWriter{}
|
||||||
|
dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteDocument}
|
||||||
|
got := copyDocument(dst, src)
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||||||
|
doc = bsoncore.AppendStringElement(doc, "Hello", "world")
|
||||||
|
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
|
||||||
|
noerr(t, err)
|
||||||
|
src := newDocumentReader(bytes.NewReader(doc))
|
||||||
|
dst := newValueWriterFromSlice(make([]byte, 0))
|
||||||
|
want := doc
|
||||||
|
err = copyDocument(dst, src)
|
||||||
|
noerr(t, err)
|
||||||
|
got := dst.buf
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("copyArray", func(t *testing.T) {
|
||||||
|
t.Run("ReadArray Error", func(t *testing.T) {
|
||||||
|
want := errors.New("ReadArrayError")
|
||||||
|
src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadArray}
|
||||||
|
got := copyArray(nil, src)
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("WriteArray Error", func(t *testing.T) {
|
||||||
|
want := errors.New("WriteArrayError")
|
||||||
|
src := &TestValueReaderWriter{}
|
||||||
|
dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteArray}
|
||||||
|
got := copyArray(dst, src)
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||||||
|
aidx, doc := bsoncore.AppendArrayElementStart(doc, "foo")
|
||||||
|
doc = bsoncore.AppendStringElement(doc, "0", "Hello, world!")
|
||||||
|
doc, err := bsoncore.AppendArrayEnd(doc, aidx)
|
||||||
|
noerr(t, err)
|
||||||
|
doc, err = bsoncore.AppendDocumentEnd(doc, idx)
|
||||||
|
noerr(t, err)
|
||||||
|
src := newDocumentReader(bytes.NewReader(doc))
|
||||||
|
|
||||||
|
_, err = src.ReadDocument()
|
||||||
|
noerr(t, err)
|
||||||
|
_, _, err = src.ReadElement()
|
||||||
|
noerr(t, err)
|
||||||
|
|
||||||
|
dst := newValueWriterFromSlice(make([]byte, 0))
|
||||||
|
_, err = dst.WriteDocument()
|
||||||
|
noerr(t, err)
|
||||||
|
_, err = dst.WriteDocumentElement("foo")
|
||||||
|
noerr(t, err)
|
||||||
|
want := doc
|
||||||
|
|
||||||
|
err = copyArray(dst, src)
|
||||||
|
noerr(t, err)
|
||||||
|
|
||||||
|
err = dst.WriteDocumentEnd()
|
||||||
|
noerr(t, err)
|
||||||
|
|
||||||
|
got := dst.buf
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("CopyValue", func(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
dst *TestValueReaderWriter
|
||||||
|
src *TestValueReaderWriter
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"Double/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDouble, err: errors.New("1"), errAfter: llvrwReadDouble},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Double/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDouble, err: errors.New("2"), errAfter: llvrwWriteDouble},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14159)},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"String/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeString, err: errors.New("1"), errAfter: llvrwReadString},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"String/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeString, err: errors.New("2"), errAfter: llvrwWriteString},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeString, readval: "hello, world"},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Document/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeEmbeddedDocument, err: errors.New("1"), errAfter: llvrwReadDocument},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Array/dst/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeArray, err: errors.New("2"), errAfter: llvrwReadArray},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Binary/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeBinary, err: errors.New("1"), errAfter: llvrwReadBinary},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Binary/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeBinary, err: errors.New("2"), errAfter: llvrwWriteBinaryWithSubtype},
|
||||||
|
&TestValueReaderWriter{
|
||||||
|
bsontype: TypeBinary,
|
||||||
|
readval: bsoncore.Value{
|
||||||
|
Type: bsoncore.TypeBinary,
|
||||||
|
Data: []byte{0x03, 0x00, 0x00, 0x00, 0xFF, 0x01, 0x02, 0x03},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Undefined/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeUndefined, err: errors.New("1"), errAfter: llvrwReadUndefined},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Undefined/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeUndefined, err: errors.New("2"), errAfter: llvrwWriteUndefined},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeUndefined},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ObjectID/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeObjectID, err: errors.New("1"), errAfter: llvrwReadObjectID},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ObjectID/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeObjectID, err: errors.New("2"), errAfter: llvrwWriteObjectID},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeObjectID, readval: ObjectID{0x01, 0x02, 0x03}},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Boolean/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeBoolean, err: errors.New("1"), errAfter: llvrwReadBoolean},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Boolean/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeBoolean, err: errors.New("2"), errAfter: llvrwWriteBoolean},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeBoolean, readval: bool(true)},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"DateTime/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDateTime, err: errors.New("1"), errAfter: llvrwReadDateTime},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"DateTime/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDateTime, err: errors.New("2"), errAfter: llvrwWriteDateTime},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDateTime, readval: int64(1234567890)},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Null/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeNull, err: errors.New("1"), errAfter: llvrwReadNull},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Null/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeNull, err: errors.New("2"), errAfter: llvrwWriteNull},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeNull},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Regex/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeRegex, err: errors.New("1"), errAfter: llvrwReadRegex},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Regex/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeRegex, err: errors.New("2"), errAfter: llvrwWriteRegex},
|
||||||
|
&TestValueReaderWriter{
|
||||||
|
bsontype: TypeRegex,
|
||||||
|
readval: bsoncore.Value{
|
||||||
|
Type: bsoncore.TypeRegex,
|
||||||
|
Data: bsoncore.AppendRegex(nil, "hello", "world"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"DBPointer/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDBPointer, err: errors.New("1"), errAfter: llvrwReadDBPointer},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"DBPointer/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDBPointer, err: errors.New("2"), errAfter: llvrwWriteDBPointer},
|
||||||
|
&TestValueReaderWriter{
|
||||||
|
bsontype: TypeDBPointer,
|
||||||
|
readval: bsoncore.Value{
|
||||||
|
Type: bsoncore.TypeDBPointer,
|
||||||
|
Data: bsoncore.AppendDBPointer(nil, "foo", ObjectID{0x01, 0x02, 0x03}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Javascript/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeJavaScript, err: errors.New("1"), errAfter: llvrwReadJavascript},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Javascript/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeJavaScript, err: errors.New("2"), errAfter: llvrwWriteJavascript},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeJavaScript, readval: "hello, world"},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Symbol/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeSymbol, err: errors.New("1"), errAfter: llvrwReadSymbol},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Symbol/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeSymbol, err: errors.New("2"), errAfter: llvrwWriteSymbol},
|
||||||
|
&TestValueReaderWriter{
|
||||||
|
bsontype: TypeSymbol,
|
||||||
|
readval: bsoncore.Value{
|
||||||
|
Type: bsoncore.TypeSymbol,
|
||||||
|
Data: bsoncore.AppendSymbol(nil, "hello, world"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"CodeWithScope/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeCodeWithScope, err: errors.New("1"), errAfter: llvrwReadCodeWithScope},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"CodeWithScope/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeCodeWithScope, err: errors.New("2"), errAfter: llvrwWriteCodeWithScope},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeCodeWithScope},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"CodeWithScope/dst/copyDocumentCore error",
|
||||||
|
&TestValueReaderWriter{err: errors.New("3"), errAfter: llvrwWriteDocumentElement},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeCodeWithScope},
|
||||||
|
errors.New("3"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Int32/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeInt32, err: errors.New("1"), errAfter: llvrwReadInt32},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Int32/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeInt32, err: errors.New("2"), errAfter: llvrwWriteInt32},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeInt32, readval: int32(12345)},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Timestamp/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeTimestamp, err: errors.New("1"), errAfter: llvrwReadTimestamp},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Timestamp/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeTimestamp, err: errors.New("2"), errAfter: llvrwWriteTimestamp},
|
||||||
|
&TestValueReaderWriter{
|
||||||
|
bsontype: TypeTimestamp,
|
||||||
|
readval: bsoncore.Value{
|
||||||
|
Type: bsoncore.TypeTimestamp,
|
||||||
|
Data: bsoncore.AppendTimestamp(nil, 12345, 67890),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Int64/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeInt64, err: errors.New("1"), errAfter: llvrwReadInt64},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Int64/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeInt64, err: errors.New("2"), errAfter: llvrwWriteInt64},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeInt64, readval: int64(1234567890)},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Decimal128/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDecimal128, err: errors.New("1"), errAfter: llvrwReadDecimal128},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Decimal128/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDecimal128, err: errors.New("2"), errAfter: llvrwWriteDecimal128},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeDecimal128, readval: NewDecimal128(12345, 67890)},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"MinKey/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeMinKey, err: errors.New("1"), errAfter: llvrwReadMinKey},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"MinKey/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeMinKey, err: errors.New("2"), errAfter: llvrwWriteMinKey},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeMinKey},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"MaxKey/src/error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeMaxKey, err: errors.New("1"), errAfter: llvrwReadMaxKey},
|
||||||
|
errors.New("1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"MaxKey/dst/error",
|
||||||
|
&TestValueReaderWriter{bsontype: TypeMaxKey, err: errors.New("2"), errAfter: llvrwWriteMaxKey},
|
||||||
|
&TestValueReaderWriter{bsontype: TypeMaxKey},
|
||||||
|
errors.New("2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Unknown BSON type error",
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
&TestValueReaderWriter{},
|
||||||
|
fmt.Errorf("cannot copy unknown BSON type %s", Type(0)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
tc.dst.t, tc.src.t = t, t
|
||||||
|
err := copyValue(tc.dst, tc.src)
|
||||||
|
if !assert.CompareErrors(err, tc.err) {
|
||||||
|
t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("CopyValueFromBytes", func(t *testing.T) {
|
||||||
|
t.Run("BytesWriter", func(t *testing.T) {
|
||||||
|
vw := newValueWriterFromSlice(make([]byte, 0))
|
||||||
|
_, err := vw.WriteDocument()
|
||||||
|
noerr(t, err)
|
||||||
|
_, err = vw.WriteDocumentElement("foo")
|
||||||
|
noerr(t, err)
|
||||||
|
err = copyValueFromBytes(vw, TypeString, bsoncore.AppendString(nil, "bar"))
|
||||||
|
noerr(t, err)
|
||||||
|
err = vw.WriteDocumentEnd()
|
||||||
|
noerr(t, err)
|
||||||
|
var idx int32
|
||||||
|
want, err := bsoncore.AppendDocumentEnd(
|
||||||
|
bsoncore.AppendStringElement(
|
||||||
|
bsoncore.AppendDocumentStartInline(nil, &idx),
|
||||||
|
"foo", "bar",
|
||||||
|
),
|
||||||
|
idx,
|
||||||
|
)
|
||||||
|
noerr(t, err)
|
||||||
|
got := vw.buf
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("Non BytesWriter", func(t *testing.T) {
|
||||||
|
llvrw := &TestValueReaderWriter{t: t}
|
||||||
|
err := copyValueFromBytes(llvrw, TypeString, bsoncore.AppendString(nil, "bar"))
|
||||||
|
noerr(t, err)
|
||||||
|
got, want := llvrw.invoked, llvrwWriteString
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Incorrect method invoked on llvrw. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("CopyValueToBytes", func(t *testing.T) {
|
||||||
|
t.Run("BytesReader", func(t *testing.T) {
|
||||||
|
var idx int32
|
||||||
|
b, err := bsoncore.AppendDocumentEnd(
|
||||||
|
bsoncore.AppendStringElement(
|
||||||
|
bsoncore.AppendDocumentStartInline(nil, &idx),
|
||||||
|
"hello", "world",
|
||||||
|
),
|
||||||
|
idx,
|
||||||
|
)
|
||||||
|
noerr(t, err)
|
||||||
|
vr := newDocumentReader(bytes.NewReader(b))
|
||||||
|
_, err = vr.ReadDocument()
|
||||||
|
noerr(t, err)
|
||||||
|
_, _, err = vr.ReadElement()
|
||||||
|
noerr(t, err)
|
||||||
|
btype, got, err := copyValueToBytes(vr)
|
||||||
|
noerr(t, err)
|
||||||
|
want := bsoncore.AppendString(nil, "world")
|
||||||
|
if btype != TypeString {
|
||||||
|
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("Bytes do not match. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("Non BytesReader", func(t *testing.T) {
|
||||||
|
llvrw := &TestValueReaderWriter{t: t, bsontype: TypeString, readval: "Hello, world!"}
|
||||||
|
btype, got, err := copyValueToBytes(llvrw)
|
||||||
|
noerr(t, err)
|
||||||
|
want := bsoncore.AppendString(nil, "Hello, world!")
|
||||||
|
if btype != TypeString {
|
||||||
|
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("Bytes do not match. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("AppendValueBytes", func(t *testing.T) {
|
||||||
|
t.Run("BytesReader", func(t *testing.T) {
|
||||||
|
var idx int32
|
||||||
|
b, err := bsoncore.AppendDocumentEnd(
|
||||||
|
bsoncore.AppendStringElement(
|
||||||
|
bsoncore.AppendDocumentStartInline(nil, &idx),
|
||||||
|
"hello", "world",
|
||||||
|
),
|
||||||
|
idx,
|
||||||
|
)
|
||||||
|
noerr(t, err)
|
||||||
|
vr := newDocumentReader(bytes.NewReader(b))
|
||||||
|
_, err = vr.ReadDocument()
|
||||||
|
noerr(t, err)
|
||||||
|
_, _, err = vr.ReadElement()
|
||||||
|
noerr(t, err)
|
||||||
|
btype, got, err := copyValueToBytes(vr)
|
||||||
|
noerr(t, err)
|
||||||
|
want := bsoncore.AppendString(nil, "world")
|
||||||
|
if btype != TypeString {
|
||||||
|
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("Bytes do not match. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("Non BytesReader", func(t *testing.T) {
|
||||||
|
llvrw := &TestValueReaderWriter{t: t, bsontype: TypeString, readval: "Hello, world!"}
|
||||||
|
btype, got, err := copyValueToBytes(llvrw)
|
||||||
|
noerr(t, err)
|
||||||
|
want := bsoncore.AppendString(nil, "Hello, world!")
|
||||||
|
if btype != TypeString {
|
||||||
|
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("Bytes do not match. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("CopyValue error", func(t *testing.T) {
|
||||||
|
want := errors.New("CopyValue error")
|
||||||
|
llvrw := &TestValueReaderWriter{t: t, bsontype: TypeString, err: want, errAfter: llvrwReadString}
|
||||||
|
_, _, got := copyValueToBytes(llvrw)
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
339
decimal.go
Normal file
339
decimal.go
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms.
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/decimal128"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
|
||||||
|
const (
|
||||||
|
MaxDecimal128Exp = 6111
|
||||||
|
MinDecimal128Exp = -6176
|
||||||
|
)
|
||||||
|
|
||||||
|
// These errors are returned when an invalid value is parsed as a big.Int.
|
||||||
|
var (
|
||||||
|
ErrParseNaN = errors.New("cannot parse NaN as a *big.Int")
|
||||||
|
ErrParseInf = errors.New("cannot parse Infinity as a *big.Int")
|
||||||
|
ErrParseNegInf = errors.New("cannot parse -Infinity as a *big.Int")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Decimal128 holds decimal128 BSON values.
|
||||||
|
type Decimal128 struct {
|
||||||
|
h, l uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDecimal128 creates a Decimal128 using the provide high and low uint64s.
|
||||||
|
func NewDecimal128(h, l uint64) Decimal128 {
|
||||||
|
return Decimal128{h: h, l: l}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBytes returns the underlying bytes of the BSON decimal value as two uint64 values. The first
|
||||||
|
// contains the most first 8 bytes of the value and the second contains the latter.
|
||||||
|
func (d Decimal128) GetBytes() (uint64, uint64) {
|
||||||
|
return d.h, d.l
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the decimal value.
|
||||||
|
func (d Decimal128) String() string {
|
||||||
|
return decimal128.String(d.h, d.l)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BigInt returns significand as big.Int and exponent, bi * 10 ^ exp.
|
||||||
|
func (d Decimal128) BigInt() (*big.Int, int, error) {
|
||||||
|
high, low := d.GetBytes()
|
||||||
|
posSign := high>>63&1 == 0 // positive sign
|
||||||
|
|
||||||
|
switch high >> 58 & (1<<5 - 1) {
|
||||||
|
case 0x1F:
|
||||||
|
return nil, 0, ErrParseNaN
|
||||||
|
case 0x1E:
|
||||||
|
if posSign {
|
||||||
|
return nil, 0, ErrParseInf
|
||||||
|
}
|
||||||
|
return nil, 0, ErrParseNegInf
|
||||||
|
}
|
||||||
|
|
||||||
|
var exp int
|
||||||
|
if high>>61&3 == 3 {
|
||||||
|
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
|
||||||
|
// Implicit 0b100 prefix in significand.
|
||||||
|
exp = int(high >> 47 & (1<<14 - 1))
|
||||||
|
// Spec says all of these values are out of range.
|
||||||
|
high, low = 0, 0
|
||||||
|
} else {
|
||||||
|
// Bits: 1*sign 14*exponent 113*significand
|
||||||
|
exp = int(high >> 49 & (1<<14 - 1))
|
||||||
|
high &= (1<<49 - 1)
|
||||||
|
}
|
||||||
|
exp += MinDecimal128Exp
|
||||||
|
|
||||||
|
// Would be handled by the logic below, but that's trivial and common.
|
||||||
|
if high == 0 && low == 0 && exp == 0 {
|
||||||
|
return new(big.Int), 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bi := big.NewInt(0)
|
||||||
|
const host32bit = ^uint(0)>>32 == 0
|
||||||
|
if host32bit {
|
||||||
|
bi.SetBits([]big.Word{big.Word(low), big.Word(low >> 32), big.Word(high), big.Word(high >> 32)})
|
||||||
|
} else {
|
||||||
|
bi.SetBits([]big.Word{big.Word(low), big.Word(high)})
|
||||||
|
}
|
||||||
|
|
||||||
|
if !posSign {
|
||||||
|
return bi.Neg(bi), exp, nil
|
||||||
|
}
|
||||||
|
return bi, exp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNaN returns whether d is NaN.
|
||||||
|
func (d Decimal128) IsNaN() bool {
|
||||||
|
return d.h>>58&(1<<5-1) == 0x1F
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInf returns:
|
||||||
|
//
|
||||||
|
// +1 d == Infinity
|
||||||
|
// 0 other case
|
||||||
|
// -1 d == -Infinity
|
||||||
|
func (d Decimal128) IsInf() int {
|
||||||
|
if d.h>>58&(1<<5-1) != 0x1E {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.h>>63&1 == 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsZero returns true if d is the empty Decimal128.
|
||||||
|
func (d Decimal128) IsZero() bool {
|
||||||
|
return d.h == 0 && d.l == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON returns Decimal128 as a string.
|
||||||
|
func (d Decimal128) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(d.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON creates a Decimal128 from a JSON string, an extended JSON $numberDecimal value, or the string
|
||||||
|
// "null". If b is a JSON string or extended JSON value, d will have the value of that string, and if b is "null", d will
|
||||||
|
// be unchanged.
|
||||||
|
func (d *Decimal128) UnmarshalJSON(b []byte) error {
|
||||||
|
// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer Decimal128 field
|
||||||
|
// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
|
||||||
|
// enter the UnmarshalJSON hook.
|
||||||
|
if string(b) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var res interface{}
|
||||||
|
err := json.Unmarshal(b, &res)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
str, ok := res.(string)
|
||||||
|
|
||||||
|
// Extended JSON
|
||||||
|
if !ok {
|
||||||
|
m, ok := res.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return errors.New("not an extended JSON Decimal128: expected document")
|
||||||
|
}
|
||||||
|
d128, ok := m["$numberDecimal"]
|
||||||
|
if !ok {
|
||||||
|
return errors.New("not an extended JSON Decimal128: expected key $numberDecimal")
|
||||||
|
}
|
||||||
|
str, ok = d128.(string)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("not an extended JSON Decimal128: expected decimal to be string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*d, err = ParseDecimal128(str)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var dNaN = Decimal128{0x1F << 58, 0}
|
||||||
|
var dPosInf = Decimal128{0x1E << 58, 0}
|
||||||
|
var dNegInf = Decimal128{0x3E << 58, 0}
|
||||||
|
|
||||||
|
func dErr(s string) (Decimal128, error) {
|
||||||
|
return dNaN, fmt.Errorf("cannot parse %q as a decimal128", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// match scientific notation number, example -10.15e-18
|
||||||
|
var normalNumber = regexp.MustCompile(`^(?P<int>[-+]?\d*)?(?:\.(?P<dec>\d*))?(?:[Ee](?P<exp>[-+]?\d+))?$`)
|
||||||
|
|
||||||
|
// ParseDecimal128 takes the given string and attempts to parse it into a valid
|
||||||
|
// Decimal128 value.
|
||||||
|
func ParseDecimal128(s string) (Decimal128, error) {
|
||||||
|
if s == "" {
|
||||||
|
return dErr(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := normalNumber.FindStringSubmatch(s)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
orig := s
|
||||||
|
neg := s[0] == '-'
|
||||||
|
if neg || s[0] == '+' {
|
||||||
|
s = s[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if s == "NaN" || s == "nan" || strings.EqualFold(s, "nan") {
|
||||||
|
return dNaN, nil
|
||||||
|
}
|
||||||
|
if s == "Inf" || s == "inf" || strings.EqualFold(s, "inf") || strings.EqualFold(s, "infinity") {
|
||||||
|
if neg {
|
||||||
|
return dNegInf, nil
|
||||||
|
}
|
||||||
|
return dPosInf, nil
|
||||||
|
}
|
||||||
|
return dErr(orig)
|
||||||
|
}
|
||||||
|
|
||||||
|
intPart := matches[1]
|
||||||
|
decPart := matches[2]
|
||||||
|
expPart := matches[3]
|
||||||
|
|
||||||
|
var err error
|
||||||
|
exp := 0
|
||||||
|
if expPart != "" {
|
||||||
|
exp, err = strconv.Atoi(expPart)
|
||||||
|
if err != nil {
|
||||||
|
return dErr(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if decPart != "" {
|
||||||
|
exp -= len(decPart)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(strings.Trim(intPart+decPart, "-0")) > 35 {
|
||||||
|
return dErr(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the significand (i.e. the non-exponent part) as a big.Int.
|
||||||
|
bi, ok := new(big.Int).SetString(intPart+decPart, 10)
|
||||||
|
if !ok {
|
||||||
|
return dErr(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
d, ok := ParseDecimal128FromBigInt(bi, exp)
|
||||||
|
if !ok {
|
||||||
|
return dErr(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if bi.Sign() == 0 && s[0] == '-' {
|
||||||
|
d.h |= 1 << 63
|
||||||
|
}
|
||||||
|
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ten = big.NewInt(10)
|
||||||
|
zero = new(big.Int)
|
||||||
|
|
||||||
|
maxS, _ = new(big.Int).SetString("9999999999999999999999999999999999", 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseDecimal128FromBigInt attempts to parse the given significand and exponent into a valid Decimal128 value.
|
||||||
|
func ParseDecimal128FromBigInt(bi *big.Int, exp int) (Decimal128, bool) {
|
||||||
|
// copy
|
||||||
|
bi = new(big.Int).Set(bi)
|
||||||
|
|
||||||
|
q := new(big.Int)
|
||||||
|
r := new(big.Int)
|
||||||
|
|
||||||
|
// If the significand is zero, the logical value will always be zero, independent of the
|
||||||
|
// exponent. However, the loops for handling out-of-range exponent values below may be extremely
|
||||||
|
// slow for zero values because the significand never changes. Limit the exponent value to the
|
||||||
|
// supported range here to prevent entering the loops below.
|
||||||
|
if bi.Cmp(zero) == 0 {
|
||||||
|
if exp > MaxDecimal128Exp {
|
||||||
|
exp = MaxDecimal128Exp
|
||||||
|
}
|
||||||
|
if exp < MinDecimal128Exp {
|
||||||
|
exp = MinDecimal128Exp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for bigIntCmpAbs(bi, maxS) == 1 {
|
||||||
|
bi, _ = q.QuoRem(bi, ten, r)
|
||||||
|
if r.Cmp(zero) != 0 {
|
||||||
|
return Decimal128{}, false
|
||||||
|
}
|
||||||
|
exp++
|
||||||
|
if exp > MaxDecimal128Exp {
|
||||||
|
return Decimal128{}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for exp < MinDecimal128Exp {
|
||||||
|
// Subnormal.
|
||||||
|
bi, _ = q.QuoRem(bi, ten, r)
|
||||||
|
if r.Cmp(zero) != 0 {
|
||||||
|
return Decimal128{}, false
|
||||||
|
}
|
||||||
|
exp++
|
||||||
|
}
|
||||||
|
for exp > MaxDecimal128Exp {
|
||||||
|
// Clamped.
|
||||||
|
bi.Mul(bi, ten)
|
||||||
|
if bigIntCmpAbs(bi, maxS) == 1 {
|
||||||
|
return Decimal128{}, false
|
||||||
|
}
|
||||||
|
exp--
|
||||||
|
}
|
||||||
|
|
||||||
|
b := bi.Bytes()
|
||||||
|
var h, l uint64
|
||||||
|
for i := 0; i < len(b); i++ {
|
||||||
|
if i < len(b)-8 {
|
||||||
|
h = h<<8 | uint64(b[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
l = l<<8 | uint64(b[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
h |= uint64(exp-MinDecimal128Exp) & uint64(1<<14-1) << 49
|
||||||
|
if bi.Sign() == -1 {
|
||||||
|
h |= 1 << 63
|
||||||
|
}
|
||||||
|
|
||||||
|
return Decimal128{h: h, l: l}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// bigIntCmpAbs computes big.Int.Cmp(absoluteValue(x), absoluteValue(y)).
|
||||||
|
func bigIntCmpAbs(x, y *big.Int) int {
|
||||||
|
xAbs := bigIntAbsValue(x)
|
||||||
|
yAbs := bigIntAbsValue(y)
|
||||||
|
return xAbs.Cmp(yAbs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bigIntAbsValue returns a big.Int containing the absolute value of b.
|
||||||
|
// If b is already a non-negative number, it is returned without any changes or copies.
|
||||||
|
func bigIntAbsValue(b *big.Int) *big.Int {
|
||||||
|
if b.Sign() >= 0 {
|
||||||
|
return b // already positive
|
||||||
|
}
|
||||||
|
return new(big.Int).Abs(b)
|
||||||
|
}
|
236
decimal_test.go
Normal file
236
decimal_test.go
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type bigIntTestCase struct {
|
||||||
|
s string
|
||||||
|
|
||||||
|
h uint64
|
||||||
|
l uint64
|
||||||
|
|
||||||
|
bi *big.Int
|
||||||
|
exp int
|
||||||
|
|
||||||
|
remark string
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBigInt(s string) *big.Int {
|
||||||
|
bi, _ := new(big.Int).SetString(s, 10)
|
||||||
|
return bi
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
one = big.NewInt(1)
|
||||||
|
|
||||||
|
biMaxS = new(big.Int).SetBytes([]byte{0x1, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
|
||||||
|
biNMaxS = new(big.Int).Neg(biMaxS)
|
||||||
|
|
||||||
|
biOverflow = new(big.Int).Add(biMaxS, one)
|
||||||
|
biNOverflow = new(big.Int).Neg(biOverflow)
|
||||||
|
|
||||||
|
bi12345 = parseBigInt("12345")
|
||||||
|
biN12345 = parseBigInt("-12345")
|
||||||
|
|
||||||
|
bi9_14 = parseBigInt("90123456789012")
|
||||||
|
biN9_14 = parseBigInt("-90123456789012")
|
||||||
|
|
||||||
|
bi9_34 = parseBigInt("9999999999999999999999999999999999")
|
||||||
|
biN9_34 = parseBigInt("-9999999999999999999999999999999999")
|
||||||
|
)
|
||||||
|
|
||||||
|
var bigIntTestCases = []bigIntTestCase{
|
||||||
|
{s: "12345", h: 0x3040000000000000, l: 12345, bi: bi12345},
|
||||||
|
{s: "-12345", h: 0xB040000000000000, l: 12345, bi: biN12345},
|
||||||
|
|
||||||
|
{s: "90123456.789012", h: 0x3034000000000000, l: 90123456789012, bi: bi9_14, exp: -6},
|
||||||
|
{s: "-90123456.789012", h: 0xB034000000000000, l: 90123456789012, bi: biN9_14, exp: -6},
|
||||||
|
{s: "9.0123456789012E+22", h: 0x3052000000000000, l: 90123456789012, bi: bi9_14, exp: 9},
|
||||||
|
{s: "-9.0123456789012E+22", h: 0xB052000000000000, l: 90123456789012, bi: biN9_14, exp: 9},
|
||||||
|
{s: "9.0123456789012E-8", h: 0x3016000000000000, l: 90123456789012, bi: bi9_14, exp: -21},
|
||||||
|
{s: "-9.0123456789012E-8", h: 0xB016000000000000, l: 90123456789012, bi: biN9_14, exp: -21},
|
||||||
|
|
||||||
|
{s: "9999999999999999999999999999999999", h: 3477321013416265664, l: 4003012203950112767, bi: bi9_34},
|
||||||
|
{s: "-9999999999999999999999999999999999", h: 12700693050271041472, l: 4003012203950112767, bi: biN9_34},
|
||||||
|
{s: "0.9999999999999999999999999999999999", h: 3458180714999941056, l: 4003012203950112767, bi: bi9_34, exp: -34},
|
||||||
|
{s: "-0.9999999999999999999999999999999999", h: 12681552751854716864, l: 4003012203950112767, bi: biN9_34, exp: -34},
|
||||||
|
{s: "99999999999999999.99999999999999999", h: 3467750864208103360, l: 4003012203950112767, bi: bi9_34, exp: -17},
|
||||||
|
{s: "-99999999999999999.99999999999999999", h: 12691122901062879168, l: 4003012203950112767, bi: biN9_34, exp: -17},
|
||||||
|
{s: "9.999999999999999999999999999999999E+35", h: 3478446913323108288, l: 4003012203950112767, bi: bi9_34, exp: 2},
|
||||||
|
{s: "-9.999999999999999999999999999999999E+35", h: 12701818950177884096, l: 4003012203950112767, bi: biN9_34, exp: 2},
|
||||||
|
{s: "9.999999999999999999999999999999999E+40", h: 3481261663090214848, l: 4003012203950112767, bi: bi9_34, exp: 7},
|
||||||
|
{s: "-9.999999999999999999999999999999999E+40", h: 12704633699944990656, l: 4003012203950112767, bi: biN9_34, exp: 7},
|
||||||
|
{s: "99999999999999999999999999999.99999", h: 3474506263649159104, l: 4003012203950112767, bi: bi9_34, exp: -5},
|
||||||
|
{s: "-99999999999999999999999999999.99999", h: 12697878300503934912, l: 4003012203950112767, bi: biN9_34, exp: -5},
|
||||||
|
|
||||||
|
{s: "1.038459371706965525706099265844019E-6143", remark: "subnormal", h: 0x333333333333, l: 0x3333333333333333, bi: parseBigInt("10384593717069655257060992658440190"), exp: MinDecimal128Exp - 1},
|
||||||
|
{s: "-1.038459371706965525706099265844019E-6143", remark: "subnormal", h: 0x8000333333333333, l: 0x3333333333333333, bi: parseBigInt("-10384593717069655257060992658440190"), exp: MinDecimal128Exp - 1},
|
||||||
|
|
||||||
|
{s: "rounding overflow 1", remark: "overflow", bi: parseBigInt("103845937170696552570609926584401910"), exp: MaxDecimal128Exp},
|
||||||
|
{s: "rounding overflow 2", remark: "overflow", bi: parseBigInt("103845937170696552570609926584401910"), exp: MaxDecimal128Exp},
|
||||||
|
|
||||||
|
{s: "subnormal overflow 1", remark: "overflow", bi: biMaxS, exp: MinDecimal128Exp - 1},
|
||||||
|
{s: "subnormal overflow 2", remark: "overflow", bi: biNMaxS, exp: MinDecimal128Exp - 1},
|
||||||
|
|
||||||
|
{s: "clamped overflow 1", remark: "overflow", bi: biMaxS, exp: MaxDecimal128Exp + 1},
|
||||||
|
{s: "clamped overflow 2", remark: "overflow", bi: biNMaxS, exp: MaxDecimal128Exp + 1},
|
||||||
|
|
||||||
|
{s: "biMaxS+1 overflow", remark: "overflow", bi: biOverflow, exp: MaxDecimal128Exp},
|
||||||
|
{s: "biNMaxS-1 overflow", remark: "overflow", bi: biNOverflow, exp: MaxDecimal128Exp},
|
||||||
|
|
||||||
|
{s: "NaN", h: 0x7c00000000000000, l: 0, remark: "NaN"},
|
||||||
|
{s: "Infinity", h: 0x7800000000000000, l: 0, remark: "Infinity"},
|
||||||
|
{s: "-Infinity", h: 0xf800000000000000, l: 0, remark: "-Infinity"},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecimal128_BigInt(t *testing.T) {
|
||||||
|
for _, c := range bigIntTestCases {
|
||||||
|
t.Run(c.s, func(t *testing.T) {
|
||||||
|
switch c.remark {
|
||||||
|
case "NaN", "Infinity", "-Infinity":
|
||||||
|
d128 := NewDecimal128(c.h, c.l)
|
||||||
|
_, _, err := d128.BigInt()
|
||||||
|
require.Error(t, err, "case %s", c.s)
|
||||||
|
case "":
|
||||||
|
d128 := NewDecimal128(c.h, c.l)
|
||||||
|
bi, e, err := d128.BigInt()
|
||||||
|
require.NoError(t, err, "case %s", c.s)
|
||||||
|
require.Equal(t, 0, c.bi.Cmp(bi), "case %s e:%s a:%s", c.s, c.bi.String(), bi.String())
|
||||||
|
require.Equal(t, c.exp, e, "case %s", c.s, d128.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDecimal128FromBigInt(t *testing.T) {
|
||||||
|
for _, c := range bigIntTestCases {
|
||||||
|
switch c.remark {
|
||||||
|
case "overflow":
|
||||||
|
d128, ok := ParseDecimal128FromBigInt(c.bi, c.exp)
|
||||||
|
require.Equal(t, false, ok, "case %s %s", c.s, d128.String(), c.remark)
|
||||||
|
case "", "rounding", "subnormal", "clamped":
|
||||||
|
d128, ok := ParseDecimal128FromBigInt(c.bi, c.exp)
|
||||||
|
require.Equal(t, true, ok, "case %s", c.s)
|
||||||
|
require.Equal(t, c.s, d128.String(), "case %s", c.s)
|
||||||
|
|
||||||
|
require.Equal(t, c.h, d128.h, "case %s", c.s, d128.l)
|
||||||
|
require.Equal(t, c.l, d128.l, "case %s", c.s, d128.h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseDecimal128(t *testing.T) {
|
||||||
|
cases := make([]bigIntTestCase, 0, len(bigIntTestCases))
|
||||||
|
cases = append(cases, bigIntTestCases...)
|
||||||
|
cases = append(cases,
|
||||||
|
bigIntTestCase{s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800},
|
||||||
|
bigIntTestCase{s: "12345E+21", h: 0x306a000000000000, l: 12345},
|
||||||
|
bigIntTestCase{s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"},
|
||||||
|
bigIntTestCase{s: ".125e1", h: 0x303c000000000000, l: 125},
|
||||||
|
bigIntTestCase{s: ".125", h: 0x303a000000000000, l: 125},
|
||||||
|
// Test that parsing negative zero returns negative zero with a zero exponent.
|
||||||
|
bigIntTestCase{s: "-0", h: 0xb040000000000000, l: 0},
|
||||||
|
// Test that parsing negative zero with an in-range exponent returns negative zero and
|
||||||
|
// preserves the specified exponent value.
|
||||||
|
bigIntTestCase{s: "-0E999", h: 0xb80e000000000000, l: 0},
|
||||||
|
// Test that parsing zero with an out-of-range positive exponent returns zero with the
|
||||||
|
// maximum positive exponent (i.e. 0e+6111).
|
||||||
|
bigIntTestCase{s: "0E2000000000000", h: 0x5ffe000000000000, l: 0},
|
||||||
|
// Test that parsing zero with an out-of-range negative exponent returns zero with the
|
||||||
|
// minimum negative exponent (i.e. 0e-6176).
|
||||||
|
bigIntTestCase{s: "-0E2000000000000", h: 0xdffe000000000000, l: 0},
|
||||||
|
bigIntTestCase{s: "", remark: "parse fail"})
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.s, func(t *testing.T) {
|
||||||
|
switch c.remark {
|
||||||
|
case "overflow", "parse fail":
|
||||||
|
_, err := ParseDecimal128(c.s)
|
||||||
|
assert.Error(t, err, "ParseDecimal128(%q) should return an error", c.s)
|
||||||
|
default:
|
||||||
|
got, err := ParseDecimal128(c.s)
|
||||||
|
require.NoError(t, err, "ParseDecimal128(%q) error", c.s)
|
||||||
|
|
||||||
|
want := Decimal128{h: c.h, l: c.l}
|
||||||
|
// Decimal128 doesn't implement an equality function, so compare the expected
|
||||||
|
// low/high uint64 values directly. Also print the string representation of each
|
||||||
|
// number to make debugging failures easier.
|
||||||
|
assert.Equal(t, want, got, "ParseDecimal128(%q) = %s, want %s", c.s, got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecimal128_JSON(t *testing.T) {
|
||||||
|
t.Run("roundTrip", func(t *testing.T) {
|
||||||
|
decimal := NewDecimal128(0x3040000000000000, 12345)
|
||||||
|
bytes, err := json.Marshal(decimal)
|
||||||
|
assert.Nil(t, err, "json.Marshal error: %v", err)
|
||||||
|
got := NewDecimal128(0, 0)
|
||||||
|
err = json.Unmarshal(bytes, &got)
|
||||||
|
assert.Nil(t, err, "json.Unmarshal error: %v", err)
|
||||||
|
assert.Equal(t, decimal.h, got.h, "expected h: %v got: %v", decimal.h, got.h)
|
||||||
|
assert.Equal(t, decimal.l, got.l, "expected l: %v got: %v", decimal.l, got.l)
|
||||||
|
})
|
||||||
|
t.Run("unmarshal extendedJSON", func(t *testing.T) {
|
||||||
|
want := NewDecimal128(0x3040000000000000, 12345)
|
||||||
|
extJSON := fmt.Sprintf(`{"$numberDecimal": %q}`, want.String())
|
||||||
|
|
||||||
|
got := NewDecimal128(0, 0)
|
||||||
|
err := json.Unmarshal([]byte(extJSON), &got)
|
||||||
|
assert.Nil(t, err, "json.Unmarshal error: %v", err)
|
||||||
|
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
|
||||||
|
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
|
||||||
|
})
|
||||||
|
t.Run("unmarshal null", func(t *testing.T) {
|
||||||
|
want := NewDecimal128(0, 0)
|
||||||
|
extJSON := `null`
|
||||||
|
|
||||||
|
got := NewDecimal128(0, 0)
|
||||||
|
err := json.Unmarshal([]byte(extJSON), &got)
|
||||||
|
assert.Nil(t, err, "json.Unmarshal error: %v", err)
|
||||||
|
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
|
||||||
|
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
|
||||||
|
})
|
||||||
|
t.Run("unmarshal", func(t *testing.T) {
|
||||||
|
cases := make([]bigIntTestCase, 0, len(bigIntTestCases))
|
||||||
|
cases = append(cases, bigIntTestCases...)
|
||||||
|
cases = append(cases,
|
||||||
|
bigIntTestCase{s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800},
|
||||||
|
bigIntTestCase{s: "12345E+21", h: 0x306a000000000000, l: 12345},
|
||||||
|
bigIntTestCase{s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"},
|
||||||
|
bigIntTestCase{s: ".125e1", h: 0x303c000000000000, l: 125},
|
||||||
|
bigIntTestCase{s: ".125", h: 0x303a000000000000, l: 125})
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.s, func(t *testing.T) {
|
||||||
|
input := fmt.Sprintf(`{"foo": %q}`, c.s)
|
||||||
|
var got map[string]Decimal128
|
||||||
|
err := json.Unmarshal([]byte(input), &got)
|
||||||
|
|
||||||
|
switch c.remark {
|
||||||
|
case "overflow", "parse fail":
|
||||||
|
assert.NotNil(t, err, "expected Unmarshal error, got nil")
|
||||||
|
default:
|
||||||
|
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||||
|
gotDecimal := got["foo"]
|
||||||
|
assert.Equal(t, c.h, gotDecimal.h, "expected h: %v got: %v", c.h, gotDecimal.l)
|
||||||
|
assert.Equal(t, c.l, gotDecimal.l, "expected l: %v got: %v", c.l, gotDecimal.h)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
136
decoder.go
Normal file
136
decoder.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrDecodeToNil is the error returned when trying to decode to a nil value
|
||||||
|
var ErrDecodeToNil = errors.New("cannot Decode to nil value")
|
||||||
|
|
||||||
|
// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal*
|
||||||
|
// methods and is not consumable from outside of this package. The Decoders retrieved from this pool
|
||||||
|
// must have both Reset and SetRegistry called on them.
|
||||||
|
var decPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return new(Decoder)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as
|
||||||
|
// the source of BSON data.
|
||||||
|
type Decoder struct {
|
||||||
|
dc DecodeContext
|
||||||
|
vr ValueReader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDecoder returns a new decoder that reads from vr.
|
||||||
|
func NewDecoder(vr ValueReader) *Decoder {
|
||||||
|
return &Decoder{
|
||||||
|
dc: DecodeContext{Registry: defaultRegistry},
|
||||||
|
vr: vr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode reads the next BSON document from the stream and decodes it into the
|
||||||
|
// value pointed to by val.
|
||||||
|
//
|
||||||
|
// See [Unmarshal] for details about BSON unmarshaling behavior.
|
||||||
|
func (d *Decoder) Decode(val interface{}) error {
|
||||||
|
if unmarshaler, ok := val.(Unmarshaler); ok {
|
||||||
|
// TODO(skriptble): Reuse a []byte here and use the AppendDocumentBytes method.
|
||||||
|
buf, err := copyDocumentToBytes(d.vr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return unmarshaler.UnmarshalBSON(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
rval := reflect.ValueOf(val)
|
||||||
|
switch rval.Kind() {
|
||||||
|
case reflect.Ptr:
|
||||||
|
if rval.IsNil() {
|
||||||
|
return ErrDecodeToNil
|
||||||
|
}
|
||||||
|
rval = rval.Elem()
|
||||||
|
case reflect.Map:
|
||||||
|
if rval.IsNil() {
|
||||||
|
return ErrDecodeToNil
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval)
|
||||||
|
}
|
||||||
|
decoder, err := d.dc.LookupDecoder(rval.Type())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoder.DecodeValue(d.dc, d.vr, rval)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset will reset the state of the decoder, using the same *DecodeContext used in
|
||||||
|
// the original construction but using vr for reading.
|
||||||
|
func (d *Decoder) Reset(vr ValueReader) {
|
||||||
|
d.vr = vr
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRegistry replaces the current registry of the decoder with r.
|
||||||
|
func (d *Decoder) SetRegistry(r *Registry) {
|
||||||
|
d.dc.Registry = r
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultDocumentM causes the Decoder to always unmarshal documents into the bson.M type. This
|
||||||
|
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
|
||||||
|
func (d *Decoder) DefaultDocumentM() {
|
||||||
|
d.dc.defaultDocumentType = reflect.TypeOf(M{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values
|
||||||
|
// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct
|
||||||
|
// field. The truncation logic does not apply to BSON "decimal128" values.
|
||||||
|
func (d *Decoder) AllowTruncatingDoubles() {
|
||||||
|
d.dc.truncate = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or
|
||||||
|
// "Old" BSON binary subtype as a Go byte slice instead of a bson.Binary.
|
||||||
|
func (d *Decoder) BinaryAsSlice() {
|
||||||
|
d.dc.binaryAsSlice = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ObjectIDAsHexString causes the Decoder to decode object IDs to their hex representation.
|
||||||
|
func (d *Decoder) ObjectIDAsHexString() {
|
||||||
|
d.dc.objectIDAsHexString = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson"
|
||||||
|
// struct tag is not specified.
|
||||||
|
func (d *Decoder) UseJSONStructTags() {
|
||||||
|
d.dc.useJSONStructTags = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead
|
||||||
|
// of the UTC timezone.
|
||||||
|
func (d *Decoder) UseLocalTimeZone() {
|
||||||
|
d.dc.useLocalTimeZone = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value
|
||||||
|
// passed to Decode before unmarshaling BSON documents into them.
|
||||||
|
func (d *Decoder) ZeroMaps() {
|
||||||
|
d.dc.zeroMaps = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination
|
||||||
|
// value passed to Decode before unmarshaling BSON documents into them.
|
||||||
|
func (d *Decoder) ZeroStructs() {
|
||||||
|
d.dc.zeroStructs = true
|
||||||
|
}
|
208
decoder_example_test.go
Normal file
208
decoder_example_test.go
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleDecoder() {
|
||||||
|
// Marshal a BSON document that contains the name, SKU, and price (in cents)
|
||||||
|
// of a product.
|
||||||
|
doc := bson.D{
|
||||||
|
{Key: "name", Value: "Cereal Rounds"},
|
||||||
|
{Key: "sku", Value: "AB12345"},
|
||||||
|
{Key: "price_cents", Value: 399},
|
||||||
|
}
|
||||||
|
data, err := bson.Marshal(doc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a Decoder that reads the marshaled BSON document and use it to
|
||||||
|
// unmarshal the document into a Product struct.
|
||||||
|
decoder := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(data)))
|
||||||
|
|
||||||
|
type Product struct {
|
||||||
|
Name string `bson:"name"`
|
||||||
|
SKU string `bson:"sku"`
|
||||||
|
Price int64 `bson:"price_cents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var res Product
|
||||||
|
err = decoder.Decode(&res)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%+v\n", res)
|
||||||
|
// Output: {Name:Cereal Rounds SKU:AB12345 Price:399}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleDecoder_DefaultDocumentM() {
|
||||||
|
// Marshal a BSON document that contains a city name and a nested document
|
||||||
|
// with various city properties.
|
||||||
|
doc := bson.D{
|
||||||
|
{Key: "name", Value: "New York"},
|
||||||
|
{Key: "properties", Value: bson.D{
|
||||||
|
{Key: "state", Value: "NY"},
|
||||||
|
{Key: "population", Value: 8_804_190},
|
||||||
|
{Key: "elevation", Value: 10},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
data, err := bson.Marshal(doc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a Decoder that reads the marshaled BSON document and use it to unmarshal the document
|
||||||
|
// into a City struct.
|
||||||
|
decoder := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(data)))
|
||||||
|
|
||||||
|
type City struct {
|
||||||
|
Name string `bson:"name"`
|
||||||
|
Properties interface{} `bson:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure the Decoder to default to decoding BSON documents as the M
|
||||||
|
// type if the decode destination has no type information. The Properties
|
||||||
|
// field in the City struct will be decoded as a "M" (i.e. map) instead
|
||||||
|
// of the default "D".
|
||||||
|
decoder.DefaultDocumentM()
|
||||||
|
|
||||||
|
var res City
|
||||||
|
err = decoder.Decode(&res)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err = json.Marshal(res)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("%+v\n", string(data))
|
||||||
|
// Output: {"Name":"New York","Properties":{"elevation":10,"population":8804190,"state":"NY"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleDecoder_UseJSONStructTags() {
|
||||||
|
// Marshal a BSON document that contains the name, SKU, and price (in cents)
|
||||||
|
// of a product.
|
||||||
|
doc := bson.D{
|
||||||
|
{Key: "name", Value: "Cereal Rounds"},
|
||||||
|
{Key: "sku", Value: "AB12345"},
|
||||||
|
{Key: "price_cents", Value: 399},
|
||||||
|
}
|
||||||
|
data, err := bson.Marshal(doc)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a Decoder that reads the marshaled BSON document and use it to
|
||||||
|
// unmarshal the document into a Product struct.
|
||||||
|
decoder := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(data)))
|
||||||
|
|
||||||
|
type Product struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
SKU string `json:"sku"`
|
||||||
|
Price int64 `json:"price_cents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure the Decoder to use "json" struct tags when decoding if "bson"
|
||||||
|
// struct tags are not present.
|
||||||
|
decoder.UseJSONStructTags()
|
||||||
|
|
||||||
|
var res Product
|
||||||
|
err = decoder.Decode(&res)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%+v\n", res)
|
||||||
|
// Output: {Name:Cereal Rounds SKU:AB12345 Price:399}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleDecoder_extendedJSON() {
|
||||||
|
// Define an Extended JSON document that contains the name, SKU, and price
|
||||||
|
// (in cents) of a product.
|
||||||
|
data := []byte(`{"name":"Cereal Rounds","sku":"AB12345","price_cents":{"$numberLong":"399"}}`)
|
||||||
|
|
||||||
|
// Create a Decoder that reads the Extended JSON document and use it to
|
||||||
|
// unmarshal the document into a Product struct.
|
||||||
|
vr, err := bson.NewExtJSONValueReader(bytes.NewReader(data), true)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
decoder := bson.NewDecoder(vr)
|
||||||
|
|
||||||
|
type Product struct {
|
||||||
|
Name string `bson:"name"`
|
||||||
|
SKU string `bson:"sku"`
|
||||||
|
Price int64 `bson:"price_cents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var res Product
|
||||||
|
err = decoder.Decode(&res)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%+v\n", res)
|
||||||
|
// Output: {Name:Cereal Rounds SKU:AB12345 Price:399}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleDecoder_multipleExtendedJSONDocuments() {
|
||||||
|
// Define a newline-separated sequence of Extended JSON documents that
|
||||||
|
// contain X,Y coordinates.
|
||||||
|
data := []byte(`
|
||||||
|
{"x":{"$numberInt":"0"},"y":{"$numberInt":"0"}}
|
||||||
|
{"x":{"$numberInt":"1"},"y":{"$numberInt":"1"}}
|
||||||
|
{"x":{"$numberInt":"2"},"y":{"$numberInt":"2"}}
|
||||||
|
{"x":{"$numberInt":"3"},"y":{"$numberInt":"3"}}
|
||||||
|
{"x":{"$numberInt":"4"},"y":{"$numberInt":"4"}}
|
||||||
|
`)
|
||||||
|
|
||||||
|
// Create a Decoder that reads the Extended JSON documents and use it to
|
||||||
|
// unmarshal the documents Coordinate structs.
|
||||||
|
vr, err := bson.NewExtJSONValueReader(bytes.NewReader(data), true)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
decoder := bson.NewDecoder(vr)
|
||||||
|
|
||||||
|
type Coordinate struct {
|
||||||
|
X int
|
||||||
|
Y int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and unmarshal each Extended JSON document from the sequence. If
|
||||||
|
// Decode returns error io.EOF, that means the Decoder has reached the end
|
||||||
|
// of the input, so break the loop.
|
||||||
|
for {
|
||||||
|
var res Coordinate
|
||||||
|
err = decoder.Decode(&res)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%+v\n", res)
|
||||||
|
}
|
||||||
|
// Output:
|
||||||
|
// {X:0 Y:0}
|
||||||
|
// {X:1 Y:1}
|
||||||
|
// {X:2 Y:2}
|
||||||
|
// {X:3 Y:3}
|
||||||
|
// {X:4 Y:4}
|
||||||
|
}
|
699
decoder_test.go
Normal file
699
decoder_test.go
Normal file
@ -0,0 +1,699 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/require"
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDecodeValue(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, tc := range unmarshalingTestCases() {
|
||||||
|
tc := tc
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := reflect.New(tc.sType).Elem()
|
||||||
|
vr := NewDocumentReader(bytes.NewReader(tc.data))
|
||||||
|
reg := defaultRegistry
|
||||||
|
decoder, err := reg.LookupDecoder(reflect.TypeOf(got))
|
||||||
|
noerr(t, err)
|
||||||
|
err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got)
|
||||||
|
noerr(t, err)
|
||||||
|
assert.Equal(t, tc.want, got.Addr().Interface(), "Results do not match.")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecodingInterfaces(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
stub func() ([]byte, interface{}, func(*testing.T))
|
||||||
|
}
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "struct with interface containing a concrete value",
|
||||||
|
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||||
|
type testStruct struct {
|
||||||
|
Value interface{}
|
||||||
|
}
|
||||||
|
var value string
|
||||||
|
|
||||||
|
data := docToBytes(struct {
|
||||||
|
Value string
|
||||||
|
}{
|
||||||
|
Value: "foo",
|
||||||
|
})
|
||||||
|
|
||||||
|
receiver := testStruct{&value}
|
||||||
|
|
||||||
|
check := func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
assert.Equal(t, "foo", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, &receiver, check
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "struct with interface containing a struct",
|
||||||
|
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||||
|
type demo struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
|
||||||
|
type testStruct struct {
|
||||||
|
Value interface{}
|
||||||
|
}
|
||||||
|
var value demo
|
||||||
|
|
||||||
|
data := docToBytes(struct {
|
||||||
|
Value demo
|
||||||
|
}{
|
||||||
|
Value: demo{"foo"},
|
||||||
|
})
|
||||||
|
|
||||||
|
receiver := testStruct{&value}
|
||||||
|
|
||||||
|
check := func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
assert.Equal(t, "foo", value.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, &receiver, check
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "struct with interface containing a slice",
|
||||||
|
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||||
|
type testStruct struct {
|
||||||
|
Values interface{}
|
||||||
|
}
|
||||||
|
var values []string
|
||||||
|
|
||||||
|
data := docToBytes(struct {
|
||||||
|
Values []string
|
||||||
|
}{
|
||||||
|
Values: []string{"foo", "bar"},
|
||||||
|
})
|
||||||
|
|
||||||
|
receiver := testStruct{&values}
|
||||||
|
|
||||||
|
check := func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
assert.Equal(t, []string{"foo", "bar"}, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, &receiver, check
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "struct with interface containing an array",
|
||||||
|
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||||
|
type testStruct struct {
|
||||||
|
Values interface{}
|
||||||
|
}
|
||||||
|
var values [2]string
|
||||||
|
|
||||||
|
data := docToBytes(struct {
|
||||||
|
Values []string
|
||||||
|
}{
|
||||||
|
Values: []string{"foo", "bar"},
|
||||||
|
})
|
||||||
|
|
||||||
|
receiver := testStruct{&values}
|
||||||
|
|
||||||
|
check := func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
assert.Equal(t, [2]string{"foo", "bar"}, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, &receiver, check
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "struct with interface array containing concrete values",
|
||||||
|
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||||
|
type testStruct struct {
|
||||||
|
Values [3]interface{}
|
||||||
|
}
|
||||||
|
var str string
|
||||||
|
var i, j int
|
||||||
|
|
||||||
|
data := docToBytes(struct {
|
||||||
|
Values []interface{}
|
||||||
|
}{
|
||||||
|
Values: []interface{}{"foo", 42, nil},
|
||||||
|
})
|
||||||
|
|
||||||
|
receiver := testStruct{[3]interface{}{&str, &i, &j}}
|
||||||
|
|
||||||
|
check := func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
assert.Equal(t, "foo", str)
|
||||||
|
assert.Equal(t, 42, i)
|
||||||
|
assert.Equal(t, 0, j)
|
||||||
|
assert.Equal(t, testStruct{[3]interface{}{&str, &i, nil}}, receiver)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, &receiver, check
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
data, receiver, check := tc.stub()
|
||||||
|
got := reflect.ValueOf(receiver).Elem()
|
||||||
|
vr := NewDocumentReader(bytes.NewReader(data))
|
||||||
|
reg := defaultRegistry
|
||||||
|
decoder, err := reg.LookupDecoder(got.Type())
|
||||||
|
noerr(t, err)
|
||||||
|
err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got)
|
||||||
|
noerr(t, err)
|
||||||
|
check(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecoder(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("Decode", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("basic", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, tc := range unmarshalingTestCases() {
|
||||||
|
tc := tc
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := reflect.New(tc.sType).Interface()
|
||||||
|
vr := NewDocumentReader(bytes.NewReader(tc.data))
|
||||||
|
dec := NewDecoder(vr)
|
||||||
|
err := dec.Decode(got)
|
||||||
|
noerr(t, err)
|
||||||
|
assert.Equal(t, tc.want, got, "Results do not match.")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("stream", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
vr := NewDocumentReader(&buf)
|
||||||
|
dec := NewDecoder(vr)
|
||||||
|
for _, tc := range unmarshalingTestCases() {
|
||||||
|
tc := tc
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
buf.Write(tc.data)
|
||||||
|
got := reflect.New(tc.sType).Interface()
|
||||||
|
err := dec.Decode(got)
|
||||||
|
noerr(t, err)
|
||||||
|
assert.Equal(t, tc.want, got, "Results do not match.")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("lookup error", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
type certainlydoesntexistelsewhereihope func(string, string) string
|
||||||
|
// Avoid unused code lint error.
|
||||||
|
_ = certainlydoesntexistelsewhereihope(func(string, string) string { return "" })
|
||||||
|
|
||||||
|
cdeih := func(string, string) string { return "certainlydoesntexistelsewhereihope" }
|
||||||
|
dec := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
|
||||||
|
want := errNoDecoder{Type: reflect.TypeOf(cdeih)}
|
||||||
|
got := dec.Decode(&cdeih)
|
||||||
|
assert.Equal(t, want, got, "Received unexpected error.")
|
||||||
|
})
|
||||||
|
t.Run("Unmarshaler", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
vr ValueReader
|
||||||
|
invoked bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"error",
|
||||||
|
errors.New("Unmarshaler error"),
|
||||||
|
&valueReaderWriter{BSONType: TypeEmbeddedDocument, Err: ErrEOD, ErrAfter: readElement},
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"copy error",
|
||||||
|
errors.New("copy error"),
|
||||||
|
&valueReaderWriter{Err: errors.New("copy error"), ErrAfter: readDocument},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"success",
|
||||||
|
nil,
|
||||||
|
&valueReaderWriter{BSONType: TypeEmbeddedDocument, Err: ErrEOD, ErrAfter: readElement},
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
unmarshaler := &testUnmarshaler{Err: tc.err}
|
||||||
|
dec := NewDecoder(tc.vr)
|
||||||
|
got := dec.Decode(unmarshaler)
|
||||||
|
want := tc.err
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
if unmarshaler.Invoked != tc.invoked {
|
||||||
|
if tc.invoked {
|
||||||
|
t.Error("Expected to have UnmarshalBSON invoked, but it wasn't.")
|
||||||
|
} else {
|
||||||
|
t.Error("Expected UnmarshalBSON to not be invoked, but it was.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Unmarshaler/success ValueReader", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
want := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))
|
||||||
|
unmarshaler := &testUnmarshaler{}
|
||||||
|
vr := NewDocumentReader(bytes.NewReader(want))
|
||||||
|
dec := NewDecoder(vr)
|
||||||
|
err := dec.Decode(unmarshaler)
|
||||||
|
noerr(t, err)
|
||||||
|
got := unmarshaler.Val
|
||||||
|
if !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("Did not unmarshal properly. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("NewDecoder", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
|
||||||
|
if got == nil {
|
||||||
|
t.Errorf("Was expecting a non-nil Decoder, but got <nil>")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("NewDecoderWithContext", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("success", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
|
||||||
|
if got == nil {
|
||||||
|
t.Errorf("Was expecting a non-nil Decoder, but got <nil>")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("Decode doesn't zero struct", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
type foo struct {
|
||||||
|
Item string
|
||||||
|
Qty int
|
||||||
|
Bonus int
|
||||||
|
}
|
||||||
|
var got foo
|
||||||
|
got.Item = "apple"
|
||||||
|
got.Bonus = 2
|
||||||
|
data := docToBytes(D{{"item", "canvas"}, {"qty", 4}})
|
||||||
|
vr := NewDocumentReader(bytes.NewReader(data))
|
||||||
|
dec := NewDecoder(vr)
|
||||||
|
err := dec.Decode(&got)
|
||||||
|
noerr(t, err)
|
||||||
|
want := foo{Item: "canvas", Qty: 4, Bonus: 2}
|
||||||
|
assert.Equal(t, want, got, "Results do not match.")
|
||||||
|
})
|
||||||
|
t.Run("Reset", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
vr1, vr2 := NewDocumentReader(bytes.NewReader([]byte{})), NewDocumentReader(bytes.NewReader([]byte{}))
|
||||||
|
dec := NewDecoder(vr1)
|
||||||
|
if dec.vr != vr1 {
|
||||||
|
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1)
|
||||||
|
}
|
||||||
|
dec.Reset(vr2)
|
||||||
|
if dec.vr != vr2 {
|
||||||
|
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("SetRegistry", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
r1, r2 := defaultRegistry, NewRegistry()
|
||||||
|
dc1 := DecodeContext{Registry: r1}
|
||||||
|
dc2 := DecodeContext{Registry: r2}
|
||||||
|
dec := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
|
||||||
|
if !reflect.DeepEqual(dec.dc, dc1) {
|
||||||
|
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1)
|
||||||
|
}
|
||||||
|
dec.SetRegistry(r2)
|
||||||
|
if !reflect.DeepEqual(dec.dc, dc2) {
|
||||||
|
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("DecodeToNil", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
data := docToBytes(D{{"item", "canvas"}, {"qty", 4}})
|
||||||
|
vr := NewDocumentReader(bytes.NewReader(data))
|
||||||
|
dec := NewDecoder(vr)
|
||||||
|
|
||||||
|
var got *D
|
||||||
|
err := dec.Decode(got)
|
||||||
|
if !errors.Is(err, ErrDecodeToNil) {
|
||||||
|
t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type testUnmarshaler struct {
|
||||||
|
Invoked bool
|
||||||
|
Val []byte
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tu *testUnmarshaler) UnmarshalBSON(d []byte) error {
|
||||||
|
tu.Invoked = true
|
||||||
|
tu.Val = d
|
||||||
|
return tu.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecoderConfiguration(t *testing.T) {
|
||||||
|
type truncateDoublesTest struct {
|
||||||
|
MyInt int
|
||||||
|
MyInt8 int8
|
||||||
|
MyInt16 int16
|
||||||
|
MyInt32 int32
|
||||||
|
MyInt64 int64
|
||||||
|
MyUint uint
|
||||||
|
MyUint8 uint8
|
||||||
|
MyUint16 uint16
|
||||||
|
MyUint32 uint32
|
||||||
|
MyUint64 uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type objectIDTest struct {
|
||||||
|
ID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type jsonStructTest struct {
|
||||||
|
StructFieldName string `json:"jsonFieldName"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type localTimeZoneTest struct {
|
||||||
|
MyTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type zeroMapsTest struct {
|
||||||
|
MyMap map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type zeroStructsTest struct {
|
||||||
|
MyString string
|
||||||
|
MyInt int
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
description string
|
||||||
|
configure func(*Decoder)
|
||||||
|
input []byte
|
||||||
|
decodeInto func() interface{}
|
||||||
|
want interface{}
|
||||||
|
}{
|
||||||
|
// Test that AllowTruncatingDoubles causes the Decoder to unmarshal BSON doubles with
|
||||||
|
// fractional parts into Go integer types by truncating the fractional part.
|
||||||
|
{
|
||||||
|
description: "AllowTruncatingDoubles",
|
||||||
|
configure: func(dec *Decoder) {
|
||||||
|
dec.AllowTruncatingDoubles()
|
||||||
|
},
|
||||||
|
input: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendDouble("myInt", 1.999).
|
||||||
|
AppendDouble("myInt8", 1.999).
|
||||||
|
AppendDouble("myInt16", 1.999).
|
||||||
|
AppendDouble("myInt32", 1.999).
|
||||||
|
AppendDouble("myInt64", 1.999).
|
||||||
|
AppendDouble("myUint", 1.999).
|
||||||
|
AppendDouble("myUint8", 1.999).
|
||||||
|
AppendDouble("myUint16", 1.999).
|
||||||
|
AppendDouble("myUint32", 1.999).
|
||||||
|
AppendDouble("myUint64", 1.999).
|
||||||
|
Build(),
|
||||||
|
decodeInto: func() interface{} { return &truncateDoublesTest{} },
|
||||||
|
want: &truncateDoublesTest{
|
||||||
|
MyInt: 1,
|
||||||
|
MyInt8: 1,
|
||||||
|
MyInt16: 1,
|
||||||
|
MyInt32: 1,
|
||||||
|
MyInt64: 1,
|
||||||
|
MyUint: 1,
|
||||||
|
MyUint8: 1,
|
||||||
|
MyUint16: 1,
|
||||||
|
MyUint32: 1,
|
||||||
|
MyUint64: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Test that BinaryAsSlice causes the Decoder to unmarshal BSON binary fields into Go byte
|
||||||
|
// slices when there is no type information (e.g when unmarshaling into a bson.D).
|
||||||
|
{
|
||||||
|
description: "BinaryAsSlice",
|
||||||
|
configure: func(dec *Decoder) {
|
||||||
|
dec.BinaryAsSlice()
|
||||||
|
},
|
||||||
|
input: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendBinary("myBinary", TypeBinaryGeneric, []byte{}).
|
||||||
|
Build(),
|
||||||
|
decodeInto: func() interface{} { return &D{} },
|
||||||
|
want: &D{{Key: "myBinary", Value: []byte{}}},
|
||||||
|
},
|
||||||
|
// Test that the default decoder always decodes BSON documents into bson.D values,
|
||||||
|
// independent of the top-level Go value type.
|
||||||
|
{
|
||||||
|
description: "DocumentD nested by default",
|
||||||
|
configure: func(_ *Decoder) {},
|
||||||
|
input: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
|
||||||
|
AppendString("myString", "test value").
|
||||||
|
Build()).
|
||||||
|
Build(),
|
||||||
|
decodeInto: func() interface{} { return M{} },
|
||||||
|
want: M{
|
||||||
|
"myDocument": D{{Key: "myString", Value: "test value"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Test that DefaultDocumentM always decodes BSON documents into bson.M values,
|
||||||
|
// independent of the top-level Go value type.
|
||||||
|
{
|
||||||
|
description: "DefaultDocumentM nested",
|
||||||
|
configure: func(dec *Decoder) {
|
||||||
|
dec.DefaultDocumentM()
|
||||||
|
},
|
||||||
|
input: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
|
||||||
|
AppendString("myString", "test value").
|
||||||
|
Build()).
|
||||||
|
Build(),
|
||||||
|
decodeInto: func() interface{} { return &D{} },
|
||||||
|
want: &D{
|
||||||
|
{Key: "myDocument", Value: M{"myString": "test value"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Test that ObjectIDAsHexString causes the Decoder to decode object ID to hex.
|
||||||
|
{
|
||||||
|
description: "ObjectIDAsHexString",
|
||||||
|
configure: func(dec *Decoder) {
|
||||||
|
dec.ObjectIDAsHexString()
|
||||||
|
},
|
||||||
|
input: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendObjectID("id", func() ObjectID {
|
||||||
|
id, _ := ObjectIDFromHex("5ef7fdd91c19e3222b41b839")
|
||||||
|
return id
|
||||||
|
}()).
|
||||||
|
Build(),
|
||||||
|
decodeInto: func() interface{} { return &objectIDTest{} },
|
||||||
|
want: &objectIDTest{ID: "5ef7fdd91c19e3222b41b839"},
|
||||||
|
},
|
||||||
|
// Test that UseJSONStructTags causes the Decoder to fall back to "json" struct tags if
|
||||||
|
// "bson" struct tags are not available.
|
||||||
|
{
|
||||||
|
description: "UseJSONStructTags",
|
||||||
|
configure: func(dec *Decoder) {
|
||||||
|
dec.UseJSONStructTags()
|
||||||
|
},
|
||||||
|
input: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendString("jsonFieldName", "test value").
|
||||||
|
Build(),
|
||||||
|
decodeInto: func() interface{} { return &jsonStructTest{} },
|
||||||
|
want: &jsonStructTest{StructFieldName: "test value"},
|
||||||
|
},
|
||||||
|
// Test that UseLocalTimeZone causes the Decoder to use the local time zone for decoded
|
||||||
|
// time.Time values instead of UTC.
|
||||||
|
{
|
||||||
|
description: "UseLocalTimeZone",
|
||||||
|
configure: func(dec *Decoder) {
|
||||||
|
dec.UseLocalTimeZone()
|
||||||
|
},
|
||||||
|
input: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendDateTime("myTime", 1684349179939).
|
||||||
|
Build(),
|
||||||
|
decodeInto: func() interface{} { return &localTimeZoneTest{} },
|
||||||
|
want: &localTimeZoneTest{MyTime: time.UnixMilli(1684349179939)},
|
||||||
|
},
|
||||||
|
// Test that ZeroMaps causes the Decoder to empty any Go map values before decoding BSON
|
||||||
|
// documents into them.
|
||||||
|
{
|
||||||
|
description: "ZeroMaps",
|
||||||
|
configure: func(dec *Decoder) {
|
||||||
|
dec.ZeroMaps()
|
||||||
|
},
|
||||||
|
input: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendDocument("myMap", bsoncore.NewDocumentBuilder().
|
||||||
|
AppendString("myString", "test value").
|
||||||
|
Build()).
|
||||||
|
Build(),
|
||||||
|
decodeInto: func() interface{} {
|
||||||
|
return &zeroMapsTest{MyMap: map[string]string{"myExtraValue": "extra value"}}
|
||||||
|
},
|
||||||
|
want: &zeroMapsTest{MyMap: map[string]string{"myString": "test value"}},
|
||||||
|
},
|
||||||
|
// Test that ZeroStructs causes the Decoder to empty any Go struct values before decoding
|
||||||
|
// BSON documents into them.
|
||||||
|
{
|
||||||
|
description: "ZeroStructs",
|
||||||
|
configure: func(dec *Decoder) {
|
||||||
|
dec.ZeroStructs()
|
||||||
|
},
|
||||||
|
input: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendString("myString", "test value").
|
||||||
|
Build(),
|
||||||
|
decodeInto: func() interface{} {
|
||||||
|
return &zeroStructsTest{MyInt: 1}
|
||||||
|
},
|
||||||
|
want: &zeroStructsTest{MyString: "test value"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc // Capture range variable.
|
||||||
|
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dec := NewDecoder(NewDocumentReader(bytes.NewReader(tc.input)))
|
||||||
|
|
||||||
|
tc.configure(dec)
|
||||||
|
|
||||||
|
got := tc.decodeInto()
|
||||||
|
err := dec.Decode(got)
|
||||||
|
require.NoError(t, err, "Decode error")
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, got, "expected and actual decode results do not match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Decoding an object ID to string", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
type objectIDTest struct {
|
||||||
|
ID string
|
||||||
|
}
|
||||||
|
|
||||||
|
doc := bsoncore.NewDocumentBuilder().
|
||||||
|
AppendObjectID("id", func() ObjectID {
|
||||||
|
id, _ := ObjectIDFromHex("5ef7fdd91c19e3222b41b839")
|
||||||
|
return id
|
||||||
|
}()).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
dec := NewDecoder(NewDocumentReader(bytes.NewReader(doc)))
|
||||||
|
|
||||||
|
var got objectIDTest
|
||||||
|
err := dec.Decode(&got)
|
||||||
|
const want = "error decoding key id: decoding an object ID into a string is not supported by default (set Decoder.ObjectIDAsHexString to enable decoding as a hexadecimal string)"
|
||||||
|
assert.EqualError(t, err, want)
|
||||||
|
})
|
||||||
|
t.Run("DefaultDocumentM top-level", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
input := bsoncore.NewDocumentBuilder().
|
||||||
|
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
|
||||||
|
AppendString("myString", "test value").
|
||||||
|
Build()).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
dec := NewDecoder(NewDocumentReader(bytes.NewReader(input)))
|
||||||
|
|
||||||
|
dec.DefaultDocumentM()
|
||||||
|
|
||||||
|
var got interface{}
|
||||||
|
err := dec.Decode(&got)
|
||||||
|
require.NoError(t, err, "Decode error")
|
||||||
|
|
||||||
|
want := M{
|
||||||
|
"myDocument": M{
|
||||||
|
"myString": "test value",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.Equal(t, want, got, "expected and actual decode results do not match")
|
||||||
|
})
|
||||||
|
t.Run("Default decodes DocumentD for top-level", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
input := bsoncore.NewDocumentBuilder().
|
||||||
|
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
|
||||||
|
AppendString("myString", "test value").
|
||||||
|
Build()).
|
||||||
|
Build()
|
||||||
|
|
||||||
|
dec := NewDecoder(NewDocumentReader(bytes.NewReader(input)))
|
||||||
|
|
||||||
|
var got interface{}
|
||||||
|
err := dec.Decode(&got)
|
||||||
|
require.NoError(t, err, "Decode error")
|
||||||
|
|
||||||
|
want := D{
|
||||||
|
{Key: "myDocument", Value: D{
|
||||||
|
{Key: "myString", Value: "test value"},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
assert.Equal(t, want, got, "expected and actual decode results do not match")
|
||||||
|
})
|
||||||
|
}
|
1497
default_value_decoders.go
Normal file
1497
default_value_decoders.go
Normal file
File diff suppressed because it is too large
Load Diff
3806
default_value_decoders_test.go
Normal file
3806
default_value_decoders_test.go
Normal file
File diff suppressed because it is too large
Load Diff
517
default_value_encoders.go
Normal file
517
default_value_encoders.go
Normal file
@ -0,0 +1,517 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
var bvwPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return new(valueWriter)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var errInvalidValue = errors.New("cannot encode invalid element")
|
||||||
|
|
||||||
|
var sliceWriterPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
sw := make(sliceWriter, 0)
|
||||||
|
return &sw
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error {
|
||||||
|
vw, err := dw.WriteDocumentElement(e.Key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.Value == nil {
|
||||||
|
return vw.WriteNull()
|
||||||
|
}
|
||||||
|
encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with
|
||||||
|
// the provided RegistryBuilder.
|
||||||
|
func registerDefaultEncoders(reg *Registry) {
|
||||||
|
mapEncoder := &mapCodec{}
|
||||||
|
uintCodec := &uintCodec{}
|
||||||
|
|
||||||
|
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{})
|
||||||
|
reg.RegisterTypeEncoder(tTime, &timeCodec{})
|
||||||
|
reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{})
|
||||||
|
reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{})
|
||||||
|
reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue))
|
||||||
|
reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Int16, ValueEncoderFunc(intEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Int32, ValueEncoderFunc(intEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Int64, ValueEncoderFunc(intEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Uint, uintCodec)
|
||||||
|
reg.RegisterKindEncoder(reflect.Uint8, uintCodec)
|
||||||
|
reg.RegisterKindEncoder(reflect.Uint16, uintCodec)
|
||||||
|
reg.RegisterKindEncoder(reflect.Uint32, uintCodec)
|
||||||
|
reg.RegisterKindEncoder(reflect.Uint64, uintCodec)
|
||||||
|
reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue))
|
||||||
|
reg.RegisterKindEncoder(reflect.Map, mapEncoder)
|
||||||
|
reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{})
|
||||||
|
reg.RegisterKindEncoder(reflect.String, &stringCodec{})
|
||||||
|
reg.RegisterKindEncoder(reflect.Struct, newStructCodec(mapEncoder))
|
||||||
|
reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{})
|
||||||
|
reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue))
|
||||||
|
reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue))
|
||||||
|
}
|
||||||
|
|
||||||
|
// booleanEncodeValue is the ValueEncoderFunc for bool types.
|
||||||
|
func booleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Kind() != reflect.Bool {
|
||||||
|
return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val}
|
||||||
|
}
|
||||||
|
return vw.WriteBoolean(val.Bool())
|
||||||
|
}
|
||||||
|
|
||||||
|
func fitsIn32Bits(i int64) bool {
|
||||||
|
return math.MinInt32 <= i && i <= math.MaxInt32
|
||||||
|
}
|
||||||
|
|
||||||
|
// intEncodeValue is the ValueEncoderFunc for int types.
|
||||||
|
func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
switch val.Kind() {
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32:
|
||||||
|
return vw.WriteInt32(int32(val.Int()))
|
||||||
|
case reflect.Int:
|
||||||
|
i64 := val.Int()
|
||||||
|
if fitsIn32Bits(i64) {
|
||||||
|
return vw.WriteInt32(int32(i64))
|
||||||
|
}
|
||||||
|
return vw.WriteInt64(i64)
|
||||||
|
case reflect.Int64:
|
||||||
|
i64 := val.Int()
|
||||||
|
if ec.minSize && fitsIn32Bits(i64) {
|
||||||
|
return vw.WriteInt32(int32(i64))
|
||||||
|
}
|
||||||
|
return vw.WriteInt64(i64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValueEncoderError{
|
||||||
|
Name: "IntEncodeValue",
|
||||||
|
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
|
||||||
|
Received: val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// floatEncodeValue is the ValueEncoderFunc for float types.
|
||||||
|
func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
switch val.Kind() {
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return vw.WriteDouble(val.Float())
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
// objectIDEncodeValue is the ValueEncoderFunc for ObjectID.
|
||||||
|
func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tOID {
|
||||||
|
return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val}
|
||||||
|
}
|
||||||
|
return vw.WriteObjectID(val.Interface().(ObjectID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// decimal128EncodeValue is the ValueEncoderFunc for Decimal128.
|
||||||
|
func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tDecimal {
|
||||||
|
return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val}
|
||||||
|
}
|
||||||
|
return vw.WriteDecimal128(val.Interface().(Decimal128))
|
||||||
|
}
|
||||||
|
|
||||||
|
// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number.
|
||||||
|
func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tJSONNumber {
|
||||||
|
return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val}
|
||||||
|
}
|
||||||
|
jsnum := val.Interface().(json.Number)
|
||||||
|
|
||||||
|
// Attempt int first, then float64
|
||||||
|
if i64, err := jsnum.Int64(); err == nil {
|
||||||
|
return intEncodeValue(ec, vw, reflect.ValueOf(i64))
|
||||||
|
}
|
||||||
|
|
||||||
|
f64, err := jsnum.Float64()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return floatEncodeValue(ec, vw, reflect.ValueOf(f64))
|
||||||
|
}
|
||||||
|
|
||||||
|
// urlEncodeValue is the ValueEncoderFunc for url.URL.
|
||||||
|
func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tURL {
|
||||||
|
return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val}
|
||||||
|
}
|
||||||
|
u := val.Interface().(url.URL)
|
||||||
|
return vw.WriteString(u.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// arrayEncodeValue is the ValueEncoderFunc for array types.
|
||||||
|
func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Kind() != reflect.Array {
|
||||||
|
return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a []E we want to treat it as a document instead of as an array.
|
||||||
|
if val.Type().Elem() == tE {
|
||||||
|
dw, err := vw.WriteDocument()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx := 0; idx < val.Len(); idx++ {
|
||||||
|
e := val.Index(idx).Interface().(E)
|
||||||
|
err = encodeElement(ec, dw, e)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dw.WriteDocumentEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a []byte we want to treat it as a binary instead of as an array.
|
||||||
|
if val.Type().Elem() == tByte {
|
||||||
|
var byteSlice []byte
|
||||||
|
for idx := 0; idx < val.Len(); idx++ {
|
||||||
|
byteSlice = append(byteSlice, val.Index(idx).Interface().(byte))
|
||||||
|
}
|
||||||
|
return vw.WriteBinary(byteSlice)
|
||||||
|
}
|
||||||
|
|
||||||
|
aw, err := vw.WriteArray()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
elemType := val.Type().Elem()
|
||||||
|
encoder, err := ec.LookupEncoder(elemType)
|
||||||
|
if err != nil && elemType.Kind() != reflect.Interface {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx := 0; idx < val.Len(); idx++ {
|
||||||
|
currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx))
|
||||||
|
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
|
||||||
|
return lookupErr
|
||||||
|
}
|
||||||
|
|
||||||
|
vw, err := aw.WriteArrayElement()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(lookupErr, errInvalidValue) {
|
||||||
|
err = vw.WriteNull()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = currEncoder.EncodeValue(ec, vw, currVal)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return aw.WriteArrayEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) {
|
||||||
|
if origEncoder != nil || (currVal.Kind() != reflect.Interface) {
|
||||||
|
return origEncoder, currVal, nil
|
||||||
|
}
|
||||||
|
currVal = currVal.Elem()
|
||||||
|
if !currVal.IsValid() {
|
||||||
|
return nil, currVal, errInvalidValue
|
||||||
|
}
|
||||||
|
currEncoder, err := ec.LookupEncoder(currVal.Type())
|
||||||
|
|
||||||
|
return currEncoder, currVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations.
|
||||||
|
func valueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
// Either val or a pointer to val must implement ValueMarshaler
|
||||||
|
switch {
|
||||||
|
case !val.IsValid():
|
||||||
|
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
|
||||||
|
case val.Type().Implements(tValueMarshaler):
|
||||||
|
// If ValueMarshaler is implemented on a concrete type, make sure that val isn't a nil pointer
|
||||||
|
if isImplementationNil(val, tValueMarshaler) {
|
||||||
|
return vw.WriteNull()
|
||||||
|
}
|
||||||
|
case reflect.PtrTo(val.Type()).Implements(tValueMarshaler) && val.CanAddr():
|
||||||
|
val = val.Addr()
|
||||||
|
default:
|
||||||
|
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
m, ok := val.Interface().(ValueMarshaler)
|
||||||
|
if !ok {
|
||||||
|
return vw.WriteNull()
|
||||||
|
}
|
||||||
|
t, data, err := m.MarshalBSONValue()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return copyValueFromBytes(vw, Type(t), data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations.
|
||||||
|
func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
// Either val or a pointer to val must implement Marshaler
|
||||||
|
switch {
|
||||||
|
case !val.IsValid():
|
||||||
|
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
|
||||||
|
case val.Type().Implements(tMarshaler):
|
||||||
|
// If Marshaler is implemented on a concrete type, make sure that val isn't a nil pointer
|
||||||
|
if isImplementationNil(val, tMarshaler) {
|
||||||
|
return vw.WriteNull()
|
||||||
|
}
|
||||||
|
case reflect.PtrTo(val.Type()).Implements(tMarshaler) && val.CanAddr():
|
||||||
|
val = val.Addr()
|
||||||
|
default:
|
||||||
|
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
m, ok := val.Interface().(Marshaler)
|
||||||
|
if !ok {
|
||||||
|
return vw.WriteNull()
|
||||||
|
}
|
||||||
|
data, err := m.MarshalBSON()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return copyValueFromBytes(vw, TypeEmbeddedDocument, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type.
|
||||||
|
func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tJavaScript {
|
||||||
|
return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vw.WriteJavascript(val.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// symbolEncodeValue is the ValueEncoderFunc for the Symbol type.
|
||||||
|
func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tSymbol {
|
||||||
|
return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vw.WriteSymbol(val.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// binaryEncodeValue is the ValueEncoderFunc for Binary.
|
||||||
|
func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tBinary {
|
||||||
|
return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val}
|
||||||
|
}
|
||||||
|
b := val.Interface().(Binary)
|
||||||
|
|
||||||
|
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// vectorEncodeValue is the ValueEncoderFunc for Vector.
|
||||||
|
func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
t := val.Type()
|
||||||
|
if !val.IsValid() || t != tVector {
|
||||||
|
return ValueEncoderError{Name: "VectorEncodeValue",
|
||||||
|
Types: []reflect.Type{tVector},
|
||||||
|
Received: val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v := val.Interface().(Vector)
|
||||||
|
b := v.Binary()
|
||||||
|
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// undefinedEncodeValue is the ValueEncoderFunc for Undefined.
|
||||||
|
func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tUndefined {
|
||||||
|
return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vw.WriteUndefined()
|
||||||
|
}
|
||||||
|
|
||||||
|
// dateTimeEncodeValue is the ValueEncoderFunc for DateTime.
|
||||||
|
func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tDateTime {
|
||||||
|
return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vw.WriteDateTime(val.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
// nullEncodeValue is the ValueEncoderFunc for Null.
|
||||||
|
func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tNull {
|
||||||
|
return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vw.WriteNull()
|
||||||
|
}
|
||||||
|
|
||||||
|
// regexEncodeValue is the ValueEncoderFunc for Regex.
|
||||||
|
func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tRegex {
|
||||||
|
return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
regex := val.Interface().(Regex)
|
||||||
|
|
||||||
|
return vw.WriteRegex(regex.Pattern, regex.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer.
|
||||||
|
func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tDBPointer {
|
||||||
|
return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
dbp := val.Interface().(DBPointer)
|
||||||
|
|
||||||
|
return vw.WriteDBPointer(dbp.DB, dbp.Pointer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// timestampEncodeValue is the ValueEncoderFunc for Timestamp.
|
||||||
|
func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tTimestamp {
|
||||||
|
return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := val.Interface().(Timestamp)
|
||||||
|
|
||||||
|
return vw.WriteTimestamp(ts.T, ts.I)
|
||||||
|
}
|
||||||
|
|
||||||
|
// minKeyEncodeValue is the ValueEncoderFunc for MinKey.
|
||||||
|
func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tMinKey {
|
||||||
|
return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vw.WriteMinKey()
|
||||||
|
}
|
||||||
|
|
||||||
|
// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey.
|
||||||
|
func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tMaxKey {
|
||||||
|
return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
return vw.WriteMaxKey()
|
||||||
|
}
|
||||||
|
|
||||||
|
// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document.
|
||||||
|
func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tCoreDocument {
|
||||||
|
return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
cdoc := val.Interface().(bsoncore.Document)
|
||||||
|
|
||||||
|
return copyDocumentFromBytes(vw, cdoc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope.
|
||||||
|
func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tCodeWithScope {
|
||||||
|
return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
cws := val.Interface().(CodeWithScope)
|
||||||
|
|
||||||
|
dw, err := vw.WriteCodeWithScope(string(cws.Code))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sw := sliceWriterPool.Get().(*sliceWriter)
|
||||||
|
defer sliceWriterPool.Put(sw)
|
||||||
|
*sw = (*sw)[:0]
|
||||||
|
|
||||||
|
scopeVW := bvwPool.Get().(*valueWriter)
|
||||||
|
scopeVW.reset(scopeVW.buf[:0])
|
||||||
|
scopeVW.w = sw
|
||||||
|
defer bvwPool.Put(scopeVW)
|
||||||
|
|
||||||
|
encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = copyBytesToDocumentWriter(dw, *sw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return dw.WriteDocumentEnd()
|
||||||
|
}
|
||||||
|
|
||||||
|
// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type
|
||||||
|
func isImplementationNil(val reflect.Value, inter reflect.Type) bool {
|
||||||
|
vt := val.Type()
|
||||||
|
for vt.Kind() == reflect.Ptr {
|
||||||
|
vt = vt.Elem()
|
||||||
|
}
|
||||||
|
return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil()
|
||||||
|
}
|
1758
default_value_encoders_test.go
Normal file
1758
default_value_encoders_test.go
Normal file
File diff suppressed because it is too large
Load Diff
155
doc.go
Normal file
155
doc.go
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
// Package bson is a library for reading, writing, and manipulating BSON. BSON is a binary serialization format used to
|
||||||
|
// store documents and make remote procedure calls in MongoDB. The BSON specification is located at https://bsonspec.org.
|
||||||
|
// The BSON library handles marshaling and unmarshaling of values through a configurable codec system. For a description
|
||||||
|
// of the codec system and examples of registering custom codecs, see the bsoncodec package. For additional information
|
||||||
|
// and usage examples, check out the [Work with BSON] page in the Go Driver docs site.
|
||||||
|
//
|
||||||
|
// # Raw BSON
|
||||||
|
//
|
||||||
|
// The Raw family of types is used to validate and retrieve elements from a slice of bytes. This
|
||||||
|
// type is most useful when you want do lookups on BSON bytes without unmarshaling it into another
|
||||||
|
// type.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// var raw bson.Raw = ... // bytes from somewhere
|
||||||
|
// err := raw.Validate()
|
||||||
|
// if err != nil { return err }
|
||||||
|
// val := raw.Lookup("foo")
|
||||||
|
// i32, ok := val.Int32OK()
|
||||||
|
// // do something with i32...
|
||||||
|
//
|
||||||
|
// # Native Go Types
|
||||||
|
//
|
||||||
|
// The D and M types defined in this package can be used to build representations of BSON using native Go types. D is a
|
||||||
|
// slice and M is a map. For more information about the use cases for these types, see the documentation on the type
|
||||||
|
// definitions.
|
||||||
|
//
|
||||||
|
// Note that a D should not be constructed with duplicate key names, as that can cause undefined server behavior.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
|
||||||
|
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
|
||||||
|
//
|
||||||
|
// When decoding BSON to a D or M, the following type mappings apply when unmarshaling:
|
||||||
|
//
|
||||||
|
// 1. BSON int32 unmarshals to an int32.
|
||||||
|
// 2. BSON int64 unmarshals to an int64.
|
||||||
|
// 3. BSON double unmarshals to a float64.
|
||||||
|
// 4. BSON string unmarshals to a string.
|
||||||
|
// 5. BSON boolean unmarshals to a bool.
|
||||||
|
// 6. BSON embedded document unmarshals to the parent type (i.e. D for a D, M for an M).
|
||||||
|
// 7. BSON array unmarshals to a bson.A.
|
||||||
|
// 8. BSON ObjectId unmarshals to a bson.ObjectID.
|
||||||
|
// 9. BSON datetime unmarshals to a bson.DateTime.
|
||||||
|
// 10. BSON binary unmarshals to a bson.Binary.
|
||||||
|
// 11. BSON regular expression unmarshals to a bson.Regex.
|
||||||
|
// 12. BSON JavaScript unmarshals to a bson.JavaScript.
|
||||||
|
// 13. BSON code with scope unmarshals to a bson.CodeWithScope.
|
||||||
|
// 14. BSON timestamp unmarshals to an bson.Timestamp.
|
||||||
|
// 15. BSON 128-bit decimal unmarshals to an bson.Decimal128.
|
||||||
|
// 16. BSON min key unmarshals to an bson.MinKey.
|
||||||
|
// 17. BSON max key unmarshals to an bson.MaxKey.
|
||||||
|
// 18. BSON undefined unmarshals to a bson.Undefined.
|
||||||
|
// 19. BSON null unmarshals to nil.
|
||||||
|
// 20. BSON DBPointer unmarshals to a bson.DBPointer.
|
||||||
|
// 21. BSON symbol unmarshals to a bson.Symbol.
|
||||||
|
//
|
||||||
|
// The above mappings also apply when marshaling a D or M to BSON. Some other useful marshaling mappings are:
|
||||||
|
//
|
||||||
|
// 1. time.Time marshals to a BSON datetime.
|
||||||
|
// 2. int8, int16, and int32 marshal to a BSON int32.
|
||||||
|
// 3. int marshals to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, inclusive, and a BSON int64
|
||||||
|
// otherwise.
|
||||||
|
// 4. int64 marshals to BSON int64 (unless [Encoder.IntMinSize] is set).
|
||||||
|
// 5. uint8 and uint16 marshal to a BSON int32.
|
||||||
|
// 6. uint, uint32, and uint64 marshal to a BSON int64 (unless [Encoder.IntMinSize] is set).
|
||||||
|
// 7. BSON null and undefined values will unmarshal into the zero value of a field (e.g. unmarshaling a BSON null or
|
||||||
|
// undefined value into a string will yield the empty string.).
|
||||||
|
//
|
||||||
|
// # Structs
|
||||||
|
//
|
||||||
|
// Structs can be marshaled/unmarshaled to/from BSON or Extended JSON. When transforming structs to/from BSON or Extended
|
||||||
|
// JSON, the following rules apply:
|
||||||
|
//
|
||||||
|
// 1. Only exported fields in structs will be marshaled or unmarshaled.
|
||||||
|
//
|
||||||
|
// 2. When marshaling a struct, each field will be lowercased to generate the key for the corresponding BSON element.
|
||||||
|
// For example, a struct field named "Foo" will generate key "foo". This can be overridden via a struct tag (e.g.
|
||||||
|
// `bson:"fooField"` to generate key "fooField" instead).
|
||||||
|
//
|
||||||
|
// 3. An embedded struct field is marshaled as a subdocument. The key will be the lowercased name of the field's type.
|
||||||
|
//
|
||||||
|
// 4. A pointer field is marshaled as the underlying type if the pointer is non-nil. If the pointer is nil, it is
|
||||||
|
// marshaled as a BSON null value.
|
||||||
|
//
|
||||||
|
// 5. When unmarshaling, a field of type interface{} will follow the D/M type mappings listed above. BSON documents
|
||||||
|
// unmarshaled into an interface{} field will be unmarshaled as a D.
|
||||||
|
//
|
||||||
|
// The encoding of each struct field can be customized by the "bson" struct tag.
|
||||||
|
//
|
||||||
|
// This tag behavior is configurable, and different struct tag behavior can be configured by initializing a new
|
||||||
|
// bsoncodec.StructCodec with the desired tag parser and registering that StructCodec onto the Registry. By default, JSON
|
||||||
|
// tags are not honored, but that can be enabled by creating a StructCodec with JSONFallbackStructTagParser, like below:
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// structcodec, _ := bsoncodec.NewStructCodec(bsoncodec.JSONFallbackStructTagParser)
|
||||||
|
//
|
||||||
|
// The bson tag gives the name of the field, possibly followed by a comma-separated list of options.
|
||||||
|
// The name may be empty in order to specify options without overriding the default field name. The following options can
|
||||||
|
// be used to configure behavior:
|
||||||
|
//
|
||||||
|
// 1. omitempty: If the "omitempty" struct tag is specified on a field, the field will not be marshaled if it is set to
|
||||||
|
// an "empty" value. Numbers, booleans, and strings are considered empty if their value is equal to the zero value for
|
||||||
|
// the type (i.e. 0 for numbers, false for booleans, and "" for strings). Slices, maps, and arrays are considered
|
||||||
|
// empty if they are of length zero. Interfaces and pointers are considered empty if their value is nil. By default,
|
||||||
|
// structs are only considered empty if the struct type implements [bsoncodec.Zeroer] and the IsZero
|
||||||
|
// method returns true. Struct types that do not implement [bsoncodec.Zeroer] are never considered empty and will be
|
||||||
|
// marshaled as embedded documents. NOTE: It is recommended that this tag be used for all slice and map fields.
|
||||||
|
//
|
||||||
|
// 2. minsize: If the minsize struct tag is specified on a field of type int64, uint, uint32, or uint64 and the value of
|
||||||
|
// the field can fit in a signed int32, the field will be serialized as a BSON int32 rather than a BSON int64. For
|
||||||
|
// other types, this tag is ignored.
|
||||||
|
//
|
||||||
|
// 3. truncate: If the truncate struct tag is specified on a field with a non-float numeric type, BSON doubles
|
||||||
|
// unmarshaled into that field will be truncated at the decimal point. For example, if 3.14 is unmarshaled into a
|
||||||
|
// field of type int, it will be unmarshaled as 3. If this tag is not specified, the decoder will throw an error if
|
||||||
|
// the value cannot be decoded without losing precision. For float64 or non-numeric types, this tag is ignored.
|
||||||
|
//
|
||||||
|
// 4. inline: If the inline struct tag is specified for a struct or map field, the field will be "flattened" when
|
||||||
|
// marshaling and "un-flattened" when unmarshaling. This means that all of the fields in that struct/map will be
|
||||||
|
// pulled up one level and will become top-level fields rather than being fields in a nested document. For example,
|
||||||
|
// if a map field named "Map" with value map[string]interface{}{"foo": "bar"} is inlined, the resulting document will
|
||||||
|
// be {"foo": "bar"} instead of {"map": {"foo": "bar"}}. There can only be one inlined map field in a struct. If
|
||||||
|
// there are duplicated fields in the resulting document when an inlined struct is marshaled, the inlined field will
|
||||||
|
// be overwritten. If there are duplicated fields in the resulting document when an inlined map is marshaled, an
|
||||||
|
// error will be returned. This tag can be used with fields that are pointers to structs. If an inlined pointer field
|
||||||
|
// is nil, it will not be marshaled. For fields that are not maps or structs, this tag is ignored.
|
||||||
|
//
|
||||||
|
// # Marshaling and Unmarshaling
|
||||||
|
//
|
||||||
|
// Manually marshaling and unmarshaling can be done with the Marshal and Unmarshal family of functions.
|
||||||
|
//
|
||||||
|
// bsoncodec code provides a system for encoding values to BSON representations and decoding
|
||||||
|
// values from BSON representations. This package considers both binary BSON and ExtendedJSON as
|
||||||
|
// BSON representations. The types in this package enable a flexible system for handling this
|
||||||
|
// encoding and decoding.
|
||||||
|
//
|
||||||
|
// The codec system is composed of two parts:
|
||||||
|
//
|
||||||
|
// 1) [ValueEncoder] and [ValueDecoder] that handle encoding and decoding Go values to and from BSON
|
||||||
|
// representations.
|
||||||
|
//
|
||||||
|
// 2) A [Registry] that holds these ValueEncoders and ValueDecoders and provides methods for
|
||||||
|
// retrieving them.
|
||||||
|
//
|
||||||
|
// [Work with BSON]: https://www.mongodb.com/docs/drivers/go/current/fundamentals/bson/
|
||||||
|
package bson
|
127
empty_interface_codec.go
Normal file
127
empty_interface_codec.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// emptyInterfaceCodec is the Codec used for interface{} values.
|
||||||
|
type emptyInterfaceCodec struct {
|
||||||
|
// decodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the
|
||||||
|
// "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary.
|
||||||
|
decodeBinaryAsSlice bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that emptyInterfaceCodec satisfies the typeDecoder interface, which allows it
|
||||||
|
// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a
|
||||||
|
// collection.
|
||||||
|
var _ typeDecoder = &emptyInterfaceCodec{}
|
||||||
|
|
||||||
|
// EncodeValue is the ValueEncoderFunc for interface{}.
|
||||||
|
func (eic *emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||||
|
if !val.IsValid() || val.Type() != tEmpty {
|
||||||
|
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.IsNil() {
|
||||||
|
return vw.WriteNull()
|
||||||
|
}
|
||||||
|
encoder, err := ec.LookupEncoder(val.Elem().Type())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return encoder.EncodeValue(ec, vw, val.Elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) {
|
||||||
|
isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument
|
||||||
|
if isDocument {
|
||||||
|
if dc.defaultDocumentType != nil {
|
||||||
|
// If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return
|
||||||
|
// that type.
|
||||||
|
return dc.defaultDocumentType, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rtype, err := dc.LookupTypeMapEntry(valueType)
|
||||||
|
if err == nil {
|
||||||
|
return rtype, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isDocument {
|
||||||
|
// For documents, fallback to looking up a type map entry for Type(0) or TypeEmbeddedDocument,
|
||||||
|
// depending on the original valueType.
|
||||||
|
var lookupType Type
|
||||||
|
switch valueType {
|
||||||
|
case Type(0):
|
||||||
|
lookupType = TypeEmbeddedDocument
|
||||||
|
case TypeEmbeddedDocument:
|
||||||
|
lookupType = Type(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
rtype, err = dc.LookupTypeMapEntry(lookupType)
|
||||||
|
if err == nil {
|
||||||
|
return rtype, nil
|
||||||
|
}
|
||||||
|
// fallback to bson.D
|
||||||
|
return tD, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (eic *emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
|
||||||
|
if t != tEmpty {
|
||||||
|
return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)}
|
||||||
|
}
|
||||||
|
|
||||||
|
rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type())
|
||||||
|
if err != nil {
|
||||||
|
switch vr.Type() {
|
||||||
|
case TypeNull:
|
||||||
|
return reflect.Zero(t), vr.ReadNull()
|
||||||
|
default:
|
||||||
|
return emptyValue, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder, err := dc.LookupDecoder(rtype)
|
||||||
|
if err != nil {
|
||||||
|
return emptyValue, err
|
||||||
|
}
|
||||||
|
|
||||||
|
elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, rtype)
|
||||||
|
if err != nil {
|
||||||
|
return emptyValue, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if (eic.decodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary {
|
||||||
|
binElem := elem.Interface().(Binary)
|
||||||
|
if binElem.Subtype == TypeBinaryGeneric || binElem.Subtype == TypeBinaryBinaryOld {
|
||||||
|
elem = reflect.ValueOf(binElem.Data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return elem, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeValue is the ValueDecoderFunc for interface{}.
|
||||||
|
func (eic *emptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||||
|
if !val.CanSet() || val.Type() != tEmpty {
|
||||||
|
return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val}
|
||||||
|
}
|
||||||
|
|
||||||
|
elem, err := eic.decodeType(dc, vr, val.Type())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
val.Set(elem)
|
||||||
|
return nil
|
||||||
|
}
|
123
encoder.go
Normal file
123
encoder.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This pool is used to keep the allocations of Encoders down. This is only used for the Marshal*
|
||||||
|
// methods and is not consumable from outside of this package. The Encoders retrieved from this pool
|
||||||
|
// must have both Reset and SetRegistry called on them.
|
||||||
|
var encPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return new(Encoder)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Encoder writes a serialization format to an output stream. It writes to a ValueWriter
|
||||||
|
// as the destination of BSON data.
|
||||||
|
type Encoder struct {
|
||||||
|
ec EncodeContext
|
||||||
|
vw ValueWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEncoder returns a new encoder that writes to vw.
|
||||||
|
func NewEncoder(vw ValueWriter) *Encoder {
|
||||||
|
return &Encoder{
|
||||||
|
ec: EncodeContext{Registry: defaultRegistry},
|
||||||
|
vw: vw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode writes the BSON encoding of val to the stream.
|
||||||
|
//
|
||||||
|
// See [Marshal] for details about BSON marshaling behavior.
|
||||||
|
func (e *Encoder) Encode(val interface{}) error {
|
||||||
|
if marshaler, ok := val.(Marshaler); ok {
|
||||||
|
// TODO(skriptble): Should we have a MarshalAppender interface so that we can have []byte reuse?
|
||||||
|
buf, err := marshaler.MarshalBSON()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return copyDocumentFromBytes(e.vw, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset will reset the state of the Encoder, using the same *EncodeContext used in
|
||||||
|
// the original construction but using vw.
|
||||||
|
func (e *Encoder) Reset(vw ValueWriter) {
|
||||||
|
e.vw = vw
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRegistry replaces the current registry of the Encoder with r.
|
||||||
|
func (e *Encoder) SetRegistry(r *Registry) {
|
||||||
|
e.ec.Registry = r
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in
|
||||||
|
// the marshaled BSON when the "inline" struct tag option is set.
|
||||||
|
func (e *Encoder) ErrorOnInlineDuplicates() {
|
||||||
|
e.ec.errorOnInlineDuplicates = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint,
|
||||||
|
// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can
|
||||||
|
// represent the integer value.
|
||||||
|
func (e *Encoder) IntMinSize() {
|
||||||
|
e.ec.minSize = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name
|
||||||
|
// strings using fmt.Sprint instead of the default string conversion logic.
|
||||||
|
func (e *Encoder) StringifyMapKeysWithFmt() {
|
||||||
|
e.ec.stringifyMapKeysWithFmt = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON
|
||||||
|
// null.
|
||||||
|
func (e *Encoder) NilMapAsEmpty() {
|
||||||
|
e.ec.nilMapAsEmpty = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON
|
||||||
|
// null.
|
||||||
|
func (e *Encoder) NilSliceAsEmpty() {
|
||||||
|
e.ec.nilSliceAsEmpty = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values
|
||||||
|
// instead of BSON null.
|
||||||
|
func (e *Encoder) NilByteSliceAsEmpty() {
|
||||||
|
e.ec.nilByteSliceAsEmpty = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported
|
||||||
|
// TODO struct fields once the logic is updated to also inspect private struct fields.
|
||||||
|
|
||||||
|
// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{})
|
||||||
|
// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set.
|
||||||
|
//
|
||||||
|
// Note that the Encoder only examines exported struct fields when determining if a struct is the
|
||||||
|
// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty.
|
||||||
|
func (e *Encoder) OmitZeroStruct() {
|
||||||
|
e.ec.omitZeroStruct = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson"
|
||||||
|
// struct tag is not specified.
|
||||||
|
func (e *Encoder) UseJSONStructTags() {
|
||||||
|
e.ec.useJSONStructTags = true
|
||||||
|
}
|
240
encoder_example_test.go
Normal file
240
encoder_example_test.go
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExampleEncoder() {
|
||||||
|
// Create an Encoder that writes BSON values to a bytes.Buffer.
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
vw := bson.NewDocumentWriter(buf)
|
||||||
|
encoder := bson.NewEncoder(vw)
|
||||||
|
|
||||||
|
type Product struct {
|
||||||
|
Name string `bson:"name"`
|
||||||
|
SKU string `bson:"sku"`
|
||||||
|
Price int64 `bson:"price_cents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the Encoder to marshal a BSON document that contains the name, SKU,
|
||||||
|
// and price (in cents) of a product.
|
||||||
|
product := Product{
|
||||||
|
Name: "Cereal Rounds",
|
||||||
|
SKU: "AB12345",
|
||||||
|
Price: 399,
|
||||||
|
}
|
||||||
|
err := encoder.Encode(product)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the BSON document as Extended JSON by converting it to bson.Raw.
|
||||||
|
fmt.Println(bson.Raw(buf.Bytes()).String())
|
||||||
|
// Output: {"name": "Cereal Rounds","sku": "AB12345","price_cents": {"$numberLong":"399"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CityState struct {
|
||||||
|
City string
|
||||||
|
State string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k CityState) String() string {
|
||||||
|
return fmt.Sprintf("%s, %s", k.City, k.State)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleEncoder_StringifyMapKeysWithFmt() {
|
||||||
|
// Create an Encoder that writes BSON values to a bytes.Buffer.
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
vw := bson.NewDocumentWriter(buf)
|
||||||
|
encoder := bson.NewEncoder(vw)
|
||||||
|
|
||||||
|
// Configure the Encoder to convert Go map keys to BSON document field names
|
||||||
|
// using fmt.Sprintf instead of the default string conversion logic.
|
||||||
|
encoder.StringifyMapKeysWithFmt()
|
||||||
|
|
||||||
|
// Use the Encoder to marshal a BSON document that contains is a map of
|
||||||
|
// city and state to a list of zip codes in that city.
|
||||||
|
zipCodes := map[CityState][]int{
|
||||||
|
{City: "New York", State: "NY"}: {10001, 10301, 10451},
|
||||||
|
}
|
||||||
|
err := encoder.Encode(zipCodes)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the BSON document as Extended JSON by converting it to bson.Raw.
|
||||||
|
fmt.Println(bson.Raw(buf.Bytes()).String())
|
||||||
|
// Output: {"New York, NY": [{"$numberInt":"10001"},{"$numberInt":"10301"},{"$numberInt":"10451"}]}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleEncoder_UseJSONStructTags() {
|
||||||
|
// Create an Encoder that writes BSON values to a bytes.Buffer.
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
vw := bson.NewDocumentWriter(buf)
|
||||||
|
encoder := bson.NewEncoder(vw)
|
||||||
|
|
||||||
|
type Product struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
SKU string `json:"sku"`
|
||||||
|
Price int64 `json:"price_cents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure the Encoder to use "json" struct tags when decoding if "bson"
|
||||||
|
// struct tags are not present.
|
||||||
|
encoder.UseJSONStructTags()
|
||||||
|
|
||||||
|
// Use the Encoder to marshal a BSON document that contains the name, SKU,
|
||||||
|
// and price (in cents) of a product.
|
||||||
|
product := Product{
|
||||||
|
Name: "Cereal Rounds",
|
||||||
|
SKU: "AB12345",
|
||||||
|
Price: 399,
|
||||||
|
}
|
||||||
|
err := encoder.Encode(product)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print the BSON document as Extended JSON by converting it to bson.Raw.
|
||||||
|
fmt.Println(bson.Raw(buf.Bytes()).String())
|
||||||
|
// Output: {"name": "Cereal Rounds","sku": "AB12345","price_cents": {"$numberLong":"399"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleEncoder_multipleBSONDocuments() {
|
||||||
|
// Create an Encoder that writes BSON values to a bytes.Buffer.
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
vw := bson.NewDocumentWriter(buf)
|
||||||
|
encoder := bson.NewEncoder(vw)
|
||||||
|
|
||||||
|
type Coordinate struct {
|
||||||
|
X int
|
||||||
|
Y int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the encoder to marshal 5 Coordinate values as a sequence of BSON
|
||||||
|
// documents.
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
err := encoder.Encode(Coordinate{
|
||||||
|
X: i,
|
||||||
|
Y: i + 1,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read each marshaled BSON document from the buffer and print them as
|
||||||
|
// Extended JSON by converting them to bson.Raw.
|
||||||
|
for {
|
||||||
|
doc, err := bson.ReadDocument(buf)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Println(doc.String())
|
||||||
|
}
|
||||||
|
// Output:
|
||||||
|
// {"x": {"$numberInt":"0"},"y": {"$numberInt":"1"}}
|
||||||
|
// {"x": {"$numberInt":"1"},"y": {"$numberInt":"2"}}
|
||||||
|
// {"x": {"$numberInt":"2"},"y": {"$numberInt":"3"}}
|
||||||
|
// {"x": {"$numberInt":"3"},"y": {"$numberInt":"4"}}
|
||||||
|
// {"x": {"$numberInt":"4"},"y": {"$numberInt":"5"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleEncoder_extendedJSON() {
|
||||||
|
// Create an Encoder that writes canonical Extended JSON values to a
|
||||||
|
// bytes.Buffer.
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
vw := bson.NewExtJSONValueWriter(buf, true, false)
|
||||||
|
encoder := bson.NewEncoder(vw)
|
||||||
|
|
||||||
|
type Product struct {
|
||||||
|
Name string `bson:"name"`
|
||||||
|
SKU string `bson:"sku"`
|
||||||
|
Price int64 `bson:"price_cents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the Encoder to marshal a BSON document that contains the name, SKU,
|
||||||
|
// and price (in cents) of a product.
|
||||||
|
product := Product{
|
||||||
|
Name: "Cereal Rounds",
|
||||||
|
SKU: "AB12345",
|
||||||
|
Price: 399,
|
||||||
|
}
|
||||||
|
err := encoder.Encode(product)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(buf.String())
|
||||||
|
// Output: {"name":"Cereal Rounds","sku":"AB12345","price_cents":{"$numberLong":"399"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleEncoder_multipleExtendedJSONDocuments() {
|
||||||
|
// Create an Encoder that writes canonical Extended JSON values to a
|
||||||
|
// bytes.Buffer.
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
vw := bson.NewExtJSONValueWriter(buf, true, false)
|
||||||
|
encoder := bson.NewEncoder(vw)
|
||||||
|
|
||||||
|
type Coordinate struct {
|
||||||
|
X int
|
||||||
|
Y int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the encoder to marshal 5 Coordinate values as a sequence of Extended
|
||||||
|
// JSON documents.
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
err := encoder.Encode(Coordinate{
|
||||||
|
X: i,
|
||||||
|
Y: i + 1,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(buf.String())
|
||||||
|
// Output:
|
||||||
|
// {"x":{"$numberInt":"0"},"y":{"$numberInt":"1"}}
|
||||||
|
// {"x":{"$numberInt":"1"},"y":{"$numberInt":"2"}}
|
||||||
|
// {"x":{"$numberInt":"2"},"y":{"$numberInt":"3"}}
|
||||||
|
// {"x":{"$numberInt":"3"},"y":{"$numberInt":"4"}}
|
||||||
|
// {"x":{"$numberInt":"4"},"y":{"$numberInt":"5"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleEncoder_IntMinSize() {
|
||||||
|
// Create an encoder that will marshal integers as the minimum BSON int size
|
||||||
|
// (either 32 or 64 bits) that can represent the integer value.
|
||||||
|
type foo struct {
|
||||||
|
Bar uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
vw := bson.NewDocumentWriter(buf)
|
||||||
|
|
||||||
|
enc := bson.NewEncoder(vw)
|
||||||
|
enc.IntMinSize()
|
||||||
|
|
||||||
|
err := enc.Encode(foo{2})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(bson.Raw(buf.Bytes()).String())
|
||||||
|
// Output:
|
||||||
|
// {"bar": {"$numberInt":"2"}}
|
||||||
|
}
|
303
encoder_test.go
Normal file
303
encoder_test.go
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/require"
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBasicEncode(t *testing.T) {
|
||||||
|
for _, tc := range marshalingTestCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := make(sliceWriter, 0, 1024)
|
||||||
|
vw := NewDocumentWriter(&got)
|
||||||
|
reg := defaultRegistry
|
||||||
|
encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val))
|
||||||
|
noerr(t, err)
|
||||||
|
err = encoder.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.val))
|
||||||
|
noerr(t, err)
|
||||||
|
|
||||||
|
if !bytes.Equal(got, tc.want) {
|
||||||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
|
||||||
|
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncoderEncode(t *testing.T) {
|
||||||
|
for _, tc := range marshalingTestCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got := make(sliceWriter, 0, 1024)
|
||||||
|
vw := NewDocumentWriter(&got)
|
||||||
|
enc := NewEncoder(vw)
|
||||||
|
err := enc.Encode(tc.val)
|
||||||
|
noerr(t, err)
|
||||||
|
|
||||||
|
if !bytes.Equal(got, tc.want) {
|
||||||
|
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
|
||||||
|
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("Marshaler", func(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
buf []byte
|
||||||
|
err error
|
||||||
|
wanterr error
|
||||||
|
vw ValueWriter
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"error",
|
||||||
|
nil,
|
||||||
|
errors.New("Marshaler error"),
|
||||||
|
errors.New("Marshaler error"),
|
||||||
|
&valueReaderWriter{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"copy error",
|
||||||
|
[]byte{0x05, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
nil,
|
||||||
|
errors.New("copy error"),
|
||||||
|
&valueReaderWriter{Err: errors.New("copy error"), ErrAfter: writeDocument},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"success",
|
||||||
|
[]byte{0x07, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00},
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
marshaler := testMarshaler{buf: tc.buf, err: tc.err}
|
||||||
|
|
||||||
|
var vw ValueWriter
|
||||||
|
b := make(sliceWriter, 0, 100)
|
||||||
|
compareVW := false
|
||||||
|
if tc.vw != nil {
|
||||||
|
vw = tc.vw
|
||||||
|
} else {
|
||||||
|
compareVW = true
|
||||||
|
vw = NewDocumentWriter(&b)
|
||||||
|
}
|
||||||
|
enc := NewEncoder(vw)
|
||||||
|
got := enc.Encode(marshaler)
|
||||||
|
want := tc.wanterr
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
if compareVW {
|
||||||
|
buf := b
|
||||||
|
if !bytes.Equal(buf, tc.buf) {
|
||||||
|
t.Errorf("Copied bytes do not match. got %v; want %v", buf, tc.buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type testMarshaler struct {
|
||||||
|
buf []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm testMarshaler) MarshalBSON() ([]byte, error) { return tm.buf, tm.err }
|
||||||
|
|
||||||
|
func docToBytes(d interface{}) []byte {
|
||||||
|
b, err := Marshal(d)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
type stringerTest struct{}
|
||||||
|
|
||||||
|
func (stringerTest) String() string {
|
||||||
|
return "test key"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncoderConfiguration(t *testing.T) {
|
||||||
|
type inlineDuplicateInner struct {
|
||||||
|
Duplicate string
|
||||||
|
}
|
||||||
|
|
||||||
|
type inlineDuplicateOuter struct {
|
||||||
|
Inline inlineDuplicateInner `bson:",inline"`
|
||||||
|
Duplicate string
|
||||||
|
}
|
||||||
|
|
||||||
|
type zeroStruct struct {
|
||||||
|
MyString string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
description string
|
||||||
|
configure func(*Encoder)
|
||||||
|
input interface{}
|
||||||
|
want []byte
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
// Test that ErrorOnInlineDuplicates causes the Encoder to return an error if there are any
|
||||||
|
// duplicate fields in the marshaled document caused by using the "inline" struct tag.
|
||||||
|
{
|
||||||
|
description: "ErrorOnInlineDuplicates",
|
||||||
|
configure: func(enc *Encoder) {
|
||||||
|
enc.ErrorOnInlineDuplicates()
|
||||||
|
},
|
||||||
|
input: inlineDuplicateOuter{
|
||||||
|
Inline: inlineDuplicateInner{Duplicate: "inner"},
|
||||||
|
Duplicate: "outer",
|
||||||
|
},
|
||||||
|
wantErr: errors.New("struct bson.inlineDuplicateOuter has duplicated key duplicate"),
|
||||||
|
},
|
||||||
|
// Test that IntMinSize encodes Go int and int64 values as BSON int32 if the value is small
|
||||||
|
// enough.
|
||||||
|
{
|
||||||
|
description: "IntMinSize",
|
||||||
|
configure: func(enc *Encoder) {
|
||||||
|
enc.IntMinSize()
|
||||||
|
},
|
||||||
|
input: D{
|
||||||
|
{Key: "myInt", Value: int(1)},
|
||||||
|
{Key: "myInt64", Value: int64(1)},
|
||||||
|
{Key: "myUint", Value: uint(1)},
|
||||||
|
{Key: "myUint32", Value: uint32(1)},
|
||||||
|
{Key: "myUint64", Value: uint64(1)},
|
||||||
|
},
|
||||||
|
want: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendInt32("myInt", 1).
|
||||||
|
AppendInt32("myInt64", 1).
|
||||||
|
AppendInt32("myUint", 1).
|
||||||
|
AppendInt32("myUint32", 1).
|
||||||
|
AppendInt32("myUint64", 1).
|
||||||
|
Build(),
|
||||||
|
},
|
||||||
|
// Test that StringifyMapKeysWithFmt uses fmt.Sprint to convert map keys to BSON field names.
|
||||||
|
{
|
||||||
|
description: "StringifyMapKeysWithFmt",
|
||||||
|
configure: func(enc *Encoder) {
|
||||||
|
enc.StringifyMapKeysWithFmt()
|
||||||
|
},
|
||||||
|
input: map[stringerTest]string{
|
||||||
|
{}: "test value",
|
||||||
|
},
|
||||||
|
want: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendString("test key", "test value").
|
||||||
|
Build(),
|
||||||
|
},
|
||||||
|
// Test that NilMapAsEmpty encodes nil Go maps as empty BSON documents.
|
||||||
|
{
|
||||||
|
description: "NilMapAsEmpty",
|
||||||
|
configure: func(enc *Encoder) {
|
||||||
|
enc.NilMapAsEmpty()
|
||||||
|
},
|
||||||
|
input: D{{Key: "myMap", Value: map[string]string(nil)}},
|
||||||
|
want: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendDocument("myMap", bsoncore.NewDocumentBuilder().Build()).
|
||||||
|
Build(),
|
||||||
|
},
|
||||||
|
// Test that NilSliceAsEmpty encodes nil Go slices as empty BSON arrays.
|
||||||
|
{
|
||||||
|
description: "NilSliceAsEmpty",
|
||||||
|
configure: func(enc *Encoder) {
|
||||||
|
enc.NilSliceAsEmpty()
|
||||||
|
},
|
||||||
|
input: D{{Key: "mySlice", Value: []string(nil)}},
|
||||||
|
want: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendArray("mySlice", bsoncore.NewArrayBuilder().Build()).
|
||||||
|
Build(),
|
||||||
|
},
|
||||||
|
// Test that NilByteSliceAsEmpty encodes nil Go byte slices as empty BSON binary elements.
|
||||||
|
{
|
||||||
|
description: "NilByteSliceAsEmpty",
|
||||||
|
configure: func(enc *Encoder) {
|
||||||
|
enc.NilByteSliceAsEmpty()
|
||||||
|
},
|
||||||
|
input: D{{Key: "myBytes", Value: []byte(nil)}},
|
||||||
|
want: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendBinary("myBytes", TypeBinaryGeneric, []byte{}).
|
||||||
|
Build(),
|
||||||
|
},
|
||||||
|
// Test that OmitZeroStruct omits empty structs from the marshaled document if the
|
||||||
|
// "omitempty" struct tag is used.
|
||||||
|
{
|
||||||
|
description: "OmitZeroStruct",
|
||||||
|
configure: func(enc *Encoder) {
|
||||||
|
enc.OmitZeroStruct()
|
||||||
|
},
|
||||||
|
input: struct {
|
||||||
|
Zero zeroStruct `bson:",omitempty"`
|
||||||
|
}{},
|
||||||
|
want: bsoncore.NewDocumentBuilder().Build(),
|
||||||
|
},
|
||||||
|
// Test that UseJSONStructTags causes the Encoder to fall back to "json" struct tags if
|
||||||
|
// "bson" struct tags are not available.
|
||||||
|
{
|
||||||
|
description: "UseJSONStructTags",
|
||||||
|
configure: func(enc *Encoder) {
|
||||||
|
enc.UseJSONStructTags()
|
||||||
|
},
|
||||||
|
input: struct {
|
||||||
|
StructFieldName string `json:"jsonFieldName"`
|
||||||
|
}{
|
||||||
|
StructFieldName: "test value",
|
||||||
|
},
|
||||||
|
want: bsoncore.NewDocumentBuilder().
|
||||||
|
AppendString("jsonFieldName", "test value").
|
||||||
|
Build(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc // Capture range variable.
|
||||||
|
|
||||||
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := new(bytes.Buffer)
|
||||||
|
vw := NewDocumentWriter(got)
|
||||||
|
enc := NewEncoder(vw)
|
||||||
|
|
||||||
|
tc.configure(enc)
|
||||||
|
|
||||||
|
err := enc.Encode(tc.input)
|
||||||
|
if tc.wantErr != nil {
|
||||||
|
assert.Equal(t, tc.wantErr, err, "expected and actual errors do not match")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err, "Encode error")
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, got.Bytes(), "expected and actual encoded BSON do not match")
|
||||||
|
|
||||||
|
// After we compare the raw bytes, also decode the expected and actual BSON as a bson.D
|
||||||
|
// and compare them. The goal is to make assertion failures easier to debug because
|
||||||
|
// binary diffs are very difficult to understand.
|
||||||
|
var wantDoc D
|
||||||
|
err = Unmarshal(tc.want, &wantDoc)
|
||||||
|
require.NoError(t, err, "Unmarshal error")
|
||||||
|
var gotDoc D
|
||||||
|
err = Unmarshal(got.Bytes(), &gotDoc)
|
||||||
|
require.NoError(t, err, "Unmarshal error")
|
||||||
|
|
||||||
|
assert.Equal(t, wantDoc, gotDoc, "expected and actual decoded documents do not match")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
143
example_test.go
Normal file
143
example_test.go
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This example uses Raw to skip parsing a nested document in a BSON message.
|
||||||
|
func ExampleRaw_unmarshal() {
|
||||||
|
b, err := bson.Marshal(bson.M{
|
||||||
|
"Word": "beach",
|
||||||
|
"Synonyms": bson.A{"coast", "shore", "waterfront"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var res struct {
|
||||||
|
Word string
|
||||||
|
Synonyms bson.Raw // Don't parse the whole list, we just want to count the elements.
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bson.Unmarshal(b, &res)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
elems, err := res.Synonyms.Elements()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("%s, synonyms count: %d\n", res.Word, len(elems))
|
||||||
|
|
||||||
|
// Output: beach, synonyms count: 3
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example uses Raw to add a precomputed BSON document during marshal.
|
||||||
|
func ExampleRaw_marshal() {
|
||||||
|
precomputed, err := bson.Marshal(bson.M{"Precomputed": true})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := struct {
|
||||||
|
Message string
|
||||||
|
Metadata bson.Raw
|
||||||
|
}{
|
||||||
|
Message: "Hello World!",
|
||||||
|
Metadata: precomputed,
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := bson.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
// Print the Extended JSON by converting BSON to bson.Raw.
|
||||||
|
fmt.Println(bson.Raw(b).String())
|
||||||
|
|
||||||
|
// Output: {"message": "Hello World!","metadata": {"Precomputed": true}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example uses RawValue to delay parsing a value in a BSON message.
|
||||||
|
func ExampleRawValue_unmarshal() {
|
||||||
|
b1, err := bson.Marshal(bson.M{
|
||||||
|
"Format": "UNIX",
|
||||||
|
"Timestamp": 1675282389,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b2, err := bson.Marshal(bson.M{
|
||||||
|
"Format": "RFC3339",
|
||||||
|
"Timestamp": time.Unix(1675282389, 0).Format(time.RFC3339),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range [][]byte{b1, b2} {
|
||||||
|
var res struct {
|
||||||
|
Format string
|
||||||
|
Timestamp bson.RawValue // Delay parsing until we know the timestamp format.
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bson.Unmarshal(b, &res)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var t time.Time
|
||||||
|
switch res.Format {
|
||||||
|
case "UNIX":
|
||||||
|
t = time.Unix(res.Timestamp.AsInt64(), 0)
|
||||||
|
case "RFC3339":
|
||||||
|
t, err = time.Parse(time.RFC3339, res.Timestamp.StringValue())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println(res.Format, t.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// UNIX 1675282389
|
||||||
|
// RFC3339 1675282389
|
||||||
|
}
|
||||||
|
|
||||||
|
// This example uses RawValue to add a precomputed BSON string value during marshal.
|
||||||
|
func ExampleRawValue_marshal() {
|
||||||
|
t, val, err := bson.MarshalValue("Precomputed message!")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
precomputed := bson.RawValue{
|
||||||
|
Type: t,
|
||||||
|
Value: val,
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := struct {
|
||||||
|
Message bson.RawValue
|
||||||
|
Time time.Time
|
||||||
|
}{
|
||||||
|
Message: precomputed,
|
||||||
|
Time: time.Unix(1675282389, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := bson.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
// Print the Extended JSON by converting BSON to bson.Raw.
|
||||||
|
fmt.Println(bson.Raw(b).String())
|
||||||
|
|
||||||
|
// Output: {"message": "Precomputed message!","time": {"$date":{"$numberLong":"1675282389000"}}}
|
||||||
|
}
|
804
extjson_parser.go
Normal file
804
extjson_parser.go
Normal file
@ -0,0 +1,804 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxNestingDepth = 200
|
||||||
|
|
||||||
|
// ErrInvalidJSON indicates the JSON input is invalid
|
||||||
|
var ErrInvalidJSON = errors.New("invalid JSON input")
|
||||||
|
|
||||||
|
type jsonParseState byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
jpsStartState jsonParseState = iota
|
||||||
|
jpsSawBeginObject
|
||||||
|
jpsSawEndObject
|
||||||
|
jpsSawBeginArray
|
||||||
|
jpsSawEndArray
|
||||||
|
jpsSawColon
|
||||||
|
jpsSawComma
|
||||||
|
jpsSawKey
|
||||||
|
jpsSawValue
|
||||||
|
jpsDoneState
|
||||||
|
jpsInvalidState
|
||||||
|
)
|
||||||
|
|
||||||
|
type jsonParseMode byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
jpmInvalidMode jsonParseMode = iota
|
||||||
|
jpmObjectMode
|
||||||
|
jpmArrayMode
|
||||||
|
)
|
||||||
|
|
||||||
|
type extJSONValue struct {
|
||||||
|
t Type
|
||||||
|
v interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type extJSONObject struct {
|
||||||
|
keys []string
|
||||||
|
values []*extJSONValue
|
||||||
|
}
|
||||||
|
|
||||||
|
type extJSONParser struct {
|
||||||
|
js *jsonScanner
|
||||||
|
s jsonParseState
|
||||||
|
m []jsonParseMode
|
||||||
|
k string
|
||||||
|
v *extJSONValue
|
||||||
|
|
||||||
|
err error
|
||||||
|
canonicalOnly bool
|
||||||
|
depth int
|
||||||
|
maxDepth int
|
||||||
|
|
||||||
|
emptyObject bool
|
||||||
|
relaxedUUID bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// newExtJSONParser returns a new extended JSON parser, ready to to begin
|
||||||
|
// parsing from the first character of the argued json input. It will not
|
||||||
|
// perform any read-ahead and will therefore not report any errors about
|
||||||
|
// malformed JSON at this point.
|
||||||
|
func newExtJSONParser(r io.Reader, canonicalOnly bool) *extJSONParser {
|
||||||
|
return &extJSONParser{
|
||||||
|
js: &jsonScanner{r: r},
|
||||||
|
s: jpsStartState,
|
||||||
|
m: []jsonParseMode{},
|
||||||
|
canonicalOnly: canonicalOnly,
|
||||||
|
maxDepth: maxNestingDepth,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// peekType examines the next value and returns its BSON Type
|
||||||
|
func (ejp *extJSONParser) peekType() (Type, error) {
|
||||||
|
var t Type
|
||||||
|
var err error
|
||||||
|
initialState := ejp.s
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawValue:
|
||||||
|
t = ejp.v.t
|
||||||
|
case jpsSawBeginArray:
|
||||||
|
t = TypeArray
|
||||||
|
case jpsInvalidState:
|
||||||
|
err = ejp.err
|
||||||
|
case jpsSawComma:
|
||||||
|
// in array mode, seeing a comma means we need to progress again to actually observe a type
|
||||||
|
if ejp.peekMode() == jpmArrayMode {
|
||||||
|
return ejp.peekType()
|
||||||
|
}
|
||||||
|
case jpsSawEndArray:
|
||||||
|
// this would only be a valid state if we were in array mode, so return end-of-array error
|
||||||
|
err = ErrEOA
|
||||||
|
case jpsSawBeginObject:
|
||||||
|
// peek key to determine type
|
||||||
|
ejp.advanceState()
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawEndObject: // empty embedded document
|
||||||
|
t = TypeEmbeddedDocument
|
||||||
|
ejp.emptyObject = true
|
||||||
|
case jpsInvalidState:
|
||||||
|
err = ejp.err
|
||||||
|
case jpsSawKey:
|
||||||
|
if initialState == jpsStartState {
|
||||||
|
return TypeEmbeddedDocument, nil
|
||||||
|
}
|
||||||
|
t = wrapperKeyBSONType(ejp.k)
|
||||||
|
|
||||||
|
// if $uuid is encountered, parse as binary subtype 4
|
||||||
|
if ejp.k == "$uuid" {
|
||||||
|
ejp.relaxedUUID = true
|
||||||
|
t = TypeBinary
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case TypeJavaScript:
|
||||||
|
// just saw $code, need to check for $scope at same level
|
||||||
|
_, err = ejp.readValue(TypeJavaScript)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawEndObject: // type is TypeJavaScript
|
||||||
|
case jpsSawComma:
|
||||||
|
ejp.advanceState()
|
||||||
|
|
||||||
|
if ejp.s == jpsSawKey && ejp.k == "$scope" {
|
||||||
|
t = TypeCodeWithScope
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
|
||||||
|
}
|
||||||
|
case jpsInvalidState:
|
||||||
|
err = ejp.err
|
||||||
|
default:
|
||||||
|
err = ErrInvalidJSON
|
||||||
|
}
|
||||||
|
case TypeCodeWithScope:
|
||||||
|
err = errors.New("invalid extended JSON: code with $scope must contain $code before $scope")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// readKey parses the next key and its type and returns them
|
||||||
|
func (ejp *extJSONParser) readKey() (string, Type, error) {
|
||||||
|
if ejp.emptyObject {
|
||||||
|
ejp.emptyObject = false
|
||||||
|
return "", 0, ErrEOD
|
||||||
|
}
|
||||||
|
|
||||||
|
// advance to key (or return with error)
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsStartState:
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s == jpsSawBeginObject {
|
||||||
|
ejp.advanceState()
|
||||||
|
}
|
||||||
|
case jpsSawBeginObject:
|
||||||
|
ejp.advanceState()
|
||||||
|
case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
|
||||||
|
ejp.advanceState()
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawBeginObject, jpsSawComma:
|
||||||
|
ejp.advanceState()
|
||||||
|
case jpsSawEndObject:
|
||||||
|
return "", 0, ErrEOD
|
||||||
|
case jpsDoneState:
|
||||||
|
return "", 0, io.EOF
|
||||||
|
case jpsInvalidState:
|
||||||
|
return "", 0, ejp.err
|
||||||
|
default:
|
||||||
|
return "", 0, ErrInvalidJSON
|
||||||
|
}
|
||||||
|
case jpsSawKey: // do nothing (key was peeked before)
|
||||||
|
default:
|
||||||
|
return "", 0, invalidRequestError("key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// read key
|
||||||
|
var key string
|
||||||
|
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawKey:
|
||||||
|
key = ejp.k
|
||||||
|
case jpsSawEndObject:
|
||||||
|
return "", 0, ErrEOD
|
||||||
|
case jpsInvalidState:
|
||||||
|
return "", 0, ejp.err
|
||||||
|
default:
|
||||||
|
return "", 0, invalidRequestError("key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for colon
|
||||||
|
ejp.advanceState()
|
||||||
|
if err := ensureColon(ejp.s, key); err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// peek at the value to determine type
|
||||||
|
t, err := ejp.peekType()
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readValue returns the value corresponding to the Type returned by peekType
|
||||||
|
func (ejp *extJSONParser) readValue(t Type) (*extJSONValue, error) {
|
||||||
|
if ejp.s == jpsInvalidState {
|
||||||
|
return nil, ejp.err
|
||||||
|
}
|
||||||
|
|
||||||
|
var v *extJSONValue
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case TypeNull, TypeBoolean, TypeString:
|
||||||
|
if ejp.s != jpsSawValue {
|
||||||
|
return nil, invalidRequestError(t.String())
|
||||||
|
}
|
||||||
|
v = ejp.v
|
||||||
|
case TypeInt32, TypeInt64, TypeDouble:
|
||||||
|
// relaxed version allows these to be literal number values
|
||||||
|
if ejp.s == jpsSawValue {
|
||||||
|
v = ejp.v
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case TypeDecimal128, TypeSymbol, TypeObjectID, TypeMinKey, TypeMaxKey, TypeUndefined:
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawKey:
|
||||||
|
// read colon
|
||||||
|
ejp.advanceState()
|
||||||
|
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// read value
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
|
||||||
|
return nil, invalidJSONErrorForType("value", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
v = ejp.v
|
||||||
|
|
||||||
|
// read end object
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawEndObject {
|
||||||
|
return nil, invalidJSONErrorForType("} after value", t)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, invalidRequestError(t.String())
|
||||||
|
}
|
||||||
|
case TypeBinary, TypeRegex, TypeTimestamp, TypeDBPointer:
|
||||||
|
if ejp.s != jpsSawKey {
|
||||||
|
return nil, invalidRequestError(t.String())
|
||||||
|
}
|
||||||
|
// read colon
|
||||||
|
ejp.advanceState()
|
||||||
|
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
if t == TypeBinary && ejp.s == jpsSawValue {
|
||||||
|
// convert relaxed $uuid format
|
||||||
|
if ejp.relaxedUUID {
|
||||||
|
defer func() { ejp.relaxedUUID = false }()
|
||||||
|
uuid, err := ejp.v.parseSymbol()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing
|
||||||
|
// in the 8th, 13th, 18th, and 23rd characters.
|
||||||
|
//
|
||||||
|
// See https://tools.ietf.org/html/rfc4122#section-3
|
||||||
|
valid := len(uuid) == 36 &&
|
||||||
|
string(uuid[8]) == "-" &&
|
||||||
|
string(uuid[13]) == "-" &&
|
||||||
|
string(uuid[18]) == "-" &&
|
||||||
|
string(uuid[23]) == "-"
|
||||||
|
if !valid {
|
||||||
|
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove hyphens
|
||||||
|
uuidNoHyphens := strings.ReplaceAll(uuid, "-", "")
|
||||||
|
if len(uuidNoHyphens) != 32 {
|
||||||
|
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert hex to bytes
|
||||||
|
bytes, err := hex.DecodeString(uuidNoHyphens)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawEndObject {
|
||||||
|
return nil, invalidJSONErrorForType("$uuid and value and then }", TypeBinary)
|
||||||
|
}
|
||||||
|
|
||||||
|
base64 := &extJSONValue{
|
||||||
|
t: TypeString,
|
||||||
|
v: base64.StdEncoding.EncodeToString(bytes),
|
||||||
|
}
|
||||||
|
subType := &extJSONValue{
|
||||||
|
t: TypeString,
|
||||||
|
v: "04",
|
||||||
|
}
|
||||||
|
|
||||||
|
v = &extJSONValue{
|
||||||
|
t: TypeEmbeddedDocument,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"base64", "subType"},
|
||||||
|
values: []*extJSONValue{base64, subType},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert legacy $binary format
|
||||||
|
base64 := ejp.v
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawComma {
|
||||||
|
return nil, invalidJSONErrorForType(",", TypeBinary)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
key, t, err := ejp.readKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if key != "$type" {
|
||||||
|
return nil, invalidJSONErrorForType("$type", TypeBinary)
|
||||||
|
}
|
||||||
|
|
||||||
|
subType, err := ejp.readValue(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawEndObject {
|
||||||
|
return nil, invalidJSONErrorForType("2 key-value pairs and then }", TypeBinary)
|
||||||
|
}
|
||||||
|
|
||||||
|
v = &extJSONValue{
|
||||||
|
t: TypeEmbeddedDocument,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"base64", "subType"},
|
||||||
|
values: []*extJSONValue{base64, subType},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// read KV pairs
|
||||||
|
if ejp.s != jpsSawBeginObject {
|
||||||
|
return nil, invalidJSONErrorForType("{", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, vals, err := ejp.readObject(2, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawEndObject {
|
||||||
|
return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
|
||||||
|
|
||||||
|
case TypeDateTime:
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawValue:
|
||||||
|
v = ejp.v
|
||||||
|
case jpsSawKey:
|
||||||
|
// read colon
|
||||||
|
ejp.advanceState()
|
||||||
|
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawBeginObject:
|
||||||
|
keys, vals, err := ejp.readObject(1, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
|
||||||
|
case jpsSawValue:
|
||||||
|
if ejp.canonicalOnly {
|
||||||
|
return nil, invalidJSONError("{")
|
||||||
|
}
|
||||||
|
v = ejp.v
|
||||||
|
default:
|
||||||
|
if ejp.canonicalOnly {
|
||||||
|
return nil, invalidJSONErrorForType("object", t)
|
||||||
|
}
|
||||||
|
return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as described in RFC-3339", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawEndObject {
|
||||||
|
return nil, invalidJSONErrorForType("value and then }", t)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, invalidRequestError(t.String())
|
||||||
|
}
|
||||||
|
case TypeJavaScript:
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawKey:
|
||||||
|
// read colon
|
||||||
|
ejp.advanceState()
|
||||||
|
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// read value
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawValue {
|
||||||
|
return nil, invalidJSONErrorForType("value", t)
|
||||||
|
}
|
||||||
|
v = ejp.v
|
||||||
|
|
||||||
|
// read end object or comma and just return
|
||||||
|
ejp.advanceState()
|
||||||
|
case jpsSawEndObject:
|
||||||
|
v = ejp.v
|
||||||
|
default:
|
||||||
|
return nil, invalidRequestError(t.String())
|
||||||
|
}
|
||||||
|
case TypeCodeWithScope:
|
||||||
|
if ejp.s == jpsSawKey && ejp.k == "$scope" {
|
||||||
|
v = ejp.v // this is the $code string from earlier
|
||||||
|
|
||||||
|
// read colon
|
||||||
|
ejp.advanceState()
|
||||||
|
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// read {
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawBeginObject {
|
||||||
|
return nil, invalidJSONError("$scope to be embedded document")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, invalidRequestError(t.String())
|
||||||
|
}
|
||||||
|
case TypeEmbeddedDocument, TypeArray:
|
||||||
|
return nil, invalidRequestError(t.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readObject is a utility method for reading full objects of known (or expected) size
|
||||||
|
// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
|
||||||
|
func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
|
||||||
|
keys := make([]string, numKeys)
|
||||||
|
vals := make([]*extJSONValue, numKeys)
|
||||||
|
|
||||||
|
if !started {
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawBeginObject {
|
||||||
|
return nil, nil, invalidJSONError("{")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numKeys; i++ {
|
||||||
|
key, t, err := ejp.readKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawKey:
|
||||||
|
v, err := ejp.readValue(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
keys[i] = key
|
||||||
|
vals[i] = v
|
||||||
|
case jpsSawValue:
|
||||||
|
keys[i] = key
|
||||||
|
vals[i] = ejp.v
|
||||||
|
default:
|
||||||
|
return nil, nil, invalidJSONError("value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ejp.advanceState()
|
||||||
|
if ejp.s != jpsSawEndObject {
|
||||||
|
return nil, nil, invalidJSONError("}")
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, vals, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// advanceState reads the next JSON token from the scanner and transitions
|
||||||
|
// from the current state based on that token's type
|
||||||
|
func (ejp *extJSONParser) advanceState() {
|
||||||
|
if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jt, err := ejp.js.nextToken()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
ejp.err = err
|
||||||
|
ejp.s = jpsInvalidState
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
valid := ejp.validateToken(jt.t)
|
||||||
|
if !valid {
|
||||||
|
ejp.err = unexpectedTokenError(jt)
|
||||||
|
ejp.s = jpsInvalidState
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch jt.t {
|
||||||
|
case jttBeginObject:
|
||||||
|
ejp.s = jpsSawBeginObject
|
||||||
|
ejp.pushMode(jpmObjectMode)
|
||||||
|
ejp.depth++
|
||||||
|
|
||||||
|
if ejp.depth > ejp.maxDepth {
|
||||||
|
ejp.err = nestingDepthError(jt.p, ejp.depth)
|
||||||
|
ejp.s = jpsInvalidState
|
||||||
|
}
|
||||||
|
case jttEndObject:
|
||||||
|
ejp.s = jpsSawEndObject
|
||||||
|
ejp.depth--
|
||||||
|
|
||||||
|
if ejp.popMode() != jpmObjectMode {
|
||||||
|
ejp.err = unexpectedTokenError(jt)
|
||||||
|
ejp.s = jpsInvalidState
|
||||||
|
}
|
||||||
|
case jttBeginArray:
|
||||||
|
ejp.s = jpsSawBeginArray
|
||||||
|
ejp.pushMode(jpmArrayMode)
|
||||||
|
case jttEndArray:
|
||||||
|
ejp.s = jpsSawEndArray
|
||||||
|
|
||||||
|
if ejp.popMode() != jpmArrayMode {
|
||||||
|
ejp.err = unexpectedTokenError(jt)
|
||||||
|
ejp.s = jpsInvalidState
|
||||||
|
}
|
||||||
|
case jttColon:
|
||||||
|
ejp.s = jpsSawColon
|
||||||
|
case jttComma:
|
||||||
|
ejp.s = jpsSawComma
|
||||||
|
case jttEOF:
|
||||||
|
ejp.s = jpsDoneState
|
||||||
|
if len(ejp.m) != 0 {
|
||||||
|
ejp.err = unexpectedTokenError(jt)
|
||||||
|
ejp.s = jpsInvalidState
|
||||||
|
}
|
||||||
|
case jttString:
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawComma:
|
||||||
|
if ejp.peekMode() == jpmArrayMode {
|
||||||
|
ejp.s = jpsSawValue
|
||||||
|
ejp.v = extendJSONToken(jt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case jpsSawBeginObject:
|
||||||
|
ejp.s = jpsSawKey
|
||||||
|
ejp.k = jt.v.(string)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
ejp.s = jpsSawValue
|
||||||
|
ejp.v = extendJSONToken(jt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
|
||||||
|
jpsStartState: {
|
||||||
|
jttBeginObject: true,
|
||||||
|
jttBeginArray: true,
|
||||||
|
jttInt32: true,
|
||||||
|
jttInt64: true,
|
||||||
|
jttDouble: true,
|
||||||
|
jttString: true,
|
||||||
|
jttBool: true,
|
||||||
|
jttNull: true,
|
||||||
|
jttEOF: true,
|
||||||
|
},
|
||||||
|
jpsSawBeginObject: {
|
||||||
|
jttEndObject: true,
|
||||||
|
jttString: true,
|
||||||
|
},
|
||||||
|
jpsSawEndObject: {
|
||||||
|
jttEndObject: true,
|
||||||
|
jttEndArray: true,
|
||||||
|
jttComma: true,
|
||||||
|
jttEOF: true,
|
||||||
|
},
|
||||||
|
jpsSawBeginArray: {
|
||||||
|
jttBeginObject: true,
|
||||||
|
jttBeginArray: true,
|
||||||
|
jttEndArray: true,
|
||||||
|
jttInt32: true,
|
||||||
|
jttInt64: true,
|
||||||
|
jttDouble: true,
|
||||||
|
jttString: true,
|
||||||
|
jttBool: true,
|
||||||
|
jttNull: true,
|
||||||
|
},
|
||||||
|
jpsSawEndArray: {
|
||||||
|
jttEndObject: true,
|
||||||
|
jttEndArray: true,
|
||||||
|
jttComma: true,
|
||||||
|
jttEOF: true,
|
||||||
|
},
|
||||||
|
jpsSawColon: {
|
||||||
|
jttBeginObject: true,
|
||||||
|
jttBeginArray: true,
|
||||||
|
jttInt32: true,
|
||||||
|
jttInt64: true,
|
||||||
|
jttDouble: true,
|
||||||
|
jttString: true,
|
||||||
|
jttBool: true,
|
||||||
|
jttNull: true,
|
||||||
|
},
|
||||||
|
jpsSawComma: {
|
||||||
|
jttBeginObject: true,
|
||||||
|
jttBeginArray: true,
|
||||||
|
jttInt32: true,
|
||||||
|
jttInt64: true,
|
||||||
|
jttDouble: true,
|
||||||
|
jttString: true,
|
||||||
|
jttBool: true,
|
||||||
|
jttNull: true,
|
||||||
|
},
|
||||||
|
jpsSawKey: {
|
||||||
|
jttColon: true,
|
||||||
|
},
|
||||||
|
jpsSawValue: {
|
||||||
|
jttEndObject: true,
|
||||||
|
jttEndArray: true,
|
||||||
|
jttComma: true,
|
||||||
|
jttEOF: true,
|
||||||
|
},
|
||||||
|
jpsDoneState: {},
|
||||||
|
jpsInvalidState: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
|
||||||
|
switch ejp.s {
|
||||||
|
case jpsSawEndObject:
|
||||||
|
// if we are at depth zero and the next token is a '{',
|
||||||
|
// we can consider it valid only if we are not in array mode.
|
||||||
|
if jtt == jttBeginObject && ejp.depth == 0 {
|
||||||
|
return ejp.peekMode() != jpmArrayMode
|
||||||
|
}
|
||||||
|
case jpsSawComma:
|
||||||
|
switch ejp.peekMode() {
|
||||||
|
// the only valid next token after a comma inside a document is a string (a key)
|
||||||
|
case jpmObjectMode:
|
||||||
|
return jtt == jttString
|
||||||
|
case jpmInvalidMode:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := jpsValidTransitionTokens[ejp.s][jtt]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureExtValueType returns true if the current value has the expected
|
||||||
|
// value type for single-key extended JSON types. For example,
|
||||||
|
// {"$numberInt": v} v must be TypeString
|
||||||
|
func (ejp *extJSONParser) ensureExtValueType(t Type) bool {
|
||||||
|
switch t {
|
||||||
|
case TypeMinKey, TypeMaxKey:
|
||||||
|
return ejp.v.t == TypeInt32
|
||||||
|
case TypeUndefined:
|
||||||
|
return ejp.v.t == TypeBoolean
|
||||||
|
case TypeInt32, TypeInt64, TypeDouble, TypeDecimal128, TypeSymbol, TypeObjectID:
|
||||||
|
return ejp.v.t == TypeString
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejp *extJSONParser) pushMode(m jsonParseMode) {
|
||||||
|
ejp.m = append(ejp.m, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejp *extJSONParser) popMode() jsonParseMode {
|
||||||
|
l := len(ejp.m)
|
||||||
|
if l == 0 {
|
||||||
|
return jpmInvalidMode
|
||||||
|
}
|
||||||
|
|
||||||
|
m := ejp.m[l-1]
|
||||||
|
ejp.m = ejp.m[:l-1]
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejp *extJSONParser) peekMode() jsonParseMode {
|
||||||
|
l := len(ejp.m)
|
||||||
|
if l == 0 {
|
||||||
|
return jpmInvalidMode
|
||||||
|
}
|
||||||
|
|
||||||
|
return ejp.m[l-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func extendJSONToken(jt *jsonToken) *extJSONValue {
|
||||||
|
var t Type
|
||||||
|
|
||||||
|
switch jt.t {
|
||||||
|
case jttInt32:
|
||||||
|
t = TypeInt32
|
||||||
|
case jttInt64:
|
||||||
|
t = TypeInt64
|
||||||
|
case jttDouble:
|
||||||
|
t = TypeDouble
|
||||||
|
case jttString:
|
||||||
|
t = TypeString
|
||||||
|
case jttBool:
|
||||||
|
t = TypeBoolean
|
||||||
|
case jttNull:
|
||||||
|
t = TypeNull
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &extJSONValue{t: t, v: jt.v}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureColon(s jsonParseState, key string) error {
|
||||||
|
if s != jpsSawColon {
|
||||||
|
return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func invalidRequestError(s string) error {
|
||||||
|
return fmt.Errorf("invalid request to read %s", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func invalidJSONError(expected string) error {
|
||||||
|
return fmt.Errorf("invalid JSON input; expected %s", expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
func invalidJSONErrorForType(expected string, t Type) error {
|
||||||
|
return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unexpectedTokenError(jt *jsonToken) error {
|
||||||
|
switch jt.t {
|
||||||
|
case jttInt32, jttInt64, jttDouble:
|
||||||
|
return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
|
||||||
|
case jttString:
|
||||||
|
return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
|
||||||
|
case jttBool:
|
||||||
|
return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
|
||||||
|
case jttNull:
|
||||||
|
return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
|
||||||
|
case jttEOF:
|
||||||
|
return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func nestingDepthError(p, depth int) error {
|
||||||
|
return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
|
||||||
|
}
|
804
extjson_parser_test.go
Normal file
804
extjson_parser_test.go
Normal file
@ -0,0 +1,804 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
keyDiff = specificDiff("key")
|
||||||
|
typDiff = specificDiff("type")
|
||||||
|
valDiff = specificDiff("value")
|
||||||
|
|
||||||
|
expectErrEOF = expectSpecificError(io.EOF)
|
||||||
|
expectErrEOD = expectSpecificError(ErrEOD)
|
||||||
|
expectErrEOA = expectSpecificError(ErrEOA)
|
||||||
|
)
|
||||||
|
|
||||||
|
type expectedErrorFunc func(t *testing.T, err error, desc string)
|
||||||
|
|
||||||
|
type peekTypeTestCase struct {
|
||||||
|
desc string
|
||||||
|
input string
|
||||||
|
typs []Type
|
||||||
|
errFs []expectedErrorFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type readKeyValueTestCase struct {
|
||||||
|
desc string
|
||||||
|
input string
|
||||||
|
keys []string
|
||||||
|
typs []Type
|
||||||
|
vals []*extJSONValue
|
||||||
|
|
||||||
|
keyEFs []expectedErrorFunc
|
||||||
|
valEFs []expectedErrorFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectNoError(t *testing.T, err error, desc string) {
|
||||||
|
if err != nil {
|
||||||
|
t.Helper()
|
||||||
|
t.Errorf("%s: Unepexted error: %v", desc, err)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectError(t *testing.T, err error, desc string) {
|
||||||
|
if err == nil {
|
||||||
|
t.Helper()
|
||||||
|
t.Errorf("%s: Expected error", desc)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectSpecificError(expected error) expectedErrorFunc {
|
||||||
|
return func(t *testing.T, err error, desc string) {
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Helper()
|
||||||
|
t.Errorf("%s: Expected %v but got: %v", desc, expected, err)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func specificDiff(name string) func(t *testing.T, expected, actual interface{}, desc string) {
|
||||||
|
return func(t *testing.T, expected, actual interface{}, desc string) {
|
||||||
|
if diff := cmp.Diff(expected, actual); diff != "" {
|
||||||
|
t.Helper()
|
||||||
|
t.Errorf("%s: Incorrect JSON %s (-want, +got): %s\n", desc, name, diff)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectErrorNOOP(_ *testing.T, _ error, _ string) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func readKeyDiff(t *testing.T, eKey, aKey string, eTyp, aTyp Type, err error, errF expectedErrorFunc, desc string) {
|
||||||
|
keyDiff(t, eKey, aKey, desc)
|
||||||
|
typDiff(t, eTyp, aTyp, desc)
|
||||||
|
errF(t, err, desc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readValueDiff(t *testing.T, eVal, aVal *extJSONValue, err error, errF expectedErrorFunc, desc string) {
|
||||||
|
if aVal != nil {
|
||||||
|
typDiff(t, eVal.t, aVal.t, desc)
|
||||||
|
valDiff(t, eVal.v, aVal.v, desc)
|
||||||
|
} else {
|
||||||
|
valDiff(t, eVal, aVal, desc)
|
||||||
|
}
|
||||||
|
|
||||||
|
errF(t, err, desc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtJSONParserPeekType(t *testing.T) {
|
||||||
|
makeValidPeekTypeTestCase := func(input string, typ Type, desc string) peekTypeTestCase {
|
||||||
|
return peekTypeTestCase{
|
||||||
|
desc: desc, input: input,
|
||||||
|
typs: []Type{typ},
|
||||||
|
errFs: []expectedErrorFunc{expectNoError},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
makeInvalidTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase {
|
||||||
|
return peekTypeTestCase{
|
||||||
|
desc: desc, input: input,
|
||||||
|
typs: []Type{Type(0)},
|
||||||
|
errFs: []expectedErrorFunc{lastEF},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
makeInvalidPeekTypeTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase {
|
||||||
|
return peekTypeTestCase{
|
||||||
|
desc: desc, input: input,
|
||||||
|
typs: []Type{TypeArray, TypeString, Type(0)},
|
||||||
|
errFs: []expectedErrorFunc{expectNoError, expectNoError, lastEF},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []peekTypeTestCase{
|
||||||
|
makeValidPeekTypeTestCase(`null`, TypeNull, "Null"),
|
||||||
|
makeValidPeekTypeTestCase(`"string"`, TypeString, "String"),
|
||||||
|
makeValidPeekTypeTestCase(`true`, TypeBoolean, "Boolean--true"),
|
||||||
|
makeValidPeekTypeTestCase(`false`, TypeBoolean, "Boolean--false"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$minKey": 1}`, TypeMinKey, "MinKey"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$maxKey": 1}`, TypeMaxKey, "MaxKey"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$numberInt": "42"}`, TypeInt32, "Int32"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$numberLong": "42"}`, TypeInt64, "Int64"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$symbol": "symbol"}`, TypeSymbol, "Symbol"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$numberDouble": "42.42"}`, TypeDouble, "Double"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$undefined": true}`, TypeUndefined, "Undefined"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$numberDouble": "NaN"}`, TypeDouble, "Double--NaN"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$numberDecimal": "1234"}`, TypeDecimal128, "Decimal"),
|
||||||
|
makeValidPeekTypeTestCase(`{"foo": "bar"}`, TypeEmbeddedDocument, "Toplevel document"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$date": {"$numberLong": "0"}}`, TypeDateTime, "Datetime"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$code": "function() {}"}`, TypeJavaScript, "Code no scope"),
|
||||||
|
makeValidPeekTypeTestCase(`[{"$numberInt": "1"},{"$numberInt": "2"}]`, TypeArray, "Array"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$timestamp": {"t": 42, "i": 1}}`, TypeTimestamp, "Timestamp"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$oid": "57e193d7a9cc81b4027498b5"}`, TypeObjectID, "Object ID"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$binary": {"base64": "AQIDBAU=", "subType": "80"}}`, TypeBinary, "Binary"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$code": "function() {}", "$scope": {}}`, TypeCodeWithScope, "Code With Scope"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$binary": {"base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03"}}`, TypeBinary, "Binary"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03"}`, TypeBinary, "Binary"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$regularExpression": {"pattern": "foo*", "options": "ix"}}`, TypeRegex, "Regular expression"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$dbPointer": {"$ref": "db.collection", "$id": {"$oid": "57e193d7a9cc81b4027498b1"}}}`, TypeDBPointer, "DBPointer"),
|
||||||
|
makeValidPeekTypeTestCase(`{"$ref": "collection", "$id": {"$oid": "57fd71e96e32ab4225b723fb"}, "$db": "database"}`, TypeEmbeddedDocument, "DBRef"),
|
||||||
|
makeInvalidPeekTypeTestCase("invalid array--missing ]", `["a"`, expectError),
|
||||||
|
makeInvalidPeekTypeTestCase("invalid array--colon in array", `["a":`, expectError),
|
||||||
|
makeInvalidPeekTypeTestCase("invalid array--extra comma", `["a",,`, expectError),
|
||||||
|
makeInvalidPeekTypeTestCase("invalid array--trailing comma", `["a",]`, expectError),
|
||||||
|
makeInvalidPeekTypeTestCase("peekType after end of array", `["a"]`, expectErrEOA),
|
||||||
|
{
|
||||||
|
desc: "invalid array--leading comma",
|
||||||
|
input: `[,`,
|
||||||
|
typs: []Type{TypeArray, Type(0)},
|
||||||
|
errFs: []expectedErrorFunc{expectNoError, expectError},
|
||||||
|
},
|
||||||
|
makeInvalidTestCase("lone $scope", `{"$scope": {}}`, expectError),
|
||||||
|
makeInvalidTestCase("empty code with unknown extra key", `{"$code":"", "0":""}`, expectError),
|
||||||
|
makeInvalidTestCase("non-empty code with unknown extra key", `{"$code":"foobar", "0":""}`, expectError),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
ejp := newExtJSONParser(strings.NewReader(tc.input), true)
|
||||||
|
// Manually set the parser's starting state to jpsSawColon so peekType will read ahead to find the extjson
|
||||||
|
// type of the value. If not set, the parser will be in jpsStartState and advance to jpsSawKey, which will
|
||||||
|
// cause it to return without peeking the extjson type.
|
||||||
|
ejp.s = jpsSawColon
|
||||||
|
|
||||||
|
for i, eTyp := range tc.typs {
|
||||||
|
errF := tc.errFs[i]
|
||||||
|
|
||||||
|
typ, err := ejp.peekType()
|
||||||
|
errF(t, err, tc.desc)
|
||||||
|
if err != nil {
|
||||||
|
// Don't inspect the type if there was an error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
typDiff(t, eTyp, typ, tc.desc)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtJSONParserReadKeyReadValue(t *testing.T) {
|
||||||
|
// several test cases will use the same keys, types, and values, and only differ on input structure
|
||||||
|
|
||||||
|
keys := []string{"_id", "Symbol", "String", "Int32", "Int64", "Int", "MinKey"}
|
||||||
|
types := []Type{TypeObjectID, TypeSymbol, TypeString, TypeInt32, TypeInt64, TypeInt32, TypeMinKey}
|
||||||
|
values := []*extJSONValue{
|
||||||
|
{t: TypeString, v: "57e193d7a9cc81b4027498b5"},
|
||||||
|
{t: TypeString, v: "symbol"},
|
||||||
|
{t: TypeString, v: "string"},
|
||||||
|
{t: TypeString, v: "42"},
|
||||||
|
{t: TypeString, v: "42"},
|
||||||
|
{t: TypeInt32, v: int32(42)},
|
||||||
|
{t: TypeInt32, v: int32(1)},
|
||||||
|
}
|
||||||
|
|
||||||
|
errFuncs := make([]expectedErrorFunc, 7)
|
||||||
|
for i := 0; i < 7; i++ {
|
||||||
|
errFuncs[i] = expectNoError
|
||||||
|
}
|
||||||
|
|
||||||
|
firstKeyError := func(desc, input string) readKeyValueTestCase {
|
||||||
|
return readKeyValueTestCase{
|
||||||
|
desc: desc,
|
||||||
|
input: input,
|
||||||
|
keys: []string{""},
|
||||||
|
typs: []Type{Type(0)},
|
||||||
|
vals: []*extJSONValue{nil},
|
||||||
|
keyEFs: []expectedErrorFunc{expectError},
|
||||||
|
valEFs: []expectedErrorFunc{expectErrorNOOP},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
secondKeyError := func(desc, input, firstKey string, firstType Type, firstValue *extJSONValue) readKeyValueTestCase {
|
||||||
|
return readKeyValueTestCase{
|
||||||
|
desc: desc,
|
||||||
|
input: input,
|
||||||
|
keys: []string{firstKey, ""},
|
||||||
|
typs: []Type{firstType, Type(0)},
|
||||||
|
vals: []*extJSONValue{firstValue, nil},
|
||||||
|
keyEFs: []expectedErrorFunc{expectNoError, expectError},
|
||||||
|
valEFs: []expectedErrorFunc{expectNoError, expectErrorNOOP},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []readKeyValueTestCase{
|
||||||
|
{
|
||||||
|
desc: "normal spacing",
|
||||||
|
input: `{
|
||||||
|
"_id": { "$oid": "57e193d7a9cc81b4027498b5" },
|
||||||
|
"Symbol": { "$symbol": "symbol" },
|
||||||
|
"String": "string",
|
||||||
|
"Int32": { "$numberInt": "42" },
|
||||||
|
"Int64": { "$numberLong": "42" },
|
||||||
|
"Int": 42,
|
||||||
|
"MinKey": { "$minKey": 1 }
|
||||||
|
}`,
|
||||||
|
keys: keys, typs: types, vals: values,
|
||||||
|
keyEFs: errFuncs, valEFs: errFuncs,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "new line before comma",
|
||||||
|
input: `{ "_id": { "$oid": "57e193d7a9cc81b4027498b5" }
|
||||||
|
, "Symbol": { "$symbol": "symbol" }
|
||||||
|
, "String": "string"
|
||||||
|
, "Int32": { "$numberInt": "42" }
|
||||||
|
, "Int64": { "$numberLong": "42" }
|
||||||
|
, "Int": 42
|
||||||
|
, "MinKey": { "$minKey": 1 }
|
||||||
|
}`,
|
||||||
|
keys: keys, typs: types, vals: values,
|
||||||
|
keyEFs: errFuncs, valEFs: errFuncs,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "tabs around colons",
|
||||||
|
input: `{
|
||||||
|
"_id": { "$oid" : "57e193d7a9cc81b4027498b5" },
|
||||||
|
"Symbol": { "$symbol" : "symbol" },
|
||||||
|
"String": "string",
|
||||||
|
"Int32": { "$numberInt" : "42" },
|
||||||
|
"Int64": { "$numberLong": "42" },
|
||||||
|
"Int": 42,
|
||||||
|
"MinKey": { "$minKey": 1 }
|
||||||
|
}`,
|
||||||
|
keys: keys, typs: types, vals: values,
|
||||||
|
keyEFs: errFuncs, valEFs: errFuncs,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "no whitespace",
|
||||||
|
input: `{"_id":{"$oid":"57e193d7a9cc81b4027498b5"},"Symbol":{"$symbol":"symbol"},"String":"string","Int32":{"$numberInt":"42"},"Int64":{"$numberLong":"42"},"Int":42,"MinKey":{"$minKey":1}}`,
|
||||||
|
keys: keys, typs: types, vals: values,
|
||||||
|
keyEFs: errFuncs, valEFs: errFuncs,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "mixed whitespace",
|
||||||
|
input: ` {
|
||||||
|
"_id" : { "$oid": "57e193d7a9cc81b4027498b5" },
|
||||||
|
"Symbol" : { "$symbol": "symbol" } ,
|
||||||
|
"String" : "string",
|
||||||
|
"Int32" : { "$numberInt": "42" } ,
|
||||||
|
"Int64" : {"$numberLong" : "42"},
|
||||||
|
"Int" : 42,
|
||||||
|
"MinKey" : { "$minKey": 1 } } `,
|
||||||
|
keys: keys, typs: types, vals: values,
|
||||||
|
keyEFs: errFuncs, valEFs: errFuncs,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "nested object",
|
||||||
|
input: `{"k1": 1, "k2": { "k3": { "k4": 4 } }, "k5": 5}`,
|
||||||
|
keys: []string{"k1", "k2", "k3", "k4", "", "", "k5", ""},
|
||||||
|
typs: []Type{TypeInt32, TypeEmbeddedDocument, TypeEmbeddedDocument, TypeInt32, Type(0), Type(0), TypeInt32, Type(0)},
|
||||||
|
vals: []*extJSONValue{
|
||||||
|
{t: TypeInt32, v: int32(1)}, nil, nil, {t: TypeInt32, v: int32(4)}, nil, nil, {t: TypeInt32, v: int32(5)}, nil,
|
||||||
|
},
|
||||||
|
keyEFs: []expectedErrorFunc{
|
||||||
|
expectNoError, expectNoError, expectNoError, expectNoError, expectErrEOD,
|
||||||
|
expectErrEOD, expectNoError, expectErrEOD,
|
||||||
|
},
|
||||||
|
valEFs: []expectedErrorFunc{
|
||||||
|
expectNoError, expectError, expectError, expectNoError, expectErrorNOOP,
|
||||||
|
expectErrorNOOP, expectNoError, expectErrorNOOP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "invalid input: invalid values for extended type",
|
||||||
|
input: `{"a": {"$numberInt": "1", "x"`,
|
||||||
|
keys: []string{"a"},
|
||||||
|
typs: []Type{TypeInt32},
|
||||||
|
vals: []*extJSONValue{nil},
|
||||||
|
keyEFs: []expectedErrorFunc{expectNoError},
|
||||||
|
valEFs: []expectedErrorFunc{expectError},
|
||||||
|
},
|
||||||
|
firstKeyError("invalid input: missing key--EOF", "{"),
|
||||||
|
firstKeyError("invalid input: missing key--colon first", "{:"),
|
||||||
|
firstKeyError("invalid input: missing value", `{"a":`),
|
||||||
|
firstKeyError("invalid input: missing colon", `{"a" 1`),
|
||||||
|
firstKeyError("invalid input: extra colon", `{"a"::`),
|
||||||
|
secondKeyError("invalid input: missing }", `{"a": 1`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
|
||||||
|
secondKeyError("invalid input: missing comma", `{"a": 1 "b"`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
|
||||||
|
secondKeyError("invalid input: extra comma", `{"a": 1,, "b"`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
|
||||||
|
secondKeyError("invalid input: trailing comma in object", `{"a": 1,}`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
|
||||||
|
{
|
||||||
|
desc: "invalid input: lone scope after a complete value",
|
||||||
|
input: `{"a": "", "b": {"$scope: ""}}`,
|
||||||
|
keys: []string{"a"},
|
||||||
|
typs: []Type{TypeString},
|
||||||
|
vals: []*extJSONValue{{TypeString, ""}},
|
||||||
|
keyEFs: []expectedErrorFunc{expectNoError, expectNoError},
|
||||||
|
valEFs: []expectedErrorFunc{expectNoError, expectError},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "invalid input: lone scope nested",
|
||||||
|
input: `{"a":{"b":{"$scope":{`,
|
||||||
|
keys: []string{},
|
||||||
|
typs: []Type{},
|
||||||
|
vals: []*extJSONValue{nil},
|
||||||
|
keyEFs: []expectedErrorFunc{expectNoError},
|
||||||
|
valEFs: []expectedErrorFunc{expectError},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
ejp := newExtJSONParser(strings.NewReader(tc.input), true)
|
||||||
|
|
||||||
|
for i, eKey := range tc.keys {
|
||||||
|
eTyp := tc.typs[i]
|
||||||
|
eVal := tc.vals[i]
|
||||||
|
|
||||||
|
keyErrF := tc.keyEFs[i]
|
||||||
|
valErrF := tc.valEFs[i]
|
||||||
|
|
||||||
|
k, typ, err := ejp.readKey()
|
||||||
|
readKeyDiff(t, eKey, k, eTyp, typ, err, keyErrF, tc.desc)
|
||||||
|
|
||||||
|
v, err := ejp.readValue(typ)
|
||||||
|
readValueDiff(t, eVal, v, err, valErrF, tc.desc)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ejpExpectationTest func(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{})
|
||||||
|
|
||||||
|
type ejpTestCase struct {
|
||||||
|
f ejpExpectationTest
|
||||||
|
p *extJSONParser
|
||||||
|
k string
|
||||||
|
t Type
|
||||||
|
v interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// expectSingleValue is used for simple JSON types (strings, numbers, literals) and for extended JSON types that
|
||||||
|
// have single key-value pairs (i.e. { "$minKey": 1 }, { "$numberLong": "42.42" })
|
||||||
|
func expectSingleValue(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{}) {
|
||||||
|
eVal := expectedValue.(*extJSONValue)
|
||||||
|
|
||||||
|
k, typ, err := p.readKey()
|
||||||
|
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
|
||||||
|
|
||||||
|
v, err := p.readValue(typ)
|
||||||
|
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// expectMultipleValues is used for values that are subdocuments of known size and with known keys (such as extended
|
||||||
|
// JSON types { "$timestamp": {"t": 1, "i": 1} } and { "$regularExpression": {"pattern": "", options: ""} })
|
||||||
|
func expectMultipleValues(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{}) {
|
||||||
|
k, typ, err := p.readKey()
|
||||||
|
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
|
||||||
|
|
||||||
|
v, err := p.readValue(typ)
|
||||||
|
expectNoError(t, err, "")
|
||||||
|
typDiff(t, TypeEmbeddedDocument, v.t, expectedKey)
|
||||||
|
|
||||||
|
actObj := v.v.(*extJSONObject)
|
||||||
|
expObj := expectedValue.(*extJSONObject)
|
||||||
|
|
||||||
|
for i, actKey := range actObj.keys {
|
||||||
|
expKey := expObj.keys[i]
|
||||||
|
actVal := actObj.values[i]
|
||||||
|
expVal := expObj.values[i]
|
||||||
|
|
||||||
|
keyDiff(t, expKey, actKey, expectedKey)
|
||||||
|
typDiff(t, expVal.t, actVal.t, expectedKey)
|
||||||
|
valDiff(t, expVal.v, actVal.v, expectedKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ejpKeyTypValTriple struct {
|
||||||
|
key string
|
||||||
|
typ Type
|
||||||
|
val *extJSONValue
|
||||||
|
}
|
||||||
|
|
||||||
|
type ejpSubDocumentTestValue struct {
|
||||||
|
code string // code is only used for TypeCodeWithScope (and is ignored for TypeEmbeddedDocument
|
||||||
|
ktvs []ejpKeyTypValTriple // list of (key, type, value) triples; this is "scope" for TypeCodeWithScope
|
||||||
|
}
|
||||||
|
|
||||||
|
// expectSubDocument is used for embedded documents and code with scope types; it reads all the keys and values
|
||||||
|
// in the embedded document (or scope for codeWithScope) and compares them to the expectedValue's list of (key, type,
|
||||||
|
// value) triples
|
||||||
|
func expectSubDocument(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{}) {
|
||||||
|
subdoc := expectedValue.(ejpSubDocumentTestValue)
|
||||||
|
|
||||||
|
k, typ, err := p.readKey()
|
||||||
|
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
|
||||||
|
|
||||||
|
if expectedType == TypeCodeWithScope {
|
||||||
|
v, err := p.readValue(typ)
|
||||||
|
readValueDiff(t, &extJSONValue{t: TypeString, v: subdoc.code}, v, err, expectNoError, expectedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ktv := range subdoc.ktvs {
|
||||||
|
eKey := ktv.key
|
||||||
|
eTyp := ktv.typ
|
||||||
|
eVal := ktv.val
|
||||||
|
|
||||||
|
k, typ, err = p.readKey()
|
||||||
|
readKeyDiff(t, eKey, k, eTyp, typ, err, expectNoError, expectedKey)
|
||||||
|
|
||||||
|
v, err := p.readValue(typ)
|
||||||
|
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedType == TypeCodeWithScope {
|
||||||
|
// expect scope doc to close
|
||||||
|
k, typ, err = p.readKey()
|
||||||
|
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOD, expectedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// expect subdoc to close
|
||||||
|
k, typ, err = p.readKey()
|
||||||
|
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOD, expectedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// expectArray takes the expectedKey, ignores the expectedType, and uses the expectedValue
|
||||||
|
// as a slice of (type Type, value *extJSONValue) pairs
|
||||||
|
func expectArray(t *testing.T, p *extJSONParser, expectedKey string, _ Type, expectedValue interface{}) {
|
||||||
|
ktvs := expectedValue.([]ejpKeyTypValTriple)
|
||||||
|
|
||||||
|
k, typ, err := p.readKey()
|
||||||
|
readKeyDiff(t, expectedKey, k, TypeArray, typ, err, expectNoError, expectedKey)
|
||||||
|
|
||||||
|
for _, ktv := range ktvs {
|
||||||
|
eTyp := ktv.typ
|
||||||
|
eVal := ktv.val
|
||||||
|
|
||||||
|
typ, err = p.peekType()
|
||||||
|
typDiff(t, eTyp, typ, expectedKey)
|
||||||
|
expectNoError(t, err, expectedKey)
|
||||||
|
|
||||||
|
v, err := p.readValue(typ)
|
||||||
|
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// expect array to end
|
||||||
|
typ, err = p.peekType()
|
||||||
|
typDiff(t, Type(0), typ, expectedKey)
|
||||||
|
expectErrEOA(t, err, expectedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtJSONParserAllTypes(t *testing.T) {
|
||||||
|
in := ` { "_id" : { "$oid": "57e193d7a9cc81b4027498b5"}
|
||||||
|
, "Symbol" : { "$symbol": "symbol"}
|
||||||
|
, "String" : "string"
|
||||||
|
, "Int32" : { "$numberInt": "42"}
|
||||||
|
, "Int64" : { "$numberLong": "42"}
|
||||||
|
, "Double" : { "$numberDouble": "42.42"}
|
||||||
|
, "SpecialFloat" : { "$numberDouble": "NaN" }
|
||||||
|
, "Decimal" : { "$numberDecimal": "1234" }
|
||||||
|
, "Binary" : { "$binary": { "base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03" } }
|
||||||
|
, "BinaryLegacy" : { "$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03" }
|
||||||
|
, "BinaryUserDefined" : { "$binary": { "base64": "AQIDBAU=", "subType": "80" } }
|
||||||
|
, "Code" : { "$code": "function() {}" }
|
||||||
|
, "CodeWithEmptyScope" : { "$code": "function() {}", "$scope": {} }
|
||||||
|
, "CodeWithScope" : { "$code": "function() {}", "$scope": { "x": 1 } }
|
||||||
|
, "EmptySubdocument" : {}
|
||||||
|
, "Subdocument" : { "foo": "bar", "baz": { "$numberInt": "42" } }
|
||||||
|
, "Array" : [{"$numberInt": "1"}, {"$numberLong": "2"}, {"$numberDouble": "3"}, 4, "string", 5.0]
|
||||||
|
, "Timestamp" : { "$timestamp": { "t": 42, "i": 1 } }
|
||||||
|
, "RegularExpression" : { "$regularExpression": { "pattern": "foo*", "options": "ix" } }
|
||||||
|
, "DatetimeEpoch" : { "$date": { "$numberLong": "0" } }
|
||||||
|
, "DatetimePositive" : { "$date": { "$numberLong": "9223372036854775807" } }
|
||||||
|
, "DatetimeNegative" : { "$date": { "$numberLong": "-9223372036854775808" } }
|
||||||
|
, "True" : true
|
||||||
|
, "False" : false
|
||||||
|
, "DBPointer" : { "$dbPointer": { "$ref": "db.collection", "$id": { "$oid": "57e193d7a9cc81b4027498b1" } } }
|
||||||
|
, "DBRef" : { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" }, "$db": "database" }
|
||||||
|
, "DBRefNoDB" : { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" } }
|
||||||
|
, "MinKey" : { "$minKey": 1 }
|
||||||
|
, "MaxKey" : { "$maxKey": 1 }
|
||||||
|
, "Null" : null
|
||||||
|
, "Undefined" : { "$undefined": true }
|
||||||
|
}`
|
||||||
|
|
||||||
|
ejp := newExtJSONParser(strings.NewReader(in), true)
|
||||||
|
|
||||||
|
cases := []ejpTestCase{
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "_id", t: TypeObjectID, v: &extJSONValue{t: TypeString, v: "57e193d7a9cc81b4027498b5"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "Symbol", t: TypeSymbol, v: &extJSONValue{t: TypeString, v: "symbol"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "String", t: TypeString, v: &extJSONValue{t: TypeString, v: "string"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "Int32", t: TypeInt32, v: &extJSONValue{t: TypeString, v: "42"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "Int64", t: TypeInt64, v: &extJSONValue{t: TypeString, v: "42"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "Double", t: TypeDouble, v: &extJSONValue{t: TypeString, v: "42.42"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "SpecialFloat", t: TypeDouble, v: &extJSONValue{t: TypeString, v: "NaN"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "Decimal", t: TypeDecimal128, v: &extJSONValue{t: TypeString, v: "1234"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectMultipleValues, p: ejp,
|
||||||
|
k: "Binary", t: TypeBinary,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"base64", "subType"},
|
||||||
|
values: []*extJSONValue{
|
||||||
|
{t: TypeString, v: "o0w498Or7cijeBSpkquNtg=="},
|
||||||
|
{t: TypeString, v: "03"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectMultipleValues, p: ejp,
|
||||||
|
k: "BinaryLegacy", t: TypeBinary,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"base64", "subType"},
|
||||||
|
values: []*extJSONValue{
|
||||||
|
{t: TypeString, v: "o0w498Or7cijeBSpkquNtg=="},
|
||||||
|
{t: TypeString, v: "03"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectMultipleValues, p: ejp,
|
||||||
|
k: "BinaryUserDefined", t: TypeBinary,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"base64", "subType"},
|
||||||
|
values: []*extJSONValue{
|
||||||
|
{t: TypeString, v: "AQIDBAU="},
|
||||||
|
{t: TypeString, v: "80"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "Code", t: TypeJavaScript, v: &extJSONValue{t: TypeString, v: "function() {}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSubDocument, p: ejp,
|
||||||
|
k: "CodeWithEmptyScope", t: TypeCodeWithScope,
|
||||||
|
v: ejpSubDocumentTestValue{
|
||||||
|
code: "function() {}",
|
||||||
|
ktvs: []ejpKeyTypValTriple{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSubDocument, p: ejp,
|
||||||
|
k: "CodeWithScope", t: TypeCodeWithScope,
|
||||||
|
v: ejpSubDocumentTestValue{
|
||||||
|
code: "function() {}",
|
||||||
|
ktvs: []ejpKeyTypValTriple{
|
||||||
|
{"x", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSubDocument, p: ejp,
|
||||||
|
k: "EmptySubdocument", t: TypeEmbeddedDocument,
|
||||||
|
v: ejpSubDocumentTestValue{
|
||||||
|
ktvs: []ejpKeyTypValTriple{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSubDocument, p: ejp,
|
||||||
|
k: "Subdocument", t: TypeEmbeddedDocument,
|
||||||
|
v: ejpSubDocumentTestValue{
|
||||||
|
ktvs: []ejpKeyTypValTriple{
|
||||||
|
{"foo", TypeString, &extJSONValue{t: TypeString, v: "bar"}},
|
||||||
|
{"baz", TypeInt32, &extJSONValue{t: TypeString, v: "42"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectArray, p: ejp,
|
||||||
|
k: "Array", t: TypeArray,
|
||||||
|
v: []ejpKeyTypValTriple{
|
||||||
|
{typ: TypeInt32, val: &extJSONValue{t: TypeString, v: "1"}},
|
||||||
|
{typ: TypeInt64, val: &extJSONValue{t: TypeString, v: "2"}},
|
||||||
|
{typ: TypeDouble, val: &extJSONValue{t: TypeString, v: "3"}},
|
||||||
|
{typ: TypeInt32, val: &extJSONValue{t: TypeInt32, v: int32(4)}},
|
||||||
|
{typ: TypeString, val: &extJSONValue{t: TypeString, v: "string"}},
|
||||||
|
{typ: TypeDouble, val: &extJSONValue{t: TypeDouble, v: 5.0}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectMultipleValues, p: ejp,
|
||||||
|
k: "Timestamp", t: TypeTimestamp,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"t", "i"},
|
||||||
|
values: []*extJSONValue{
|
||||||
|
{t: TypeInt32, v: int32(42)},
|
||||||
|
{t: TypeInt32, v: int32(1)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectMultipleValues, p: ejp,
|
||||||
|
k: "RegularExpression", t: TypeRegex,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"pattern", "options"},
|
||||||
|
values: []*extJSONValue{
|
||||||
|
{t: TypeString, v: "foo*"},
|
||||||
|
{t: TypeString, v: "ix"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectMultipleValues, p: ejp,
|
||||||
|
k: "DatetimeEpoch", t: TypeDateTime,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"$numberLong"},
|
||||||
|
values: []*extJSONValue{
|
||||||
|
{t: TypeString, v: "0"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectMultipleValues, p: ejp,
|
||||||
|
k: "DatetimePositive", t: TypeDateTime,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"$numberLong"},
|
||||||
|
values: []*extJSONValue{
|
||||||
|
{t: TypeString, v: "9223372036854775807"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectMultipleValues, p: ejp,
|
||||||
|
k: "DatetimeNegative", t: TypeDateTime,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"$numberLong"},
|
||||||
|
values: []*extJSONValue{
|
||||||
|
{t: TypeString, v: "-9223372036854775808"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "True", t: TypeBoolean, v: &extJSONValue{t: TypeBoolean, v: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "False", t: TypeBoolean, v: &extJSONValue{t: TypeBoolean, v: false},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectMultipleValues, p: ejp,
|
||||||
|
k: "DBPointer", t: TypeDBPointer,
|
||||||
|
v: &extJSONObject{
|
||||||
|
keys: []string{"$ref", "$id"},
|
||||||
|
values: []*extJSONValue{
|
||||||
|
{t: TypeString, v: "db.collection"},
|
||||||
|
{t: TypeString, v: "57e193d7a9cc81b4027498b1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSubDocument, p: ejp,
|
||||||
|
k: "DBRef", t: TypeEmbeddedDocument,
|
||||||
|
v: ejpSubDocumentTestValue{
|
||||||
|
ktvs: []ejpKeyTypValTriple{
|
||||||
|
{"$ref", TypeString, &extJSONValue{t: TypeString, v: "collection"}},
|
||||||
|
{"$id", TypeObjectID, &extJSONValue{t: TypeString, v: "57fd71e96e32ab4225b723fb"}},
|
||||||
|
{"$db", TypeString, &extJSONValue{t: TypeString, v: "database"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSubDocument, p: ejp,
|
||||||
|
k: "DBRefNoDB", t: TypeEmbeddedDocument,
|
||||||
|
v: ejpSubDocumentTestValue{
|
||||||
|
ktvs: []ejpKeyTypValTriple{
|
||||||
|
{"$ref", TypeString, &extJSONValue{t: TypeString, v: "collection"}},
|
||||||
|
{"$id", TypeObjectID, &extJSONValue{t: TypeString, v: "57fd71e96e32ab4225b723fb"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "MinKey", t: TypeMinKey, v: &extJSONValue{t: TypeInt32, v: int32(1)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "MaxKey", t: TypeMaxKey, v: &extJSONValue{t: TypeInt32, v: int32(1)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "Null", t: TypeNull, v: &extJSONValue{t: TypeNull, v: nil},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
f: expectSingleValue, p: ejp,
|
||||||
|
k: "Undefined", t: TypeUndefined, v: &extJSONValue{t: TypeBoolean, v: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// run the test cases
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc.f(t, tc.p, tc.k, tc.t, tc.v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// expect end of whole document: read final }
|
||||||
|
k, typ, err := ejp.readKey()
|
||||||
|
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOD, "")
|
||||||
|
|
||||||
|
// expect end of whole document: read EOF
|
||||||
|
k, typ, err = ejp.readKey()
|
||||||
|
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOF, "")
|
||||||
|
if diff := cmp.Diff(jpsDoneState, ejp.s); diff != "" {
|
||||||
|
t.Errorf("expected parser to be in done state but instead is in %v\n", ejp.s)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtJSONValue(t *testing.T) {
|
||||||
|
t.Run("Large Date", func(t *testing.T) {
|
||||||
|
val := &extJSONValue{
|
||||||
|
t: TypeString,
|
||||||
|
v: "3001-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := val.parseDateTime()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error parsing date time: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if intVal <= 0 {
|
||||||
|
t.Fatalf("expected value above 0, got %v", intVal)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("fallback time format", func(t *testing.T) {
|
||||||
|
val := &extJSONValue{
|
||||||
|
t: TypeString,
|
||||||
|
v: "2019-06-04T14:54:31.416+0000",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := val.parseDateTime()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error parsing date time: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
46
extjson_prose_test.go
Normal file
46
extjson_prose_test.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtJSON(t *testing.T) {
|
||||||
|
timestampNegativeInt32Err := fmt.Errorf("$timestamp i number should be uint32: -1")
|
||||||
|
timestampNegativeInt64Err := fmt.Errorf("$timestamp i number should be uint32: -2147483649")
|
||||||
|
timestampLargeValueErr := fmt.Errorf("$timestamp i number should be uint32: 4294967296")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
canonical bool
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"timestamp - negative int32 value", `{"":{"$timestamp":{"t":0,"i":-1}}}`, false, timestampNegativeInt32Err},
|
||||||
|
{"timestamp - negative int64 value", `{"":{"$timestamp":{"t":0,"i":-2147483649}}}`, false, timestampNegativeInt64Err},
|
||||||
|
{"timestamp - value overflows uint32", `{"":{"$timestamp":{"t":0,"i":4294967296}}}`, false, timestampLargeValueErr},
|
||||||
|
{"top level key is not treated as special", `{"$code": "foo"}`, false, nil},
|
||||||
|
{"escaped single quote errors", `{"f\'oo": "bar"}`, false, ErrInvalidJSON},
|
||||||
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var res Raw
|
||||||
|
err := UnmarshalExtJSON([]byte(tc.input), tc.canonical, &res)
|
||||||
|
if tc.err == nil {
|
||||||
|
assert.Nil(t, err, "UnmarshalExtJSON error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotNil(t, err, "expected error %v, got nil", tc.err)
|
||||||
|
assert.Equal(t, tc.err.Error(), err.Error(), "expected error %v, got %v", tc.err, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
606
extjson_reader.go
Normal file
606
extjson_reader.go
Normal file
@ -0,0 +1,606 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ejvrState struct {
|
||||||
|
mode mode
|
||||||
|
vType Type
|
||||||
|
depth int
|
||||||
|
}
|
||||||
|
|
||||||
|
// extJSONValueReader is for reading extended JSON.
|
||||||
|
type extJSONValueReader struct {
|
||||||
|
p *extJSONParser
|
||||||
|
|
||||||
|
stack []ejvrState
|
||||||
|
frame int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewExtJSONValueReader returns a ValueReader that reads Extended JSON values
|
||||||
|
// from r. If canonicalOnly is true, reading values from the ValueReader returns
|
||||||
|
// an error if the Extended JSON was not marshaled in canonical mode.
|
||||||
|
func NewExtJSONValueReader(r io.Reader, canonicalOnly bool) (ValueReader, error) {
|
||||||
|
return newExtJSONValueReader(r, canonicalOnly)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newExtJSONValueReader(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
|
||||||
|
ejvr := new(extJSONValueReader)
|
||||||
|
return ejvr.reset(r, canonicalOnly)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) reset(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
|
||||||
|
p := newExtJSONParser(r, canonicalOnly)
|
||||||
|
typ, err := p.peekType()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrInvalidJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
var m mode
|
||||||
|
switch typ {
|
||||||
|
case TypeEmbeddedDocument:
|
||||||
|
m = mTopLevel
|
||||||
|
case TypeArray:
|
||||||
|
m = mArray
|
||||||
|
default:
|
||||||
|
m = mValue
|
||||||
|
}
|
||||||
|
|
||||||
|
stack := make([]ejvrState, 1, 5)
|
||||||
|
stack[0] = ejvrState{
|
||||||
|
mode: m,
|
||||||
|
vType: typ,
|
||||||
|
}
|
||||||
|
return &extJSONValueReader{
|
||||||
|
p: p,
|
||||||
|
stack: stack,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) advanceFrame() {
|
||||||
|
if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack
|
||||||
|
length := len(ejvr.stack)
|
||||||
|
if length+1 >= cap(ejvr.stack) {
|
||||||
|
// double it
|
||||||
|
buf := make([]ejvrState, 2*cap(ejvr.stack)+1)
|
||||||
|
copy(buf, ejvr.stack)
|
||||||
|
ejvr.stack = buf
|
||||||
|
}
|
||||||
|
ejvr.stack = ejvr.stack[:length+1]
|
||||||
|
}
|
||||||
|
ejvr.frame++
|
||||||
|
|
||||||
|
// Clean the stack
|
||||||
|
ejvr.stack[ejvr.frame].mode = 0
|
||||||
|
ejvr.stack[ejvr.frame].vType = 0
|
||||||
|
ejvr.stack[ejvr.frame].depth = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) pushDocument() {
|
||||||
|
ejvr.advanceFrame()
|
||||||
|
|
||||||
|
ejvr.stack[ejvr.frame].mode = mDocument
|
||||||
|
ejvr.stack[ejvr.frame].depth = ejvr.p.depth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) pushCodeWithScope() {
|
||||||
|
ejvr.advanceFrame()
|
||||||
|
|
||||||
|
ejvr.stack[ejvr.frame].mode = mCodeWithScope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) pushArray() {
|
||||||
|
ejvr.advanceFrame()
|
||||||
|
|
||||||
|
ejvr.stack[ejvr.frame].mode = mArray
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) push(m mode, t Type) {
|
||||||
|
ejvr.advanceFrame()
|
||||||
|
|
||||||
|
ejvr.stack[ejvr.frame].mode = m
|
||||||
|
ejvr.stack[ejvr.frame].vType = t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) pop() {
|
||||||
|
switch ejvr.stack[ejvr.frame].mode {
|
||||||
|
case mElement, mValue:
|
||||||
|
ejvr.frame--
|
||||||
|
case mDocument, mArray, mCodeWithScope:
|
||||||
|
ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) skipObject() {
|
||||||
|
// read entire object until depth returns to 0 (last ending } or ] seen)
|
||||||
|
depth := 1
|
||||||
|
for depth > 0 {
|
||||||
|
ejvr.p.advanceState()
|
||||||
|
|
||||||
|
// If object is empty, raise depth and continue. When emptyObject is true, the
|
||||||
|
// parser has already read both the opening and closing brackets of an empty
|
||||||
|
// object ("{}"), so the next valid token will be part of the parent document,
|
||||||
|
// not part of the nested document.
|
||||||
|
//
|
||||||
|
// If there is a comma, there are remaining fields, emptyObject must be set back
|
||||||
|
// to false, and comma must be skipped with advanceState().
|
||||||
|
if ejvr.p.emptyObject {
|
||||||
|
if ejvr.p.s == jpsSawComma {
|
||||||
|
ejvr.p.emptyObject = false
|
||||||
|
ejvr.p.advanceState()
|
||||||
|
}
|
||||||
|
depth--
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ejvr.p.s {
|
||||||
|
case jpsSawBeginObject, jpsSawBeginArray:
|
||||||
|
depth++
|
||||||
|
case jpsSawEndObject, jpsSawEndArray:
|
||||||
|
depth--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
|
||||||
|
te := TransitionError{
|
||||||
|
name: name,
|
||||||
|
current: ejvr.stack[ejvr.frame].mode,
|
||||||
|
destination: destination,
|
||||||
|
modes: modes,
|
||||||
|
action: "read",
|
||||||
|
}
|
||||||
|
if ejvr.frame != 0 {
|
||||||
|
te.parent = ejvr.stack[ejvr.frame-1].mode
|
||||||
|
}
|
||||||
|
return te
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) typeError(t Type) error {
|
||||||
|
return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ensureElementValue(t Type, destination mode, callerName string, addModes ...mode) error {
|
||||||
|
switch ejvr.stack[ejvr.frame].mode {
|
||||||
|
case mElement, mValue:
|
||||||
|
if ejvr.stack[ejvr.frame].vType != t {
|
||||||
|
return ejvr.typeError(t)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
modes := []mode{mElement, mValue}
|
||||||
|
if addModes != nil {
|
||||||
|
modes = append(modes, addModes...)
|
||||||
|
}
|
||||||
|
return ejvr.invalidTransitionErr(destination, callerName, modes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) Type() Type {
|
||||||
|
return ejvr.stack[ejvr.frame].vType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) Skip() error {
|
||||||
|
switch ejvr.stack[ejvr.frame].mode {
|
||||||
|
case mElement, mValue:
|
||||||
|
default:
|
||||||
|
return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
|
||||||
|
}
|
||||||
|
|
||||||
|
defer ejvr.pop()
|
||||||
|
|
||||||
|
t := ejvr.stack[ejvr.frame].vType
|
||||||
|
switch t {
|
||||||
|
case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope:
|
||||||
|
// read entire array, doc or CodeWithScope
|
||||||
|
ejvr.skipObject()
|
||||||
|
default:
|
||||||
|
_, err := ejvr.p.readValue(t)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) {
|
||||||
|
switch ejvr.stack[ejvr.frame].mode {
|
||||||
|
case mTopLevel: // allow reading array from top level
|
||||||
|
case mArray:
|
||||||
|
return ejvr, nil
|
||||||
|
default:
|
||||||
|
if err := ejvr.ensureElementValue(TypeArray, mArray, "ReadArray", mTopLevel, mArray); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvr.pushArray()
|
||||||
|
|
||||||
|
return ejvr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) {
|
||||||
|
if err := ejvr.ensureElementValue(TypeBinary, 0, "ReadBinary"); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeBinary)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b, btype, err = v.parseBinary()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return b, btype, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) {
|
||||||
|
if err := ejvr.ensureElementValue(TypeBoolean, 0, "ReadBoolean"); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeBoolean)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.t != TypeBoolean {
|
||||||
|
return false, fmt.Errorf("expected type bool, but got type %s", v.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return v.v.(bool), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) {
|
||||||
|
switch ejvr.stack[ejvr.frame].mode {
|
||||||
|
case mTopLevel:
|
||||||
|
return ejvr, nil
|
||||||
|
case mElement, mValue:
|
||||||
|
if ejvr.stack[ejvr.frame].vType != TypeEmbeddedDocument {
|
||||||
|
return nil, ejvr.typeError(TypeEmbeddedDocument)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvr.pushDocument()
|
||||||
|
return ejvr, nil
|
||||||
|
default:
|
||||||
|
return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
|
||||||
|
if err = ejvr.ensureElementValue(TypeCodeWithScope, 0, "ReadCodeWithScope"); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeCodeWithScope)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
code, err = v.parseJavascript()
|
||||||
|
|
||||||
|
ejvr.pushCodeWithScope()
|
||||||
|
return code, ejvr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid ObjectID, err error) {
|
||||||
|
if err = ejvr.ensureElementValue(TypeDBPointer, 0, "ReadDBPointer"); err != nil {
|
||||||
|
return "", NilObjectID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeDBPointer)
|
||||||
|
if err != nil {
|
||||||
|
return "", NilObjectID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ns, oid, err = v.parseDBPointer()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return ns, oid, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) {
|
||||||
|
if err := ejvr.ensureElementValue(TypeDateTime, 0, "ReadDateTime"); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeDateTime)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := v.parseDateTime()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return d, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadDecimal128() (Decimal128, error) {
|
||||||
|
if err := ejvr.ensureElementValue(TypeDecimal128, 0, "ReadDecimal128"); err != nil {
|
||||||
|
return Decimal128{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeDecimal128)
|
||||||
|
if err != nil {
|
||||||
|
return Decimal128{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := v.parseDecimal128()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return d, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadDouble() (float64, error) {
|
||||||
|
if err := ejvr.ensureElementValue(TypeDouble, 0, "ReadDouble"); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeDouble)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := v.parseDouble()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return d, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadInt32() (int32, error) {
|
||||||
|
if err := ejvr.ensureElementValue(TypeInt32, 0, "ReadInt32"); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeInt32)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := v.parseInt32()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadInt64() (int64, error) {
|
||||||
|
if err := ejvr.ensureElementValue(TypeInt64, 0, "ReadInt64"); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeInt64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := v.parseInt64()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) {
|
||||||
|
if err = ejvr.ensureElementValue(TypeJavaScript, 0, "ReadJavascript"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeJavaScript)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
code, err = v.parseJavascript()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return code, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadMaxKey() error {
|
||||||
|
if err := ejvr.ensureElementValue(TypeMaxKey, 0, "ReadMaxKey"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeMaxKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = v.parseMinMaxKey("max")
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadMinKey() error {
|
||||||
|
if err := ejvr.ensureElementValue(TypeMinKey, 0, "ReadMinKey"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeMinKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = v.parseMinMaxKey("min")
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadNull() error {
|
||||||
|
if err := ejvr.ensureElementValue(TypeNull, 0, "ReadNull"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeNull)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.t != TypeNull {
|
||||||
|
return fmt.Errorf("expected type null but got type %s", v.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadObjectID() (ObjectID, error) {
|
||||||
|
if err := ejvr.ensureElementValue(TypeObjectID, 0, "ReadObjectID"); err != nil {
|
||||||
|
return ObjectID{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeObjectID)
|
||||||
|
if err != nil {
|
||||||
|
return ObjectID{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
oid, err := v.parseObjectID()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return oid, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) {
|
||||||
|
if err = ejvr.ensureElementValue(TypeRegex, 0, "ReadRegex"); err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeRegex)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern, options, err = v.parseRegex()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return pattern, options, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadString() (string, error) {
|
||||||
|
if err := ejvr.ensureElementValue(TypeString, 0, "ReadString"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeString)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.t != TypeString {
|
||||||
|
return "", fmt.Errorf("expected type string but got type %s", v.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return v.v.(string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) {
|
||||||
|
if err = ejvr.ensureElementValue(TypeSymbol, 0, "ReadSymbol"); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeSymbol)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
symbol, err = v.parseSymbol()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return symbol, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) {
|
||||||
|
if err = ejvr.ensureElementValue(TypeTimestamp, 0, "ReadTimestamp"); err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeTimestamp)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t, i, err = v.parseTimestamp()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return t, i, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadUndefined() error {
|
||||||
|
if err := ejvr.ensureElementValue(TypeUndefined, 0, "ReadUndefined"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, err := ejvr.p.readValue(TypeUndefined)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = v.parseUndefined()
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
|
||||||
|
switch ejvr.stack[ejvr.frame].mode {
|
||||||
|
case mTopLevel, mDocument, mCodeWithScope:
|
||||||
|
default:
|
||||||
|
return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
|
||||||
|
}
|
||||||
|
|
||||||
|
name, t, err := ejvr.p.readKey()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrEOD) {
|
||||||
|
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
|
||||||
|
_, err := ejvr.p.peekType()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvr.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvr.push(mElement, t)
|
||||||
|
return name, ejvr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {
|
||||||
|
switch ejvr.stack[ejvr.frame].mode {
|
||||||
|
case mArray:
|
||||||
|
default:
|
||||||
|
return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := ejvr.p.peekType()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrEOA) {
|
||||||
|
ejvr.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvr.push(mValue, t)
|
||||||
|
return ejvr, nil
|
||||||
|
}
|
168
extjson_reader_test.go
Normal file
168
extjson_reader_test.go
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtJSONReader(t *testing.T) {
|
||||||
|
t.Run("ReadDocument", func(t *testing.T) {
|
||||||
|
t.Run("EmbeddedDocument", func(t *testing.T) {
|
||||||
|
ejvr := &extJSONValueReader{
|
||||||
|
stack: []ejvrState{
|
||||||
|
{mode: mTopLevel},
|
||||||
|
{mode: mElement, vType: TypeBoolean},
|
||||||
|
},
|
||||||
|
frame: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvr.stack[1].mode = mArray
|
||||||
|
wanterr := ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
|
||||||
|
_, err := ejvr.ReadDocument()
|
||||||
|
if err == nil || err.Error() != wanterr.Error() {
|
||||||
|
t.Errorf("Incorrect returned error. got %v; want %v", err, wanterr)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid transition", func(t *testing.T) {
|
||||||
|
t.Run("Skip", func(t *testing.T) {
|
||||||
|
ejvr := &extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}}
|
||||||
|
wanterr := (&extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}}).invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
|
||||||
|
goterr := ejvr.Skip()
|
||||||
|
if !cmp.Equal(goterr, wanterr, cmp.Comparer(assert.CompareErrors)) {
|
||||||
|
t.Errorf("Expected correct invalid transition error. got %v; want %v", goterr, wanterr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadMultipleTopLevelDocuments(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected [][]byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"single top-level document",
|
||||||
|
"{\"foo\":1}",
|
||||||
|
[][]byte{
|
||||||
|
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"single top-level document with leading and trailing whitespace",
|
||||||
|
"\n\n {\"foo\":1} \n",
|
||||||
|
[][]byte{
|
||||||
|
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"two top-level documents",
|
||||||
|
"{\"foo\":1}{\"foo\":2}",
|
||||||
|
[][]byte{
|
||||||
|
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"two top-level documents with leading and trailing whitespace and whitespace separation ",
|
||||||
|
"\n\n {\"foo\":1}\n{\"foo\":2}\n ",
|
||||||
|
[][]byte{
|
||||||
|
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"top-level array with single document",
|
||||||
|
"[{\"foo\":1}]",
|
||||||
|
[][]byte{
|
||||||
|
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"top-level array with 2 documents",
|
||||||
|
"[{\"foo\":1},{\"foo\":2}]",
|
||||||
|
[][]byte{
|
||||||
|
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
r := strings.NewReader(tc.input)
|
||||||
|
vr, err := NewExtJSONValueReader(r, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual, err := readAllDocuments(vr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tc.expected, actual); diff != "" {
|
||||||
|
t.Fatalf("expected does not match actual: %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readAllDocuments(vr ValueReader) ([][]byte, error) {
|
||||||
|
var actual [][]byte
|
||||||
|
|
||||||
|
switch vr.Type() {
|
||||||
|
case TypeEmbeddedDocument:
|
||||||
|
for {
|
||||||
|
result, err := copyDocumentToBytes(vr)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = append(actual, result)
|
||||||
|
}
|
||||||
|
case TypeArray:
|
||||||
|
ar, err := vr.ReadArray()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
evr, err := ar.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrEOA) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := copyDocumentToBytes(evr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = append(actual, result)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("expected an array or a document, but got %s", vr.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
return actual, nil
|
||||||
|
}
|
223
extjson_tables.go
Normal file
223
extjson_tables.go
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/golang/go by The Go Authors
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms.
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import "unicode/utf8"
|
||||||
|
|
||||||
|
// safeSet holds the value true if the ASCII character with the given array
|
||||||
|
// position can be represented inside a JSON string without any further
|
||||||
|
// escaping.
|
||||||
|
//
|
||||||
|
// All values are true except for the ASCII control characters (0-31), the
|
||||||
|
// double quote ("), and the backslash character ("\").
|
||||||
|
var safeSet = [utf8.RuneSelf]bool{
|
||||||
|
' ': true,
|
||||||
|
'!': true,
|
||||||
|
'"': false,
|
||||||
|
'#': true,
|
||||||
|
'$': true,
|
||||||
|
'%': true,
|
||||||
|
'&': true,
|
||||||
|
'\'': true,
|
||||||
|
'(': true,
|
||||||
|
')': true,
|
||||||
|
'*': true,
|
||||||
|
'+': true,
|
||||||
|
',': true,
|
||||||
|
'-': true,
|
||||||
|
'.': true,
|
||||||
|
'/': true,
|
||||||
|
'0': true,
|
||||||
|
'1': true,
|
||||||
|
'2': true,
|
||||||
|
'3': true,
|
||||||
|
'4': true,
|
||||||
|
'5': true,
|
||||||
|
'6': true,
|
||||||
|
'7': true,
|
||||||
|
'8': true,
|
||||||
|
'9': true,
|
||||||
|
':': true,
|
||||||
|
';': true,
|
||||||
|
'<': true,
|
||||||
|
'=': true,
|
||||||
|
'>': true,
|
||||||
|
'?': true,
|
||||||
|
'@': true,
|
||||||
|
'A': true,
|
||||||
|
'B': true,
|
||||||
|
'C': true,
|
||||||
|
'D': true,
|
||||||
|
'E': true,
|
||||||
|
'F': true,
|
||||||
|
'G': true,
|
||||||
|
'H': true,
|
||||||
|
'I': true,
|
||||||
|
'J': true,
|
||||||
|
'K': true,
|
||||||
|
'L': true,
|
||||||
|
'M': true,
|
||||||
|
'N': true,
|
||||||
|
'O': true,
|
||||||
|
'P': true,
|
||||||
|
'Q': true,
|
||||||
|
'R': true,
|
||||||
|
'S': true,
|
||||||
|
'T': true,
|
||||||
|
'U': true,
|
||||||
|
'V': true,
|
||||||
|
'W': true,
|
||||||
|
'X': true,
|
||||||
|
'Y': true,
|
||||||
|
'Z': true,
|
||||||
|
'[': true,
|
||||||
|
'\\': false,
|
||||||
|
']': true,
|
||||||
|
'^': true,
|
||||||
|
'_': true,
|
||||||
|
'`': true,
|
||||||
|
'a': true,
|
||||||
|
'b': true,
|
||||||
|
'c': true,
|
||||||
|
'd': true,
|
||||||
|
'e': true,
|
||||||
|
'f': true,
|
||||||
|
'g': true,
|
||||||
|
'h': true,
|
||||||
|
'i': true,
|
||||||
|
'j': true,
|
||||||
|
'k': true,
|
||||||
|
'l': true,
|
||||||
|
'm': true,
|
||||||
|
'n': true,
|
||||||
|
'o': true,
|
||||||
|
'p': true,
|
||||||
|
'q': true,
|
||||||
|
'r': true,
|
||||||
|
's': true,
|
||||||
|
't': true,
|
||||||
|
'u': true,
|
||||||
|
'v': true,
|
||||||
|
'w': true,
|
||||||
|
'x': true,
|
||||||
|
'y': true,
|
||||||
|
'z': true,
|
||||||
|
'{': true,
|
||||||
|
'|': true,
|
||||||
|
'}': true,
|
||||||
|
'~': true,
|
||||||
|
'\u007f': true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// htmlSafeSet holds the value true if the ASCII character with the given
|
||||||
|
// array position can be safely represented inside a JSON string, embedded
|
||||||
|
// inside of HTML <script> tags, without any additional escaping.
|
||||||
|
//
|
||||||
|
// All values are true except for the ASCII control characters (0-31), the
|
||||||
|
// double quote ("), the backslash character ("\"), HTML opening and closing
|
||||||
|
// tags ("<" and ">"), and the ampersand ("&").
|
||||||
|
var htmlSafeSet = [utf8.RuneSelf]bool{
|
||||||
|
' ': true,
|
||||||
|
'!': true,
|
||||||
|
'"': false,
|
||||||
|
'#': true,
|
||||||
|
'$': true,
|
||||||
|
'%': true,
|
||||||
|
'&': false,
|
||||||
|
'\'': true,
|
||||||
|
'(': true,
|
||||||
|
')': true,
|
||||||
|
'*': true,
|
||||||
|
'+': true,
|
||||||
|
',': true,
|
||||||
|
'-': true,
|
||||||
|
'.': true,
|
||||||
|
'/': true,
|
||||||
|
'0': true,
|
||||||
|
'1': true,
|
||||||
|
'2': true,
|
||||||
|
'3': true,
|
||||||
|
'4': true,
|
||||||
|
'5': true,
|
||||||
|
'6': true,
|
||||||
|
'7': true,
|
||||||
|
'8': true,
|
||||||
|
'9': true,
|
||||||
|
':': true,
|
||||||
|
';': true,
|
||||||
|
'<': false,
|
||||||
|
'=': true,
|
||||||
|
'>': false,
|
||||||
|
'?': true,
|
||||||
|
'@': true,
|
||||||
|
'A': true,
|
||||||
|
'B': true,
|
||||||
|
'C': true,
|
||||||
|
'D': true,
|
||||||
|
'E': true,
|
||||||
|
'F': true,
|
||||||
|
'G': true,
|
||||||
|
'H': true,
|
||||||
|
'I': true,
|
||||||
|
'J': true,
|
||||||
|
'K': true,
|
||||||
|
'L': true,
|
||||||
|
'M': true,
|
||||||
|
'N': true,
|
||||||
|
'O': true,
|
||||||
|
'P': true,
|
||||||
|
'Q': true,
|
||||||
|
'R': true,
|
||||||
|
'S': true,
|
||||||
|
'T': true,
|
||||||
|
'U': true,
|
||||||
|
'V': true,
|
||||||
|
'W': true,
|
||||||
|
'X': true,
|
||||||
|
'Y': true,
|
||||||
|
'Z': true,
|
||||||
|
'[': true,
|
||||||
|
'\\': false,
|
||||||
|
']': true,
|
||||||
|
'^': true,
|
||||||
|
'_': true,
|
||||||
|
'`': true,
|
||||||
|
'a': true,
|
||||||
|
'b': true,
|
||||||
|
'c': true,
|
||||||
|
'd': true,
|
||||||
|
'e': true,
|
||||||
|
'f': true,
|
||||||
|
'g': true,
|
||||||
|
'h': true,
|
||||||
|
'i': true,
|
||||||
|
'j': true,
|
||||||
|
'k': true,
|
||||||
|
'l': true,
|
||||||
|
'm': true,
|
||||||
|
'n': true,
|
||||||
|
'o': true,
|
||||||
|
'p': true,
|
||||||
|
'q': true,
|
||||||
|
'r': true,
|
||||||
|
's': true,
|
||||||
|
't': true,
|
||||||
|
'u': true,
|
||||||
|
'v': true,
|
||||||
|
'w': true,
|
||||||
|
'x': true,
|
||||||
|
'y': true,
|
||||||
|
'z': true,
|
||||||
|
'{': true,
|
||||||
|
'|': true,
|
||||||
|
'}': true,
|
||||||
|
'~': true,
|
||||||
|
'\u007f': true,
|
||||||
|
}
|
489
extjson_wrappers.go
Normal file
489
extjson_wrappers.go
Normal file
@ -0,0 +1,489 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func wrapperKeyBSONType(key string) Type {
|
||||||
|
switch key {
|
||||||
|
case "$numberInt":
|
||||||
|
return TypeInt32
|
||||||
|
case "$numberLong":
|
||||||
|
return TypeInt64
|
||||||
|
case "$oid":
|
||||||
|
return TypeObjectID
|
||||||
|
case "$symbol":
|
||||||
|
return TypeSymbol
|
||||||
|
case "$numberDouble":
|
||||||
|
return TypeDouble
|
||||||
|
case "$numberDecimal":
|
||||||
|
return TypeDecimal128
|
||||||
|
case "$binary":
|
||||||
|
return TypeBinary
|
||||||
|
case "$code":
|
||||||
|
return TypeJavaScript
|
||||||
|
case "$scope":
|
||||||
|
return TypeCodeWithScope
|
||||||
|
case "$timestamp":
|
||||||
|
return TypeTimestamp
|
||||||
|
case "$regularExpression":
|
||||||
|
return TypeRegex
|
||||||
|
case "$dbPointer":
|
||||||
|
return TypeDBPointer
|
||||||
|
case "$date":
|
||||||
|
return TypeDateTime
|
||||||
|
case "$minKey":
|
||||||
|
return TypeMinKey
|
||||||
|
case "$maxKey":
|
||||||
|
return TypeMaxKey
|
||||||
|
case "$undefined":
|
||||||
|
return TypeUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
return TypeEmbeddedDocument
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
|
||||||
|
if ejv.t != TypeEmbeddedDocument {
|
||||||
|
return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
binObj := ejv.v.(*extJSONObject)
|
||||||
|
bFound := false
|
||||||
|
stFound := false
|
||||||
|
|
||||||
|
for i, key := range binObj.keys {
|
||||||
|
val := binObj.values[i]
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case "base64":
|
||||||
|
if bFound {
|
||||||
|
return nil, 0, errors.New("duplicate base64 key in $binary")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.t != TypeString {
|
||||||
|
return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string))
|
||||||
|
}
|
||||||
|
|
||||||
|
b = base64Bytes
|
||||||
|
bFound = true
|
||||||
|
case "subType":
|
||||||
|
if stFound {
|
||||||
|
return nil, 0, errors.New("duplicate subType key in $binary")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.t != TypeString {
|
||||||
|
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := strconv.ParseUint(val.v.(string), 16, 8)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
subType = byte(i)
|
||||||
|
stFound = true
|
||||||
|
default:
|
||||||
|
return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bFound {
|
||||||
|
return nil, 0, errors.New("missing base64 field in $binary object")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !stFound {
|
||||||
|
return nil, 0, errors.New("missing subType field in $binary object")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return b, subType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseDBPointer() (ns string, oid ObjectID, err error) {
|
||||||
|
if ejv.t != TypeEmbeddedDocument {
|
||||||
|
return "", NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbpObj := ejv.v.(*extJSONObject)
|
||||||
|
oidFound := false
|
||||||
|
nsFound := false
|
||||||
|
|
||||||
|
for i, key := range dbpObj.keys {
|
||||||
|
val := dbpObj.values[i]
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case "$ref":
|
||||||
|
if nsFound {
|
||||||
|
return "", NilObjectID, errors.New("duplicate $ref key in $dbPointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.t != TypeString {
|
||||||
|
return "", NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
ns = val.v.(string)
|
||||||
|
nsFound = true
|
||||||
|
case "$id":
|
||||||
|
if oidFound {
|
||||||
|
return "", NilObjectID, errors.New("duplicate $id key in $dbPointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.t != TypeString {
|
||||||
|
return "", NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
oid, err = ObjectIDFromHex(val.v.(string))
|
||||||
|
if err != nil {
|
||||||
|
return "", NilObjectID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
oidFound = true
|
||||||
|
default:
|
||||||
|
return "", NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !nsFound {
|
||||||
|
return "", oid, errors.New("missing $ref field in $dbPointer object")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !oidFound {
|
||||||
|
return "", oid, errors.New("missing $id field in $dbPointer object")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ns, oid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
rfc3339Milli = "2006-01-02T15:04:05.999Z07:00"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
timeFormats = []string{rfc3339Milli, "2006-01-02T15:04:05.999Z0700"}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseDateTime() (int64, error) {
|
||||||
|
switch ejv.t {
|
||||||
|
case TypeInt32:
|
||||||
|
return int64(ejv.v.(int32)), nil
|
||||||
|
case TypeInt64:
|
||||||
|
return ejv.v.(int64), nil
|
||||||
|
case TypeString:
|
||||||
|
return parseDatetimeString(ejv.v.(string))
|
||||||
|
case TypeEmbeddedDocument:
|
||||||
|
return parseDatetimeObject(ejv.v.(*extJSONObject))
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDatetimeString(data string) (int64, error) {
|
||||||
|
var t time.Time
|
||||||
|
var err error
|
||||||
|
// try acceptable time formats until one matches
|
||||||
|
for _, format := range timeFormats {
|
||||||
|
t, err = time.Parse(format, data)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid $date value string: %s", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(NewDateTimeFromTime(t)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDatetimeObject(data *extJSONObject) (d int64, err error) {
|
||||||
|
dFound := false
|
||||||
|
|
||||||
|
for i, key := range data.keys {
|
||||||
|
val := data.values[i]
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case "$numberLong":
|
||||||
|
if dFound {
|
||||||
|
return 0, errors.New("duplicate $numberLong key in $date")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.t != TypeString {
|
||||||
|
return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err = val.parseInt64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
dFound = true
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("invalid key in $date object: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !dFound {
|
||||||
|
return 0, errors.New("missing $numberLong field in $date object")
|
||||||
|
}
|
||||||
|
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseDecimal128() (Decimal128, error) {
|
||||||
|
if ejv.t != TypeString {
|
||||||
|
return Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := ParseDecimal128(ejv.v.(string))
|
||||||
|
if err != nil {
|
||||||
|
return Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string))
|
||||||
|
}
|
||||||
|
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseDouble() (float64, error) {
|
||||||
|
if ejv.t == TypeDouble {
|
||||||
|
return ejv.v.(float64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ejv.t != TypeString {
|
||||||
|
return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ejv.v.(string) {
|
||||||
|
case "Infinity":
|
||||||
|
return math.Inf(1), nil
|
||||||
|
case "-Infinity":
|
||||||
|
return math.Inf(-1), nil
|
||||||
|
case "NaN":
|
||||||
|
return math.NaN(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := strconv.ParseFloat(ejv.v.(string), 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseInt32() (int32, error) {
|
||||||
|
if ejv.t == TypeInt32 {
|
||||||
|
return ejv.v.(int32), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ejv.t != TypeString {
|
||||||
|
return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if i < math.MinInt32 || i > math.MaxInt32 {
|
||||||
|
return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int32(i), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseInt64() (int64, error) {
|
||||||
|
if ejv.t == TypeInt64 {
|
||||||
|
return ejv.v.(int64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ejv.t != TypeString {
|
||||||
|
return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseJavascript() (code string, err error) {
|
||||||
|
if ejv.t != TypeString {
|
||||||
|
return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ejv.v.(string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseMinMaxKey(minmax string) error {
|
||||||
|
if ejv.t != TypeInt32 {
|
||||||
|
return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ejv.v.(int32) != 1 {
|
||||||
|
return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseObjectID() (ObjectID, error) {
|
||||||
|
if ejv.t != TypeString {
|
||||||
|
return NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ObjectIDFromHex(ejv.v.(string))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) {
|
||||||
|
if ejv.t != TypeEmbeddedDocument {
|
||||||
|
return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
regexObj := ejv.v.(*extJSONObject)
|
||||||
|
patFound := false
|
||||||
|
optFound := false
|
||||||
|
|
||||||
|
for i, key := range regexObj.keys {
|
||||||
|
val := regexObj.values[i]
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case "pattern":
|
||||||
|
if patFound {
|
||||||
|
return "", "", errors.New("duplicate pattern key in $regularExpression")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.t != TypeString {
|
||||||
|
return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern = val.v.(string)
|
||||||
|
patFound = true
|
||||||
|
case "options":
|
||||||
|
if optFound {
|
||||||
|
return "", "", errors.New("duplicate options key in $regularExpression")
|
||||||
|
}
|
||||||
|
|
||||||
|
if val.t != TypeString {
|
||||||
|
return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
options = val.v.(string)
|
||||||
|
optFound = true
|
||||||
|
default:
|
||||||
|
return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !patFound {
|
||||||
|
return "", "", errors.New("missing pattern field in $regularExpression object")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !optFound {
|
||||||
|
return "", "", errors.New("missing options field in $regularExpression object")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return pattern, options, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseSymbol() (string, error) {
|
||||||
|
if ejv.t != TypeString {
|
||||||
|
return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ejv.v.(string), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) {
|
||||||
|
if ejv.t != TypeEmbeddedDocument {
|
||||||
|
return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) {
|
||||||
|
if flag {
|
||||||
|
return 0, fmt.Errorf("duplicate %s key in $timestamp", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch val.t {
|
||||||
|
case TypeInt32:
|
||||||
|
value := val.v.(int32)
|
||||||
|
|
||||||
|
if value < 0 {
|
||||||
|
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint32(value), nil
|
||||||
|
case TypeInt64:
|
||||||
|
value := val.v.(int64)
|
||||||
|
if value < 0 || value > int64(math.MaxUint32) {
|
||||||
|
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint32(value), nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tsObj := ejv.v.(*extJSONObject)
|
||||||
|
tFound := false
|
||||||
|
iFound := false
|
||||||
|
|
||||||
|
for j, key := range tsObj.keys {
|
||||||
|
val := tsObj.values[j]
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case "t":
|
||||||
|
if t, err = handleKey(key, val, tFound); err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tFound = true
|
||||||
|
case "i":
|
||||||
|
if i, err = handleKey(key, val, iFound); err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
iFound = true
|
||||||
|
default:
|
||||||
|
return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tFound {
|
||||||
|
return 0, 0, errors.New("missing t field in $timestamp object")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !iFound {
|
||||||
|
return 0, 0, errors.New("missing i field in $timestamp object")
|
||||||
|
}
|
||||||
|
|
||||||
|
return t, i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejv *extJSONValue) parseUndefined() error {
|
||||||
|
if ejv.t != TypeBoolean {
|
||||||
|
return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ejv.v.(bool) {
|
||||||
|
return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
690
extjson_writer.go
Normal file
690
extjson_writer.go
Normal file
@ -0,0 +1,690 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ejvwState struct {
|
||||||
|
mode mode
|
||||||
|
}
|
||||||
|
|
||||||
|
type extJSONValueWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
buf []byte
|
||||||
|
|
||||||
|
stack []ejvwState
|
||||||
|
frame int64
|
||||||
|
canonical bool
|
||||||
|
escapeHTML bool
|
||||||
|
newlines bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
|
||||||
|
func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) ValueWriter {
|
||||||
|
// Enable newlines for all Extended JSON value writers created by NewExtJSONValueWriter. We
|
||||||
|
// expect these value writers to be used with an Encoder, which should add newlines after
|
||||||
|
// encoded Extended JSON documents.
|
||||||
|
return newExtJSONWriter(w, canonical, escapeHTML, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newExtJSONWriter(w io.Writer, canonical, escapeHTML, newlines bool) *extJSONValueWriter {
|
||||||
|
stack := make([]ejvwState, 1, 5)
|
||||||
|
stack[0] = ejvwState{mode: mTopLevel}
|
||||||
|
|
||||||
|
return &extJSONValueWriter{
|
||||||
|
w: w,
|
||||||
|
buf: []byte{},
|
||||||
|
stack: stack,
|
||||||
|
canonical: canonical,
|
||||||
|
escapeHTML: escapeHTML,
|
||||||
|
newlines: newlines,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
|
||||||
|
stack := make([]ejvwState, 1, 5)
|
||||||
|
stack[0] = ejvwState{mode: mTopLevel}
|
||||||
|
|
||||||
|
return &extJSONValueWriter{
|
||||||
|
buf: buf,
|
||||||
|
stack: stack,
|
||||||
|
canonical: canonical,
|
||||||
|
escapeHTML: escapeHTML,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
|
||||||
|
if ejvw.stack == nil {
|
||||||
|
ejvw.stack = make([]ejvwState, 1, 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.stack = ejvw.stack[:1]
|
||||||
|
ejvw.stack[0] = ejvwState{mode: mTopLevel}
|
||||||
|
ejvw.canonical = canonical
|
||||||
|
ejvw.escapeHTML = escapeHTML
|
||||||
|
ejvw.frame = 0
|
||||||
|
ejvw.buf = buf
|
||||||
|
ejvw.w = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) advanceFrame() {
|
||||||
|
if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
|
||||||
|
length := len(ejvw.stack)
|
||||||
|
if length+1 >= cap(ejvw.stack) {
|
||||||
|
// double it
|
||||||
|
buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
|
||||||
|
copy(buf, ejvw.stack)
|
||||||
|
ejvw.stack = buf
|
||||||
|
}
|
||||||
|
ejvw.stack = ejvw.stack[:length+1]
|
||||||
|
}
|
||||||
|
ejvw.frame++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) push(m mode) {
|
||||||
|
ejvw.advanceFrame()
|
||||||
|
|
||||||
|
ejvw.stack[ejvw.frame].mode = m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) pop() {
|
||||||
|
switch ejvw.stack[ejvw.frame].mode {
|
||||||
|
case mElement, mValue:
|
||||||
|
ejvw.frame--
|
||||||
|
case mDocument, mArray, mCodeWithScope:
|
||||||
|
ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
|
||||||
|
te := TransitionError{
|
||||||
|
name: name,
|
||||||
|
current: ejvw.stack[ejvw.frame].mode,
|
||||||
|
destination: destination,
|
||||||
|
modes: modes,
|
||||||
|
action: "write",
|
||||||
|
}
|
||||||
|
if ejvw.frame != 0 {
|
||||||
|
te.parent = ejvw.stack[ejvw.frame-1].mode
|
||||||
|
}
|
||||||
|
return te
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
|
||||||
|
switch ejvw.stack[ejvw.frame].mode {
|
||||||
|
case mElement, mValue:
|
||||||
|
default:
|
||||||
|
modes := []mode{mElement, mValue}
|
||||||
|
if addmodes != nil {
|
||||||
|
modes = append(modes, addmodes...)
|
||||||
|
}
|
||||||
|
return ejvw.invalidTransitionErr(destination, callerName, modes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
|
||||||
|
var s string
|
||||||
|
if quotes {
|
||||||
|
s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
|
||||||
|
} else {
|
||||||
|
s = fmt.Sprintf(`{"$%s":%s}`, key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, []byte(s)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
|
||||||
|
if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, '[')
|
||||||
|
|
||||||
|
ejvw.push(mArray)
|
||||||
|
return ejvw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
|
||||||
|
return ejvw.WriteBinaryWithSubtype(b, 0x00)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString(`{"$binary":{"base64":"`)
|
||||||
|
buf.WriteString(base64.StdEncoding.EncodeToString(b))
|
||||||
|
buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
|
||||||
|
if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString(`{"$code":`)
|
||||||
|
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
|
||||||
|
buf.WriteString(`,"$scope":{`)
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||||
|
|
||||||
|
ejvw.push(mCodeWithScope)
|
||||||
|
return ejvw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid ObjectID) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString(`{"$dbPointer":{"$ref":"`)
|
||||||
|
buf.WriteString(ns)
|
||||||
|
buf.WriteString(`","$id":{"$oid":"`)
|
||||||
|
buf.WriteString(oid.Hex())
|
||||||
|
buf.WriteString(`"}}},`)
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
|
||||||
|
|
||||||
|
if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
|
||||||
|
s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
|
||||||
|
ejvw.writeExtendedSingleValue("date", s, false)
|
||||||
|
} else {
|
||||||
|
ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteDecimal128(d Decimal128) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
|
||||||
|
if ejvw.stack[ejvw.frame].mode == mTopLevel {
|
||||||
|
ejvw.buf = append(ejvw.buf, '{')
|
||||||
|
return ejvw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, '{')
|
||||||
|
ejvw.push(mDocument)
|
||||||
|
return ejvw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s := formatDouble(f)
|
||||||
|
|
||||||
|
if ejvw.canonical {
|
||||||
|
ejvw.writeExtendedSingleValue("numberDouble", s, true)
|
||||||
|
} else {
|
||||||
|
switch s {
|
||||||
|
case "Infinity":
|
||||||
|
fallthrough
|
||||||
|
case "-Infinity":
|
||||||
|
fallthrough
|
||||||
|
case "NaN":
|
||||||
|
s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
|
||||||
|
}
|
||||||
|
ejvw.buf = append(ejvw.buf, []byte(s)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s := strconv.FormatInt(int64(i), 10)
|
||||||
|
|
||||||
|
if ejvw.canonical {
|
||||||
|
ejvw.writeExtendedSingleValue("numberInt", s, true)
|
||||||
|
} else {
|
||||||
|
ejvw.buf = append(ejvw.buf, []byte(s)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s := strconv.FormatInt(i, 10)
|
||||||
|
|
||||||
|
if ejvw.canonical {
|
||||||
|
ejvw.writeExtendedSingleValue("numberLong", s, true)
|
||||||
|
} else {
|
||||||
|
ejvw.buf = append(ejvw.buf, []byte(s)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
|
||||||
|
|
||||||
|
ejvw.writeExtendedSingleValue("code", buf.String(), false)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteMaxKey() error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.writeExtendedSingleValue("maxKey", "1", false)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteMinKey() error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.writeExtendedSingleValue("minKey", "1", false)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteNull() error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, []byte("null")...)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteObjectID(oid ObjectID) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
options = sortStringAlphebeticAscending(options)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString(`{"$regularExpression":{"pattern":`)
|
||||||
|
writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
|
||||||
|
buf.WriteString(`,"options":`)
|
||||||
|
writeStringWithEscapes(options, &buf, ejvw.escapeHTML)
|
||||||
|
buf.WriteString(`}},`)
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteString(s string) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
|
||||||
|
|
||||||
|
ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString(`{"$timestamp":{"t":`)
|
||||||
|
buf.WriteString(strconv.FormatUint(uint64(t), 10))
|
||||||
|
buf.WriteString(`,"i":`)
|
||||||
|
buf.WriteString(strconv.FormatUint(uint64(i), 10))
|
||||||
|
buf.WriteString(`}},`)
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteUndefined() error {
|
||||||
|
if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.writeExtendedSingleValue("undefined", "true", false)
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
|
||||||
|
switch ejvw.stack[ejvw.frame].mode {
|
||||||
|
case mDocument, mTopLevel, mCodeWithScope:
|
||||||
|
var buf bytes.Buffer
|
||||||
|
writeStringWithEscapes(key, &buf, ejvw.escapeHTML)
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`%s:`, buf.String()))...)
|
||||||
|
ejvw.push(mElement)
|
||||||
|
default:
|
||||||
|
return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ejvw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
|
||||||
|
switch ejvw.stack[ejvw.frame].mode {
|
||||||
|
case mDocument, mTopLevel, mCodeWithScope:
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// close the document
|
||||||
|
if ejvw.buf[len(ejvw.buf)-1] == ',' {
|
||||||
|
ejvw.buf[len(ejvw.buf)-1] = '}'
|
||||||
|
} else {
|
||||||
|
ejvw.buf = append(ejvw.buf, '}')
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ejvw.stack[ejvw.frame].mode {
|
||||||
|
case mCodeWithScope:
|
||||||
|
ejvw.buf = append(ejvw.buf, '}')
|
||||||
|
fallthrough
|
||||||
|
case mDocument:
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
case mTopLevel:
|
||||||
|
// If the value writer has newlines enabled, end top-level documents with a newline so that
|
||||||
|
// multiple documents encoded to the same writer are separated by newlines. That matches the
|
||||||
|
// Go json.Encoder behavior and also works with NewExtJSONValueReader.
|
||||||
|
if ejvw.newlines {
|
||||||
|
ejvw.buf = append(ejvw.buf, '\n')
|
||||||
|
}
|
||||||
|
if ejvw.w != nil {
|
||||||
|
if _, err := ejvw.w.Write(ejvw.buf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ejvw.buf = ejvw.buf[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
|
||||||
|
switch ejvw.stack[ejvw.frame].mode {
|
||||||
|
case mArray:
|
||||||
|
ejvw.push(mValue)
|
||||||
|
default:
|
||||||
|
return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ejvw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
|
||||||
|
switch ejvw.stack[ejvw.frame].mode {
|
||||||
|
case mArray:
|
||||||
|
// close the array
|
||||||
|
if ejvw.buf[len(ejvw.buf)-1] == ',' {
|
||||||
|
ejvw.buf[len(ejvw.buf)-1] = ']'
|
||||||
|
} else {
|
||||||
|
ejvw.buf = append(ejvw.buf, ']')
|
||||||
|
}
|
||||||
|
|
||||||
|
ejvw.buf = append(ejvw.buf, ',')
|
||||||
|
|
||||||
|
ejvw.pop()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatDouble(f float64) string {
|
||||||
|
var s string
|
||||||
|
switch {
|
||||||
|
case math.IsInf(f, 1):
|
||||||
|
s = "Infinity"
|
||||||
|
case math.IsInf(f, -1):
|
||||||
|
s = "-Infinity"
|
||||||
|
case math.IsNaN(f):
|
||||||
|
s = "NaN"
|
||||||
|
default:
|
||||||
|
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
|
||||||
|
// perfectly represent it.
|
||||||
|
s = strconv.FormatFloat(f, 'G', -1, 64)
|
||||||
|
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
|
||||||
|
s += ".0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
var hexChars = "0123456789abcdef"
|
||||||
|
|
||||||
|
func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
|
||||||
|
buf.WriteByte('"')
|
||||||
|
start := 0
|
||||||
|
for i := 0; i < len(s); {
|
||||||
|
if b := s[i]; b < utf8.RuneSelf {
|
||||||
|
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if start < i {
|
||||||
|
buf.WriteString(s[start:i])
|
||||||
|
}
|
||||||
|
switch b {
|
||||||
|
case '\\', '"':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte(b)
|
||||||
|
case '\n':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('n')
|
||||||
|
case '\r':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('r')
|
||||||
|
case '\t':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('t')
|
||||||
|
case '\b':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('b')
|
||||||
|
case '\f':
|
||||||
|
buf.WriteByte('\\')
|
||||||
|
buf.WriteByte('f')
|
||||||
|
default:
|
||||||
|
// This encodes bytes < 0x20 except for \t, \n and \r.
|
||||||
|
// If escapeHTML is set, it also escapes <, >, and &
|
||||||
|
// because they can lead to security holes when
|
||||||
|
// user-controlled strings are rendered into JSON
|
||||||
|
// and served to some browsers.
|
||||||
|
buf.WriteString(`\u00`)
|
||||||
|
buf.WriteByte(hexChars[b>>4])
|
||||||
|
buf.WriteByte(hexChars[b&0xF])
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
start = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c, size := utf8.DecodeRuneInString(s[i:])
|
||||||
|
if c == utf8.RuneError && size == 1 {
|
||||||
|
if start < i {
|
||||||
|
buf.WriteString(s[start:i])
|
||||||
|
}
|
||||||
|
buf.WriteString(`\ufffd`)
|
||||||
|
i += size
|
||||||
|
start = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// U+2028 is LINE SEPARATOR.
|
||||||
|
// U+2029 is PARAGRAPH SEPARATOR.
|
||||||
|
// They are both technically valid characters in JSON strings,
|
||||||
|
// but don't work in JSONP, which has to be evaluated as JavaScript,
|
||||||
|
// and can lead to security holes there. It is valid JSON to
|
||||||
|
// escape them, so we do so unconditionally.
|
||||||
|
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
|
||||||
|
if c == '\u2028' || c == '\u2029' {
|
||||||
|
if start < i {
|
||||||
|
buf.WriteString(s[start:i])
|
||||||
|
}
|
||||||
|
buf.WriteString(`\u202`)
|
||||||
|
buf.WriteByte(hexChars[c&0xF])
|
||||||
|
i += size
|
||||||
|
start = i
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i += size
|
||||||
|
}
|
||||||
|
if start < len(s) {
|
||||||
|
buf.WriteString(s[start:])
|
||||||
|
}
|
||||||
|
buf.WriteByte('"')
|
||||||
|
}
|
||||||
|
|
||||||
|
type sortableString []rune
|
||||||
|
|
||||||
|
func (ss sortableString) Len() int {
|
||||||
|
return len(ss)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss sortableString) Less(i, j int) bool {
|
||||||
|
return ss[i] < ss[j]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss sortableString) Swap(i, j int) {
|
||||||
|
ss[i], ss[j] = ss[j], ss[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortStringAlphebeticAscending(s string) string {
|
||||||
|
ss := sortableString([]rune(s))
|
||||||
|
sort.Sort(ss)
|
||||||
|
return string([]rune(ss))
|
||||||
|
}
|
259
extjson_writer_test.go
Normal file
259
extjson_writer_test.go
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtJSONValueWriter(t *testing.T) {
|
||||||
|
oid := ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
fn interface{}
|
||||||
|
params []interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"WriteBinary",
|
||||||
|
(*extJSONValueWriter).WriteBinary,
|
||||||
|
[]interface{}{[]byte{0x01, 0x02, 0x03}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteBinaryWithSubtype (not 0x02)",
|
||||||
|
(*extJSONValueWriter).WriteBinaryWithSubtype,
|
||||||
|
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xFF)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteBinaryWithSubtype (0x02)",
|
||||||
|
(*extJSONValueWriter).WriteBinaryWithSubtype,
|
||||||
|
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0x02)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteBoolean",
|
||||||
|
(*extJSONValueWriter).WriteBoolean,
|
||||||
|
[]interface{}{true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteDBPointer",
|
||||||
|
(*extJSONValueWriter).WriteDBPointer,
|
||||||
|
[]interface{}{"bar", oid},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteDateTime",
|
||||||
|
(*extJSONValueWriter).WriteDateTime,
|
||||||
|
[]interface{}{int64(12345678)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteDecimal128",
|
||||||
|
(*extJSONValueWriter).WriteDecimal128,
|
||||||
|
[]interface{}{NewDecimal128(10, 20)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteDouble",
|
||||||
|
(*extJSONValueWriter).WriteDouble,
|
||||||
|
[]interface{}{float64(3.14159)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteInt32",
|
||||||
|
(*extJSONValueWriter).WriteInt32,
|
||||||
|
[]interface{}{int32(123456)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteInt64",
|
||||||
|
(*extJSONValueWriter).WriteInt64,
|
||||||
|
[]interface{}{int64(1234567890)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteJavascript",
|
||||||
|
(*extJSONValueWriter).WriteJavascript,
|
||||||
|
[]interface{}{"var foo = 'bar';"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteMaxKey",
|
||||||
|
(*extJSONValueWriter).WriteMaxKey,
|
||||||
|
[]interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteMinKey",
|
||||||
|
(*extJSONValueWriter).WriteMinKey,
|
||||||
|
[]interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteNull",
|
||||||
|
(*extJSONValueWriter).WriteNull,
|
||||||
|
[]interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteObjectID",
|
||||||
|
(*extJSONValueWriter).WriteObjectID,
|
||||||
|
[]interface{}{oid},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteRegex",
|
||||||
|
(*extJSONValueWriter).WriteRegex,
|
||||||
|
[]interface{}{"bar", "baz"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteString",
|
||||||
|
(*extJSONValueWriter).WriteString,
|
||||||
|
[]interface{}{"hello, world!"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteSymbol",
|
||||||
|
(*extJSONValueWriter).WriteSymbol,
|
||||||
|
[]interface{}{"symbollolz"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteTimestamp",
|
||||||
|
(*extJSONValueWriter).WriteTimestamp,
|
||||||
|
[]interface{}{uint32(10), uint32(20)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"WriteUndefined",
|
||||||
|
(*extJSONValueWriter).WriteUndefined,
|
||||||
|
[]interface{}{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
fn := reflect.ValueOf(tc.fn)
|
||||||
|
if fn.Kind() != reflect.Func {
|
||||||
|
t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind())
|
||||||
|
}
|
||||||
|
if fn.Type().NumIn() != len(tc.params)+1 || fn.Type().In(0) != reflect.TypeOf((*extJSONValueWriter)(nil)) {
|
||||||
|
t.Fatalf("fn must have at least one parameter and the first parameter must be a *valueWriter")
|
||||||
|
}
|
||||||
|
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
|
||||||
|
t.Fatalf("fn must have one return value and it must be an error.")
|
||||||
|
}
|
||||||
|
params := make([]reflect.Value, 1, len(tc.params)+1)
|
||||||
|
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||||
|
params[0] = reflect.ValueOf(ejvw)
|
||||||
|
for _, param := range tc.params {
|
||||||
|
params = append(params, reflect.ValueOf(param))
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("incorrect transition", func(t *testing.T) {
|
||||||
|
results := fn.Call(params)
|
||||||
|
got := results[0].Interface().(error)
|
||||||
|
fnName := tc.name
|
||||||
|
if strings.Contains(fnName, "WriteBinary") {
|
||||||
|
fnName = "WriteBinaryWithSubtype"
|
||||||
|
}
|
||||||
|
want := TransitionError{current: mTopLevel, name: fnName, modes: []mode{mElement, mValue},
|
||||||
|
action: "write"}
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("WriteArray", func(t *testing.T) {
|
||||||
|
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||||
|
ejvw.push(mArray)
|
||||||
|
want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel,
|
||||||
|
name: "WriteArray", modes: []mode{mElement, mValue}, action: "write"}
|
||||||
|
_, got := ejvw.WriteArray()
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("WriteCodeWithScope", func(t *testing.T) {
|
||||||
|
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||||
|
ejvw.push(mArray)
|
||||||
|
want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel,
|
||||||
|
name: "WriteCodeWithScope", modes: []mode{mElement, mValue}, action: "write"}
|
||||||
|
_, got := ejvw.WriteCodeWithScope("")
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("WriteDocument", func(t *testing.T) {
|
||||||
|
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||||
|
ejvw.push(mArray)
|
||||||
|
want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel,
|
||||||
|
name: "WriteDocument", modes: []mode{mElement, mValue, mTopLevel}, action: "write"}
|
||||||
|
_, got := ejvw.WriteDocument()
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("WriteDocumentElement", func(t *testing.T) {
|
||||||
|
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||||
|
ejvw.push(mElement)
|
||||||
|
want := TransitionError{current: mElement,
|
||||||
|
destination: mElement,
|
||||||
|
parent: mTopLevel,
|
||||||
|
name: "WriteDocumentElement",
|
||||||
|
modes: []mode{mDocument, mTopLevel, mCodeWithScope},
|
||||||
|
action: "write"}
|
||||||
|
_, got := ejvw.WriteDocumentElement("")
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("WriteDocumentEnd", func(t *testing.T) {
|
||||||
|
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||||
|
ejvw.push(mElement)
|
||||||
|
want := fmt.Errorf("incorrect mode to end document: %s", mElement)
|
||||||
|
got := ejvw.WriteDocumentEnd()
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("WriteArrayElement", func(t *testing.T) {
|
||||||
|
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||||
|
ejvw.push(mElement)
|
||||||
|
want := TransitionError{current: mElement,
|
||||||
|
destination: mValue,
|
||||||
|
parent: mTopLevel,
|
||||||
|
name: "WriteArrayElement",
|
||||||
|
modes: []mode{mArray},
|
||||||
|
action: "write"}
|
||||||
|
_, got := ejvw.WriteArrayElement()
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("WriteArrayEnd", func(t *testing.T) {
|
||||||
|
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||||
|
ejvw.push(mElement)
|
||||||
|
want := fmt.Errorf("incorrect mode to end array: %s", mElement)
|
||||||
|
got := ejvw.WriteArrayEnd()
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("WriteBytes", func(t *testing.T) {
|
||||||
|
t.Run("writeElementHeader error", func(t *testing.T) {
|
||||||
|
ejvw := newExtJSONWriterFromSlice(nil, true, true)
|
||||||
|
want := TransitionError{current: mTopLevel, destination: mode(0),
|
||||||
|
name: "WriteBinaryWithSubtype", modes: []mode{mElement, mValue}, action: "write"}
|
||||||
|
got := ejvw.WriteBinaryWithSubtype(nil, (byte)(TypeEmbeddedDocument))
|
||||||
|
if !assert.CompareErrors(got, want) {
|
||||||
|
t.Errorf("Did not received expected error. got %v; want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("FormatDoubleWithExponent", func(t *testing.T) {
|
||||||
|
want := "3E-12"
|
||||||
|
got := formatDouble(float64(0.000000000003))
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Did not receive expected string. got %s: want %s", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
40
fuzz_test.go
Normal file
40
fuzz_test.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func FuzzDecode(f *testing.F) {
|
||||||
|
seedBSONCorpus(f)
|
||||||
|
|
||||||
|
f.Fuzz(func(t *testing.T, data []byte) {
|
||||||
|
for _, typ := range []func() interface{}{
|
||||||
|
func() interface{} { return new(D) },
|
||||||
|
func() interface{} { return new([]E) },
|
||||||
|
func() interface{} { return new(M) },
|
||||||
|
func() interface{} { return new(interface{}) },
|
||||||
|
func() interface{} { return make(map[string]interface{}) },
|
||||||
|
func() interface{} { return new([]interface{}) },
|
||||||
|
} {
|
||||||
|
i := typ()
|
||||||
|
if err := Unmarshal(data, i); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded, err := Marshal(i)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to marshal", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := Unmarshal(encoded, i); err != nil {
|
||||||
|
t.Fatal("failed to unmarshal", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
12
go.mod
Normal file
12
go.mod
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
module gitea.psichedelico.com/go/bson
|
||||||
|
|
||||||
|
go 1.23.0
|
||||||
|
|
||||||
|
toolchain go1.23.7
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/google/go-cmp v0.7.0
|
||||||
|
golang.org/x/sync v0.12.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require github.com/davecgh/go-spew v1.1.1
|
6
go.sum
Normal file
6
go.sum
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||||
|
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
481
internal/assert/assertion_compare.go
Normal file
481
internal/assert/assertion_compare.go
Normal file
@ -0,0 +1,481 @@
|
|||||||
|
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare.go
|
||||||
|
|
||||||
|
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||||
|
// Use of this source code is governed by an MIT-style license that can be found in
|
||||||
|
// the THIRD-PARTY-NOTICES file.
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CompareType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
compareLess CompareType = iota - 1
|
||||||
|
compareEqual
|
||||||
|
compareGreater
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
intType = reflect.TypeOf(int(1))
|
||||||
|
int8Type = reflect.TypeOf(int8(1))
|
||||||
|
int16Type = reflect.TypeOf(int16(1))
|
||||||
|
int32Type = reflect.TypeOf(int32(1))
|
||||||
|
int64Type = reflect.TypeOf(int64(1))
|
||||||
|
|
||||||
|
uintType = reflect.TypeOf(uint(1))
|
||||||
|
uint8Type = reflect.TypeOf(uint8(1))
|
||||||
|
uint16Type = reflect.TypeOf(uint16(1))
|
||||||
|
uint32Type = reflect.TypeOf(uint32(1))
|
||||||
|
uint64Type = reflect.TypeOf(uint64(1))
|
||||||
|
|
||||||
|
float32Type = reflect.TypeOf(float32(1))
|
||||||
|
float64Type = reflect.TypeOf(float64(1))
|
||||||
|
|
||||||
|
stringType = reflect.TypeOf("")
|
||||||
|
|
||||||
|
timeType = reflect.TypeOf(time.Time{})
|
||||||
|
bytesType = reflect.TypeOf([]byte{})
|
||||||
|
)
|
||||||
|
|
||||||
|
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
|
||||||
|
obj1Value := reflect.ValueOf(obj1)
|
||||||
|
obj2Value := reflect.ValueOf(obj2)
|
||||||
|
|
||||||
|
// throughout this switch we try and avoid calling .Convert() if possible,
|
||||||
|
// as this has a pretty big performance impact
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int:
|
||||||
|
{
|
||||||
|
intobj1, ok := obj1.(int)
|
||||||
|
if !ok {
|
||||||
|
intobj1 = obj1Value.Convert(intType).Interface().(int)
|
||||||
|
}
|
||||||
|
intobj2, ok := obj2.(int)
|
||||||
|
if !ok {
|
||||||
|
intobj2 = obj2Value.Convert(intType).Interface().(int)
|
||||||
|
}
|
||||||
|
if intobj1 > intobj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if intobj1 == intobj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if intobj1 < intobj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Int8:
|
||||||
|
{
|
||||||
|
int8obj1, ok := obj1.(int8)
|
||||||
|
if !ok {
|
||||||
|
int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
|
||||||
|
}
|
||||||
|
int8obj2, ok := obj2.(int8)
|
||||||
|
if !ok {
|
||||||
|
int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
|
||||||
|
}
|
||||||
|
if int8obj1 > int8obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if int8obj1 == int8obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if int8obj1 < int8obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Int16:
|
||||||
|
{
|
||||||
|
int16obj1, ok := obj1.(int16)
|
||||||
|
if !ok {
|
||||||
|
int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
|
||||||
|
}
|
||||||
|
int16obj2, ok := obj2.(int16)
|
||||||
|
if !ok {
|
||||||
|
int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
|
||||||
|
}
|
||||||
|
if int16obj1 > int16obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if int16obj1 == int16obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if int16obj1 < int16obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Int32:
|
||||||
|
{
|
||||||
|
int32obj1, ok := obj1.(int32)
|
||||||
|
if !ok {
|
||||||
|
int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
|
||||||
|
}
|
||||||
|
int32obj2, ok := obj2.(int32)
|
||||||
|
if !ok {
|
||||||
|
int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
|
||||||
|
}
|
||||||
|
if int32obj1 > int32obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if int32obj1 == int32obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if int32obj1 < int32obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Int64:
|
||||||
|
{
|
||||||
|
int64obj1, ok := obj1.(int64)
|
||||||
|
if !ok {
|
||||||
|
int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
|
||||||
|
}
|
||||||
|
int64obj2, ok := obj2.(int64)
|
||||||
|
if !ok {
|
||||||
|
int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
|
||||||
|
}
|
||||||
|
if int64obj1 > int64obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if int64obj1 == int64obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if int64obj1 < int64obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Uint:
|
||||||
|
{
|
||||||
|
uintobj1, ok := obj1.(uint)
|
||||||
|
if !ok {
|
||||||
|
uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
|
||||||
|
}
|
||||||
|
uintobj2, ok := obj2.(uint)
|
||||||
|
if !ok {
|
||||||
|
uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
|
||||||
|
}
|
||||||
|
if uintobj1 > uintobj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if uintobj1 == uintobj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if uintobj1 < uintobj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Uint8:
|
||||||
|
{
|
||||||
|
uint8obj1, ok := obj1.(uint8)
|
||||||
|
if !ok {
|
||||||
|
uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
|
||||||
|
}
|
||||||
|
uint8obj2, ok := obj2.(uint8)
|
||||||
|
if !ok {
|
||||||
|
uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
|
||||||
|
}
|
||||||
|
if uint8obj1 > uint8obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if uint8obj1 == uint8obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if uint8obj1 < uint8obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Uint16:
|
||||||
|
{
|
||||||
|
uint16obj1, ok := obj1.(uint16)
|
||||||
|
if !ok {
|
||||||
|
uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
|
||||||
|
}
|
||||||
|
uint16obj2, ok := obj2.(uint16)
|
||||||
|
if !ok {
|
||||||
|
uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
|
||||||
|
}
|
||||||
|
if uint16obj1 > uint16obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if uint16obj1 == uint16obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if uint16obj1 < uint16obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Uint32:
|
||||||
|
{
|
||||||
|
uint32obj1, ok := obj1.(uint32)
|
||||||
|
if !ok {
|
||||||
|
uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
|
||||||
|
}
|
||||||
|
uint32obj2, ok := obj2.(uint32)
|
||||||
|
if !ok {
|
||||||
|
uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
|
||||||
|
}
|
||||||
|
if uint32obj1 > uint32obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if uint32obj1 == uint32obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if uint32obj1 < uint32obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Uint64:
|
||||||
|
{
|
||||||
|
uint64obj1, ok := obj1.(uint64)
|
||||||
|
if !ok {
|
||||||
|
uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
|
||||||
|
}
|
||||||
|
uint64obj2, ok := obj2.(uint64)
|
||||||
|
if !ok {
|
||||||
|
uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
|
||||||
|
}
|
||||||
|
if uint64obj1 > uint64obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if uint64obj1 == uint64obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if uint64obj1 < uint64obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Float32:
|
||||||
|
{
|
||||||
|
float32obj1, ok := obj1.(float32)
|
||||||
|
if !ok {
|
||||||
|
float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
|
||||||
|
}
|
||||||
|
float32obj2, ok := obj2.(float32)
|
||||||
|
if !ok {
|
||||||
|
float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
|
||||||
|
}
|
||||||
|
if float32obj1 > float32obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if float32obj1 == float32obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if float32obj1 < float32obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Float64:
|
||||||
|
{
|
||||||
|
float64obj1, ok := obj1.(float64)
|
||||||
|
if !ok {
|
||||||
|
float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
|
||||||
|
}
|
||||||
|
float64obj2, ok := obj2.(float64)
|
||||||
|
if !ok {
|
||||||
|
float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
|
||||||
|
}
|
||||||
|
if float64obj1 > float64obj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if float64obj1 == float64obj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if float64obj1 < float64obj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.String:
|
||||||
|
{
|
||||||
|
stringobj1, ok := obj1.(string)
|
||||||
|
if !ok {
|
||||||
|
stringobj1 = obj1Value.Convert(stringType).Interface().(string)
|
||||||
|
}
|
||||||
|
stringobj2, ok := obj2.(string)
|
||||||
|
if !ok {
|
||||||
|
stringobj2 = obj2Value.Convert(stringType).Interface().(string)
|
||||||
|
}
|
||||||
|
if stringobj1 > stringobj2 {
|
||||||
|
return compareGreater, true
|
||||||
|
}
|
||||||
|
if stringobj1 == stringobj2 {
|
||||||
|
return compareEqual, true
|
||||||
|
}
|
||||||
|
if stringobj1 < stringobj2 {
|
||||||
|
return compareLess, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check for known struct types we can check for compare results.
|
||||||
|
case reflect.Struct:
|
||||||
|
{
|
||||||
|
// All structs enter here. We're not interested in most types.
|
||||||
|
if !canConvert(obj1Value, timeType) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// time.Time can compared!
|
||||||
|
timeObj1, ok := obj1.(time.Time)
|
||||||
|
if !ok {
|
||||||
|
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeObj2, ok := obj2.(time.Time)
|
||||||
|
if !ok {
|
||||||
|
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
{
|
||||||
|
// We only care about the []byte type.
|
||||||
|
if !canConvert(obj1Value, bytesType) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// []byte can be compared!
|
||||||
|
bytesObj1, ok := obj1.([]byte)
|
||||||
|
if !ok {
|
||||||
|
bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
|
||||||
|
|
||||||
|
}
|
||||||
|
bytesObj2, ok := obj2.([]byte)
|
||||||
|
if !ok {
|
||||||
|
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return compareEqual, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Greater asserts that the first element is greater than the second
|
||||||
|
//
|
||||||
|
// assert.Greater(t, 2, 1)
|
||||||
|
// assert.Greater(t, float64(2), float64(1))
|
||||||
|
// assert.Greater(t, "b", "a")
|
||||||
|
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GreaterOrEqual asserts that the first element is greater than or equal to the second
|
||||||
|
//
|
||||||
|
// assert.GreaterOrEqual(t, 2, 1)
|
||||||
|
// assert.GreaterOrEqual(t, 2, 2)
|
||||||
|
// assert.GreaterOrEqual(t, "b", "a")
|
||||||
|
// assert.GreaterOrEqual(t, "b", "b")
|
||||||
|
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Less asserts that the first element is less than the second
|
||||||
|
//
|
||||||
|
// assert.Less(t, 1, 2)
|
||||||
|
// assert.Less(t, float64(1), float64(2))
|
||||||
|
// assert.Less(t, "a", "b")
|
||||||
|
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LessOrEqual asserts that the first element is less than or equal to the second
|
||||||
|
//
|
||||||
|
// assert.LessOrEqual(t, 1, 2)
|
||||||
|
// assert.LessOrEqual(t, 2, 2)
|
||||||
|
// assert.LessOrEqual(t, "a", "b")
|
||||||
|
// assert.LessOrEqual(t, "b", "b")
|
||||||
|
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Positive asserts that the specified element is positive
|
||||||
|
//
|
||||||
|
// assert.Positive(t, 1)
|
||||||
|
// assert.Positive(t, 1.23)
|
||||||
|
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
zero := reflect.Zero(reflect.TypeOf(e))
|
||||||
|
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Negative asserts that the specified element is negative
|
||||||
|
//
|
||||||
|
// assert.Negative(t, -1)
|
||||||
|
// assert.Negative(t, -1.23)
|
||||||
|
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
zero := reflect.Zero(reflect.TypeOf(e))
|
||||||
|
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
|
||||||
|
e1Kind := reflect.ValueOf(e1).Kind()
|
||||||
|
e2Kind := reflect.ValueOf(e2).Kind()
|
||||||
|
if e1Kind != e2Kind {
|
||||||
|
return Fail(t, "Elements should be the same type", msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
compareResult, isComparable := compare(e1, e2, e1Kind)
|
||||||
|
if !isComparable {
|
||||||
|
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !containsValue(allowedComparesResults, compareResult) {
|
||||||
|
return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsValue(values []CompareType, value CompareType) bool {
|
||||||
|
for _, v := range values {
|
||||||
|
if v == value {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareErrors asserts two errors
|
||||||
|
func CompareErrors(err1, err2 error) bool {
|
||||||
|
if err1 == nil && err2 == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err1 == nil || err2 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err1.Error() != err2.Error() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
18
internal/assert/assertion_compare_can_convert.go
Normal file
18
internal/assert/assertion_compare_can_convert.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_can_convert.go
|
||||||
|
|
||||||
|
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||||
|
// Use of this source code is governed by an MIT-style license that can be found in
|
||||||
|
// the THIRD-PARTY-NOTICES file.
|
||||||
|
|
||||||
|
//go:build go1.17
|
||||||
|
// +build go1.17
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
// Wrapper around reflect.Value.CanConvert, for compatibility
|
||||||
|
// reasons.
|
||||||
|
func canConvert(value reflect.Value, to reflect.Type) bool {
|
||||||
|
return value.CanConvert(to)
|
||||||
|
}
|
184
internal/assert/assertion_compare_go1.17_test.go
Normal file
184
internal/assert/assertion_compare_go1.17_test.go
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_go1.17_test.go
|
||||||
|
|
||||||
|
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||||
|
// Use of this source code is governed by an MIT-style license that can be found in
|
||||||
|
// the THIRD-PARTY-NOTICES file.
|
||||||
|
|
||||||
|
//go:build go1.17
|
||||||
|
// +build go1.17
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompare17(t *testing.T) {
|
||||||
|
type customTime time.Time
|
||||||
|
type customBytes []byte
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
cType string
|
||||||
|
}{
|
||||||
|
{less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"},
|
||||||
|
{less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"},
|
||||||
|
{less: []byte{1, 1}, greater: []byte{1, 2}, cType: "[]byte"},
|
||||||
|
{less: customBytes([]byte{1, 1}), greater: customBytes([]byte{1, 2}), cType: "[]byte"},
|
||||||
|
} {
|
||||||
|
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
|
||||||
|
if !isComparable {
|
||||||
|
t.Error("object should be comparable for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resLess != compareLess {
|
||||||
|
t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType,
|
||||||
|
currCase.less, currCase.greater)
|
||||||
|
}
|
||||||
|
|
||||||
|
resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())
|
||||||
|
if !isComparable {
|
||||||
|
t.Error("object are comparable for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resGreater != compareGreater {
|
||||||
|
t.Errorf("object greater should be greater than less for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
|
||||||
|
resEqual, isComparable := compare(currCase.less, currCase.less, reflect.ValueOf(currCase.less).Kind())
|
||||||
|
if !isComparable {
|
||||||
|
t.Error("object are comparable for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resEqual != 0 {
|
||||||
|
t.Errorf("objects should be equal for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGreater17(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !Greater(mockT, 2, 1) {
|
||||||
|
t.Error("Greater should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Greater(mockT, 1, 1) {
|
||||||
|
t.Error("Greater should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Greater(mockT, 1, 2) {
|
||||||
|
t.Error("Greater should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than "[1 2]"`},
|
||||||
|
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than "0001-01-01 01:00:00 +0000 UTC"`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, Greater(out, currCase.less, currCase.greater))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Greater")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGreaterOrEqual17(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !GreaterOrEqual(mockT, 2, 1) {
|
||||||
|
t.Error("GreaterOrEqual should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !GreaterOrEqual(mockT, 1, 1) {
|
||||||
|
t.Error("GreaterOrEqual should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if GreaterOrEqual(mockT, 1, 2) {
|
||||||
|
t.Error("GreaterOrEqual should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than or equal to "[1 2]"`},
|
||||||
|
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than or equal to "0001-01-01 01:00:00 +0000 UTC"`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, GreaterOrEqual(out, currCase.less, currCase.greater))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.GreaterOrEqual")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLess17(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !Less(mockT, 1, 2) {
|
||||||
|
t.Error("Less should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Less(mockT, 1, 1) {
|
||||||
|
t.Error("Less should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Less(mockT, 2, 1) {
|
||||||
|
t.Error("Less should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than "[1 1]"`},
|
||||||
|
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than "0001-01-01 00:00:00 +0000 UTC"`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, Less(out, currCase.greater, currCase.less))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Less")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLessOrEqual17(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !LessOrEqual(mockT, 1, 2) {
|
||||||
|
t.Error("LessOrEqual should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !LessOrEqual(mockT, 1, 1) {
|
||||||
|
t.Error("LessOrEqual should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if LessOrEqual(mockT, 2, 1) {
|
||||||
|
t.Error("LessOrEqual should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than or equal to "[1 1]"`},
|
||||||
|
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than or equal to "0001-01-01 00:00:00 +0000 UTC"`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, LessOrEqual(out, currCase.greater, currCase.less))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.LessOrEqual")
|
||||||
|
}
|
||||||
|
}
|
18
internal/assert/assertion_compare_legacy.go
Normal file
18
internal/assert/assertion_compare_legacy.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_legacy.go
|
||||||
|
|
||||||
|
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||||
|
// Use of this source code is governed by an MIT-style license that can be found in
|
||||||
|
// the THIRD-PARTY-NOTICES file.
|
||||||
|
|
||||||
|
//go:build !go1.17
|
||||||
|
// +build !go1.17
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
// Older versions of Go does not have the reflect.Value.CanConvert
|
||||||
|
// method.
|
||||||
|
func canConvert(value reflect.Value, to reflect.Type) bool {
|
||||||
|
return false
|
||||||
|
}
|
455
internal/assert/assertion_compare_test.go
Normal file
455
internal/assert/assertion_compare_test.go
Normal file
@ -0,0 +1,455 @@
|
|||||||
|
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_test.go
|
||||||
|
|
||||||
|
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||||
|
// Use of this source code is governed by an MIT-style license that can be found in
|
||||||
|
// the THIRD-PARTY-NOTICES file.
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCompare(t *testing.T) {
|
||||||
|
type customInt int
|
||||||
|
type customInt8 int8
|
||||||
|
type customInt16 int16
|
||||||
|
type customInt32 int32
|
||||||
|
type customInt64 int64
|
||||||
|
type customUInt uint
|
||||||
|
type customUInt8 uint8
|
||||||
|
type customUInt16 uint16
|
||||||
|
type customUInt32 uint32
|
||||||
|
type customUInt64 uint64
|
||||||
|
type customFloat32 float32
|
||||||
|
type customFloat64 float64
|
||||||
|
type customString string
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
cType string
|
||||||
|
}{
|
||||||
|
{less: customString("a"), greater: customString("b"), cType: "string"},
|
||||||
|
{less: "a", greater: "b", cType: "string"},
|
||||||
|
{less: customInt(1), greater: customInt(2), cType: "int"},
|
||||||
|
{less: int(1), greater: int(2), cType: "int"},
|
||||||
|
{less: customInt8(1), greater: customInt8(2), cType: "int8"},
|
||||||
|
{less: int8(1), greater: int8(2), cType: "int8"},
|
||||||
|
{less: customInt16(1), greater: customInt16(2), cType: "int16"},
|
||||||
|
{less: int16(1), greater: int16(2), cType: "int16"},
|
||||||
|
{less: customInt32(1), greater: customInt32(2), cType: "int32"},
|
||||||
|
{less: int32(1), greater: int32(2), cType: "int32"},
|
||||||
|
{less: customInt64(1), greater: customInt64(2), cType: "int64"},
|
||||||
|
{less: int64(1), greater: int64(2), cType: "int64"},
|
||||||
|
{less: customUInt(1), greater: customUInt(2), cType: "uint"},
|
||||||
|
{less: uint8(1), greater: uint8(2), cType: "uint8"},
|
||||||
|
{less: customUInt8(1), greater: customUInt8(2), cType: "uint8"},
|
||||||
|
{less: uint16(1), greater: uint16(2), cType: "uint16"},
|
||||||
|
{less: customUInt16(1), greater: customUInt16(2), cType: "uint16"},
|
||||||
|
{less: uint32(1), greater: uint32(2), cType: "uint32"},
|
||||||
|
{less: customUInt32(1), greater: customUInt32(2), cType: "uint32"},
|
||||||
|
{less: uint64(1), greater: uint64(2), cType: "uint64"},
|
||||||
|
{less: customUInt64(1), greater: customUInt64(2), cType: "uint64"},
|
||||||
|
{less: float32(1.23), greater: float32(2.34), cType: "float32"},
|
||||||
|
{less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"},
|
||||||
|
{less: float64(1.23), greater: float64(2.34), cType: "float64"},
|
||||||
|
{less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"},
|
||||||
|
} {
|
||||||
|
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
|
||||||
|
if !isComparable {
|
||||||
|
t.Error("object should be comparable for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resLess != compareLess {
|
||||||
|
t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType,
|
||||||
|
currCase.less, currCase.greater)
|
||||||
|
}
|
||||||
|
|
||||||
|
resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())
|
||||||
|
if !isComparable {
|
||||||
|
t.Error("object are comparable for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resGreater != compareGreater {
|
||||||
|
t.Errorf("object greater should be greater than less for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
|
||||||
|
resEqual, isComparable := compare(currCase.less, currCase.less, reflect.ValueOf(currCase.less).Kind())
|
||||||
|
if !isComparable {
|
||||||
|
t.Error("object are comparable for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resEqual != 0 {
|
||||||
|
t.Errorf("objects should be equal for type " + currCase.cType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type outputT struct {
|
||||||
|
buf *bytes.Buffer
|
||||||
|
helpers map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements TestingT
|
||||||
|
func (t *outputT) Errorf(format string, args ...interface{}) {
|
||||||
|
s := fmt.Sprintf(format, args...)
|
||||||
|
t.buf.WriteString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *outputT) Helper() {
|
||||||
|
if t.helpers == nil {
|
||||||
|
t.helpers = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
t.helpers[callerName(1)] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// callerName gives the function name (qualified with a package path)
|
||||||
|
// for the caller after skip frames (where 0 means the current function).
|
||||||
|
func callerName(skip int) string {
|
||||||
|
// Make room for the skip PC.
|
||||||
|
var pc [1]uintptr
|
||||||
|
n := runtime.Callers(skip+2, pc[:]) // skip + runtime.Callers + callerName
|
||||||
|
if n == 0 {
|
||||||
|
panic("testing: zero callers found")
|
||||||
|
}
|
||||||
|
frames := runtime.CallersFrames(pc[:n])
|
||||||
|
frame, _ := frames.Next()
|
||||||
|
return frame.Function
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGreater(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !Greater(mockT, 2, 1) {
|
||||||
|
t.Error("Greater should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Greater(mockT, 1, 1) {
|
||||||
|
t.Error("Greater should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Greater(mockT, 1, 2) {
|
||||||
|
t.Error("Greater should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{less: "a", greater: "b", msg: `"a" is not greater than "b"`},
|
||||||
|
{less: int(1), greater: int(2), msg: `"1" is not greater than "2"`},
|
||||||
|
{less: int8(1), greater: int8(2), msg: `"1" is not greater than "2"`},
|
||||||
|
{less: int16(1), greater: int16(2), msg: `"1" is not greater than "2"`},
|
||||||
|
{less: int32(1), greater: int32(2), msg: `"1" is not greater than "2"`},
|
||||||
|
{less: int64(1), greater: int64(2), msg: `"1" is not greater than "2"`},
|
||||||
|
{less: uint8(1), greater: uint8(2), msg: `"1" is not greater than "2"`},
|
||||||
|
{less: uint16(1), greater: uint16(2), msg: `"1" is not greater than "2"`},
|
||||||
|
{less: uint32(1), greater: uint32(2), msg: `"1" is not greater than "2"`},
|
||||||
|
{less: uint64(1), greater: uint64(2), msg: `"1" is not greater than "2"`},
|
||||||
|
{less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than "2.34"`},
|
||||||
|
{less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than "2.34"`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, Greater(out, currCase.less, currCase.greater))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Greater")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGreaterOrEqual(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !GreaterOrEqual(mockT, 2, 1) {
|
||||||
|
t.Error("GreaterOrEqual should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !GreaterOrEqual(mockT, 1, 1) {
|
||||||
|
t.Error("GreaterOrEqual should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if GreaterOrEqual(mockT, 1, 2) {
|
||||||
|
t.Error("GreaterOrEqual should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{less: "a", greater: "b", msg: `"a" is not greater than or equal to "b"`},
|
||||||
|
{less: int(1), greater: int(2), msg: `"1" is not greater than or equal to "2"`},
|
||||||
|
{less: int8(1), greater: int8(2), msg: `"1" is not greater than or equal to "2"`},
|
||||||
|
{less: int16(1), greater: int16(2), msg: `"1" is not greater than or equal to "2"`},
|
||||||
|
{less: int32(1), greater: int32(2), msg: `"1" is not greater than or equal to "2"`},
|
||||||
|
{less: int64(1), greater: int64(2), msg: `"1" is not greater than or equal to "2"`},
|
||||||
|
{less: uint8(1), greater: uint8(2), msg: `"1" is not greater than or equal to "2"`},
|
||||||
|
{less: uint16(1), greater: uint16(2), msg: `"1" is not greater than or equal to "2"`},
|
||||||
|
{less: uint32(1), greater: uint32(2), msg: `"1" is not greater than or equal to "2"`},
|
||||||
|
{less: uint64(1), greater: uint64(2), msg: `"1" is not greater than or equal to "2"`},
|
||||||
|
{less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than or equal to "2.34"`},
|
||||||
|
{less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than or equal to "2.34"`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, GreaterOrEqual(out, currCase.less, currCase.greater))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.GreaterOrEqual")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLess(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !Less(mockT, 1, 2) {
|
||||||
|
t.Error("Less should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Less(mockT, 1, 1) {
|
||||||
|
t.Error("Less should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Less(mockT, 2, 1) {
|
||||||
|
t.Error("Less should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{less: "a", greater: "b", msg: `"b" is not less than "a"`},
|
||||||
|
{less: int(1), greater: int(2), msg: `"2" is not less than "1"`},
|
||||||
|
{less: int8(1), greater: int8(2), msg: `"2" is not less than "1"`},
|
||||||
|
{less: int16(1), greater: int16(2), msg: `"2" is not less than "1"`},
|
||||||
|
{less: int32(1), greater: int32(2), msg: `"2" is not less than "1"`},
|
||||||
|
{less: int64(1), greater: int64(2), msg: `"2" is not less than "1"`},
|
||||||
|
{less: uint8(1), greater: uint8(2), msg: `"2" is not less than "1"`},
|
||||||
|
{less: uint16(1), greater: uint16(2), msg: `"2" is not less than "1"`},
|
||||||
|
{less: uint32(1), greater: uint32(2), msg: `"2" is not less than "1"`},
|
||||||
|
{less: uint64(1), greater: uint64(2), msg: `"2" is not less than "1"`},
|
||||||
|
{less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than "1.23"`},
|
||||||
|
{less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than "1.23"`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, Less(out, currCase.greater, currCase.less))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Less")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLessOrEqual(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !LessOrEqual(mockT, 1, 2) {
|
||||||
|
t.Error("LessOrEqual should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !LessOrEqual(mockT, 1, 1) {
|
||||||
|
t.Error("LessOrEqual should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if LessOrEqual(mockT, 2, 1) {
|
||||||
|
t.Error("LessOrEqual should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
less interface{}
|
||||||
|
greater interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{less: "a", greater: "b", msg: `"b" is not less than or equal to "a"`},
|
||||||
|
{less: int(1), greater: int(2), msg: `"2" is not less than or equal to "1"`},
|
||||||
|
{less: int8(1), greater: int8(2), msg: `"2" is not less than or equal to "1"`},
|
||||||
|
{less: int16(1), greater: int16(2), msg: `"2" is not less than or equal to "1"`},
|
||||||
|
{less: int32(1), greater: int32(2), msg: `"2" is not less than or equal to "1"`},
|
||||||
|
{less: int64(1), greater: int64(2), msg: `"2" is not less than or equal to "1"`},
|
||||||
|
{less: uint8(1), greater: uint8(2), msg: `"2" is not less than or equal to "1"`},
|
||||||
|
{less: uint16(1), greater: uint16(2), msg: `"2" is not less than or equal to "1"`},
|
||||||
|
{less: uint32(1), greater: uint32(2), msg: `"2" is not less than or equal to "1"`},
|
||||||
|
{less: uint64(1), greater: uint64(2), msg: `"2" is not less than or equal to "1"`},
|
||||||
|
{less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than or equal to "1.23"`},
|
||||||
|
{less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than or equal to "1.23"`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, LessOrEqual(out, currCase.greater, currCase.less))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.LessOrEqual")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPositive(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !Positive(mockT, 1) {
|
||||||
|
t.Error("Positive should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !Positive(mockT, 1.23) {
|
||||||
|
t.Error("Positive should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Positive(mockT, -1) {
|
||||||
|
t.Error("Positive should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Positive(mockT, -1.23) {
|
||||||
|
t.Error("Positive should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
e interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{e: int(-1), msg: `"-1" is not positive`},
|
||||||
|
{e: int8(-1), msg: `"-1" is not positive`},
|
||||||
|
{e: int16(-1), msg: `"-1" is not positive`},
|
||||||
|
{e: int32(-1), msg: `"-1" is not positive`},
|
||||||
|
{e: int64(-1), msg: `"-1" is not positive`},
|
||||||
|
{e: float32(-1.23), msg: `"-1.23" is not positive`},
|
||||||
|
{e: float64(-1.23), msg: `"-1.23" is not positive`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, Positive(out, currCase.e))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Positive")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNegative(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
if !Negative(mockT, -1) {
|
||||||
|
t.Error("Negative should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !Negative(mockT, -1.23) {
|
||||||
|
t.Error("Negative should return true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Negative(mockT, 1) {
|
||||||
|
t.Error("Negative should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
if Negative(mockT, 1.23) {
|
||||||
|
t.Error("Negative should return false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check error report
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
e interface{}
|
||||||
|
msg string
|
||||||
|
}{
|
||||||
|
{e: int(1), msg: `"1" is not negative`},
|
||||||
|
{e: int8(1), msg: `"1" is not negative`},
|
||||||
|
{e: int16(1), msg: `"1" is not negative`},
|
||||||
|
{e: int32(1), msg: `"1" is not negative`},
|
||||||
|
{e: int64(1), msg: `"1" is not negative`},
|
||||||
|
{e: float32(1.23), msg: `"1.23" is not negative`},
|
||||||
|
{e: float64(1.23), msg: `"1.23" is not negative`},
|
||||||
|
} {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
False(t, Negative(out, currCase.e))
|
||||||
|
Contains(t, out.buf.String(), currCase.msg)
|
||||||
|
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Negative")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_compareTwoValuesDifferentValuesTypes(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
v1 interface{}
|
||||||
|
v2 interface{}
|
||||||
|
compareResult bool
|
||||||
|
}{
|
||||||
|
{v1: 123, v2: "abc"},
|
||||||
|
{v1: "abc", v2: 123456},
|
||||||
|
{v1: float64(12), v2: "123"},
|
||||||
|
{v1: "float(12)", v2: float64(1)},
|
||||||
|
} {
|
||||||
|
compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, []CompareType{compareLess, compareEqual, compareGreater}, "testFailMessage")
|
||||||
|
False(t, compareResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_compareTwoValuesNotComparableValues(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
type CompareStruct struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
v1 interface{}
|
||||||
|
v2 interface{}
|
||||||
|
}{
|
||||||
|
{v1: CompareStruct{}, v2: CompareStruct{}},
|
||||||
|
{v1: map[string]int{}, v2: map[string]int{}},
|
||||||
|
{v1: make([]int, 5), v2: make([]int, 5)},
|
||||||
|
} {
|
||||||
|
compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, []CompareType{compareLess, compareEqual, compareGreater}, "testFailMessage")
|
||||||
|
False(t, compareResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_compareTwoValuesCorrectCompareResult(t *testing.T) {
|
||||||
|
mockT := new(testing.T)
|
||||||
|
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
v1 interface{}
|
||||||
|
v2 interface{}
|
||||||
|
compareTypes []CompareType
|
||||||
|
}{
|
||||||
|
{v1: 1, v2: 2, compareTypes: []CompareType{compareLess}},
|
||||||
|
{v1: 1, v2: 2, compareTypes: []CompareType{compareLess, compareEqual}},
|
||||||
|
{v1: 2, v2: 2, compareTypes: []CompareType{compareGreater, compareEqual}},
|
||||||
|
{v1: 2, v2: 2, compareTypes: []CompareType{compareEqual}},
|
||||||
|
{v1: 2, v2: 1, compareTypes: []CompareType{compareEqual, compareGreater}},
|
||||||
|
{v1: 2, v2: 1, compareTypes: []CompareType{compareGreater}},
|
||||||
|
} {
|
||||||
|
compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, currCase.compareTypes, "testFailMessage")
|
||||||
|
True(t, compareResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_containsValue(t *testing.T) {
|
||||||
|
for _, currCase := range []struct {
|
||||||
|
values []CompareType
|
||||||
|
value CompareType
|
||||||
|
result bool
|
||||||
|
}{
|
||||||
|
{values: []CompareType{compareGreater}, value: compareGreater, result: true},
|
||||||
|
{values: []CompareType{compareGreater, compareLess}, value: compareGreater, result: true},
|
||||||
|
{values: []CompareType{compareGreater, compareLess}, value: compareLess, result: true},
|
||||||
|
{values: []CompareType{compareGreater, compareLess}, value: compareEqual, result: false},
|
||||||
|
} {
|
||||||
|
compareResult := containsValue(currCase.values, currCase.value)
|
||||||
|
Equal(t, currCase.result, compareResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComparingMsgAndArgsForwarding(t *testing.T) {
|
||||||
|
msgAndArgs := []interface{}{"format %s %x", "this", 0xc001}
|
||||||
|
expectedOutput := "format this c001\n"
|
||||||
|
funcs := []func(t TestingT){
|
||||||
|
func(t TestingT) { Greater(t, 1, 2, msgAndArgs...) },
|
||||||
|
func(t TestingT) { GreaterOrEqual(t, 1, 2, msgAndArgs...) },
|
||||||
|
func(t TestingT) { Less(t, 2, 1, msgAndArgs...) },
|
||||||
|
func(t TestingT) { LessOrEqual(t, 2, 1, msgAndArgs...) },
|
||||||
|
func(t TestingT) { Positive(t, 0, msgAndArgs...) },
|
||||||
|
func(t TestingT) { Negative(t, 0, msgAndArgs...) },
|
||||||
|
}
|
||||||
|
for _, f := range funcs {
|
||||||
|
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||||
|
f(out)
|
||||||
|
Contains(t, out.buf.String(), expectedOutput)
|
||||||
|
}
|
||||||
|
}
|
325
internal/assert/assertion_format.go
Normal file
325
internal/assert/assertion_format.go
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_format.go
|
||||||
|
|
||||||
|
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||||
|
// Use of this source code is governed by an MIT-style license that can be found in
|
||||||
|
// the THIRD-PARTY-NOTICES file.
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import (
|
||||||
|
time "time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Containsf asserts that the specified string, list(array, slice...) or map contains the
|
||||||
|
// specified substring or element.
|
||||||
|
//
|
||||||
|
// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted")
|
||||||
|
// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted")
|
||||||
|
// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted")
|
||||||
|
func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Contains(t, s, contains, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ElementsMatchf asserts that the specified listA(array, slice...) is equal to specified
|
||||||
|
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
|
||||||
|
// the number of appearances of each of them in both lists should match.
|
||||||
|
//
|
||||||
|
// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted")
|
||||||
|
func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equalf asserts that two objects are equal.
|
||||||
|
//
|
||||||
|
// assert.Equalf(t, 123, 123, "error message %s", "formatted")
|
||||||
|
//
|
||||||
|
// Pointer variable equality is determined based on the equality of the
|
||||||
|
// referenced values (as opposed to the memory addresses). Function equality
|
||||||
|
// cannot be determined and will always fail.
|
||||||
|
func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Equal(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EqualErrorf asserts that a function returned an error (i.e. not `nil`)
|
||||||
|
// and that it is equal to the provided error.
|
||||||
|
//
|
||||||
|
// actualObj, err := SomeFunction()
|
||||||
|
// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted")
|
||||||
|
func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EqualValuesf asserts that two objects are equal or convertible to the same types
|
||||||
|
// and equal.
|
||||||
|
//
|
||||||
|
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
|
||||||
|
func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return EqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errorf asserts that a function returned an error (i.e. not `nil`).
|
||||||
|
//
|
||||||
|
// actualObj, err := SomeFunction()
|
||||||
|
// if assert.Errorf(t, err, "error message %s", "formatted") {
|
||||||
|
// assert.Equal(t, expectedErrorf, err)
|
||||||
|
// }
|
||||||
|
func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Error(t, err, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
|
||||||
|
// and that the error contains the specified substring.
|
||||||
|
//
|
||||||
|
// actualObj, err := SomeFunction()
|
||||||
|
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
|
||||||
|
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Eventuallyf asserts that given condition will be met in waitFor time,
|
||||||
|
// periodically checking target function each tick.
|
||||||
|
//
|
||||||
|
// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
|
||||||
|
func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Eventually(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Failf reports a failure through
|
||||||
|
func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Fail(t, failureMessage, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FailNowf fails test
|
||||||
|
func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return FailNow(t, failureMessage, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Falsef asserts that the specified value is false.
|
||||||
|
//
|
||||||
|
// assert.Falsef(t, myBool, "error message %s", "formatted")
|
||||||
|
func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return False(t, value, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Greaterf asserts that the first element is greater than the second
|
||||||
|
//
|
||||||
|
// assert.Greaterf(t, 2, 1, "error message %s", "formatted")
|
||||||
|
// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted")
|
||||||
|
// assert.Greaterf(t, "b", "a", "error message %s", "formatted")
|
||||||
|
func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Greater(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GreaterOrEqualf asserts that the first element is greater than or equal to the second
|
||||||
|
//
|
||||||
|
// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted")
|
||||||
|
// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted")
|
||||||
|
// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted")
|
||||||
|
// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted")
|
||||||
|
func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return GreaterOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InDeltaf asserts that the two numerals are within delta of each other.
|
||||||
|
//
|
||||||
|
// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted")
|
||||||
|
func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return InDelta(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTypef asserts that the specified objects are of the same type.
|
||||||
|
func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return IsType(t, expectedType, object, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lenf asserts that the specified object has specific length.
|
||||||
|
// Lenf also fails if the object has a type that len() not accept.
|
||||||
|
//
|
||||||
|
// assert.Lenf(t, mySlice, 3, "error message %s", "formatted")
|
||||||
|
func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Len(t, object, length, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lessf asserts that the first element is less than the second
|
||||||
|
//
|
||||||
|
// assert.Lessf(t, 1, 2, "error message %s", "formatted")
|
||||||
|
// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted")
|
||||||
|
// assert.Lessf(t, "a", "b", "error message %s", "formatted")
|
||||||
|
func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Less(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LessOrEqualf asserts that the first element is less than or equal to the second
|
||||||
|
//
|
||||||
|
// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted")
|
||||||
|
// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted")
|
||||||
|
// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted")
|
||||||
|
// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted")
|
||||||
|
func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return LessOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Negativef asserts that the specified element is negative
|
||||||
|
//
|
||||||
|
// assert.Negativef(t, -1, "error message %s", "formatted")
|
||||||
|
// assert.Negativef(t, -1.23, "error message %s", "formatted")
|
||||||
|
func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Negative(t, e, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nilf asserts that the specified object is nil.
|
||||||
|
//
|
||||||
|
// assert.Nilf(t, err, "error message %s", "formatted")
|
||||||
|
func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Nil(t, object, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoErrorf asserts that a function returned no error (i.e. `nil`).
|
||||||
|
//
|
||||||
|
// actualObj, err := SomeFunction()
|
||||||
|
// if assert.NoErrorf(t, err, "error message %s", "formatted") {
|
||||||
|
// assert.Equal(t, expectedObj, actualObj)
|
||||||
|
// }
|
||||||
|
func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return NoError(t, err, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the
|
||||||
|
// specified substring or element.
|
||||||
|
//
|
||||||
|
// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted")
|
||||||
|
// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted")
|
||||||
|
// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted")
|
||||||
|
func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return NotContains(t, s, contains, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotEqualf asserts that the specified values are NOT equal.
|
||||||
|
//
|
||||||
|
// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted")
|
||||||
|
//
|
||||||
|
// Pointer variable equality is determined based on the equality of the
|
||||||
|
// referenced values (as opposed to the memory addresses).
|
||||||
|
func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotEqualValuesf asserts that two objects are not equal even when converted to the same type
|
||||||
|
//
|
||||||
|
// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted")
|
||||||
|
func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NotNilf asserts that the specified object is not nil.
|
||||||
|
//
|
||||||
|
// assert.NotNilf(t, err, "error message %s", "formatted")
|
||||||
|
func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return NotNil(t, object, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Positivef asserts that the specified element is positive
|
||||||
|
//
|
||||||
|
// assert.Positivef(t, 1, "error message %s", "formatted")
|
||||||
|
// assert.Positivef(t, 1.23, "error message %s", "formatted")
|
||||||
|
func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return Positive(t, e, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truef asserts that the specified value is true.
|
||||||
|
//
|
||||||
|
// assert.Truef(t, myBool, "error message %s", "formatted")
|
||||||
|
func Truef(t TestingT, value bool, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return True(t, value, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithinDurationf asserts that the two times are within duration delta of each other.
|
||||||
|
//
|
||||||
|
// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted")
|
||||||
|
func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
|
||||||
|
}
|
126
internal/assert/assertion_mongo.go
Normal file
126
internal/assert/assertion_mongo.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
// assertion_mongo.go contains MongoDB-specific extensions to the "assert"
|
||||||
|
// package.
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DifferentAddressRanges asserts that two byte slices reference distinct memory
|
||||||
|
// address ranges, meaning they reference different underlying byte arrays.
|
||||||
|
func DifferentAddressRanges(t TestingT, a, b []byte) (ok bool) {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a) == 0 || len(b) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the start and end memory addresses for the underlying byte array for
|
||||||
|
// each input byte slice.
|
||||||
|
sliceAddrRange := func(b []byte) (uintptr, uintptr) {
|
||||||
|
sh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
|
||||||
|
return sh.Data, sh.Data + uintptr(sh.Cap-1)
|
||||||
|
}
|
||||||
|
aStart, aEnd := sliceAddrRange(a)
|
||||||
|
bStart, bEnd := sliceAddrRange(b)
|
||||||
|
|
||||||
|
// If "b" starts after "a" ends or "a" starts after "b" ends, there is no
|
||||||
|
// overlap.
|
||||||
|
if bStart > aEnd || aStart > bEnd {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, calculate the overlap start and end and print the memory
|
||||||
|
// overlap error message.
|
||||||
|
min := func(a, b uintptr) uintptr {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
max := func(a, b uintptr) uintptr {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
overlapLow := max(aStart, bStart)
|
||||||
|
overlapHigh := min(aEnd, bEnd)
|
||||||
|
|
||||||
|
t.Errorf("Byte slices point to the same underlying byte array:\n"+
|
||||||
|
"\ta addresses:\t%d ... %d\n"+
|
||||||
|
"\tb addresses:\t%d ... %d\n"+
|
||||||
|
"\toverlap:\t%d ... %d",
|
||||||
|
aStart, aEnd,
|
||||||
|
bStart, bEnd,
|
||||||
|
overlapLow, overlapHigh)
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// EqualBSON asserts that the expected and actual BSON binary values are equal.
|
||||||
|
// If the values are not equal, it prints both the binary and Extended JSON diff
|
||||||
|
// of the BSON values. The provided BSON value types must implement the
|
||||||
|
// fmt.Stringer interface.
|
||||||
|
func EqualBSON(t TestingT, expected, actual interface{}) bool {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Equal(t,
|
||||||
|
expected,
|
||||||
|
actual,
|
||||||
|
`expected and actual BSON values do not match
|
||||||
|
As Extended JSON:
|
||||||
|
Expected: %s
|
||||||
|
Actual : %s`,
|
||||||
|
expected.(fmt.Stringer).String(),
|
||||||
|
actual.(fmt.Stringer).String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Soon runs the provided callback and fails the passed-in test if the callback
|
||||||
|
// does not complete within timeout. The provided callback should respect the
|
||||||
|
// passed-in context and cease execution when it has expired.
|
||||||
|
//
|
||||||
|
// Deprecated: This function will be removed with GODRIVER-2667, use
|
||||||
|
// assert.Eventually instead.
|
||||||
|
func Soon(t TestingT, callback func(ctx context.Context), timeout time.Duration) {
|
||||||
|
if h, ok := t.(tHelper); ok {
|
||||||
|
h.Helper()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create context to manually cancel callback after Soon assertion.
|
||||||
|
callbackCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
fullCallback := func() {
|
||||||
|
callback(callbackCtx)
|
||||||
|
done <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
timer := time.NewTimer(timeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
go fullCallback()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
case <-timer.C:
|
||||||
|
t.Errorf("timed out in %s waiting for callback", timeout)
|
||||||
|
}
|
||||||
|
}
|
125
internal/assert/assertion_mongo_test.go
Normal file
125
internal/assert/assertion_mongo_test.go
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDifferentAddressRanges(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
slice := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
a []byte
|
||||||
|
b []byte
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "distinct byte slices",
|
||||||
|
a: []byte{0, 1, 2, 3},
|
||||||
|
b: []byte{0, 1, 2, 3},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same byte slice",
|
||||||
|
a: slice,
|
||||||
|
b: slice,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whole and subslice",
|
||||||
|
a: slice,
|
||||||
|
b: slice[:4],
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two subslices",
|
||||||
|
a: slice[1:2],
|
||||||
|
b: slice[3:4],
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
a: []byte{0, 1, 2, 3},
|
||||||
|
b: []byte{},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
a: []byte{0, 1, 2, 3},
|
||||||
|
b: nil,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc // Capture range variable.
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := DifferentAddressRanges(new(testing.T), tc.a, tc.b)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("DifferentAddressRanges(%p, %p) = %v, want %v", tc.a, tc.b, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEqualBSON(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
expected interface{}
|
||||||
|
actual interface{}
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "equal bson.Raw",
|
||||||
|
expected: bson.Raw{5, 0, 0, 0, 0},
|
||||||
|
actual: bson.Raw{5, 0, 0, 0, 0},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different bson.Raw",
|
||||||
|
expected: bson.Raw{8, 0, 0, 0, 10, 120, 0, 0},
|
||||||
|
actual: bson.Raw{5, 0, 0, 0, 0},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid bson.Raw",
|
||||||
|
expected: bson.Raw{99, 99, 99, 99},
|
||||||
|
actual: bson.Raw{5, 0, 0, 0, 0},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil bson.Raw",
|
||||||
|
expected: bson.Raw(nil),
|
||||||
|
actual: bson.Raw(nil),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc // Capture range variable.
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := EqualBSON(new(testing.T), tc.expected, tc.actual)
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("EqualBSON(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
1075
internal/assert/assertions.go
Normal file
1075
internal/assert/assertions.go
Normal file
File diff suppressed because it is too large
Load Diff
1231
internal/assert/assertions_test.go
Normal file
1231
internal/assert/assertions_test.go
Normal file
File diff suppressed because it is too large
Load Diff
766
internal/assert/difflib.go
Normal file
766
internal/assert/difflib.go
Normal file
@ -0,0 +1,766 @@
|
|||||||
|
// Copied from https://github.com/pmezard/go-difflib/blob/5d4384ee4fb2527b0a1256a821ebfc92f91efefc/difflib/difflib.go
|
||||||
|
|
||||||
|
// Copyright 2013 Patrick Mezard. All rights reserved. Use of this source code is
|
||||||
|
// governed by a license that can be found in the THIRD-PARTY-NOTICES file.
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func max(a, b int) int {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateRatio(matches, length int) float64 {
|
||||||
|
if length > 0 {
|
||||||
|
return 2.0 * float64(matches) / float64(length)
|
||||||
|
}
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
type Match struct {
|
||||||
|
A int
|
||||||
|
B int
|
||||||
|
Size int
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpCode struct {
|
||||||
|
Tag byte
|
||||||
|
I1 int
|
||||||
|
I2 int
|
||||||
|
J1 int
|
||||||
|
J2 int
|
||||||
|
}
|
||||||
|
|
||||||
|
// SequenceMatcher compares sequence of strings. The basic
|
||||||
|
// algorithm predates, and is a little fancier than, an algorithm
|
||||||
|
// published in the late 1980's by Ratcliff and Obershelp under the
|
||||||
|
// hyperbolic name "gestalt pattern matching". The basic idea is to find
|
||||||
|
// the longest contiguous matching subsequence that contains no "junk"
|
||||||
|
// elements (R-O doesn't address junk). The same idea is then applied
|
||||||
|
// recursively to the pieces of the sequences to the left and to the right
|
||||||
|
// of the matching subsequence. This does not yield minimal edit
|
||||||
|
// sequences, but does tend to yield matches that "look right" to people.
|
||||||
|
//
|
||||||
|
// SequenceMatcher tries to compute a "human-friendly diff" between two
|
||||||
|
// sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the
|
||||||
|
// longest *contiguous* & junk-free matching subsequence. That's what
|
||||||
|
// catches peoples' eyes. The Windows(tm) windiff has another interesting
|
||||||
|
// notion, pairing up elements that appear uniquely in each sequence.
|
||||||
|
// That, and the method here, appear to yield more intuitive difference
|
||||||
|
// reports than does diff. This method appears to be the least vulnerable
|
||||||
|
// to syncing up on blocks of "junk lines", though (like blank lines in
|
||||||
|
// ordinary text files, or maybe "<P>" lines in HTML files). That may be
|
||||||
|
// because this is the only method of the 3 that has a *concept* of
|
||||||
|
// "junk" <wink>.
|
||||||
|
//
|
||||||
|
// Timing: Basic R-O is cubic time worst case and quadratic time expected
|
||||||
|
// case. SequenceMatcher is quadratic time for the worst case and has
|
||||||
|
// expected-case behavior dependent in a complicated way on how many
|
||||||
|
// elements the sequences have in common; best case time is linear.
|
||||||
|
type SequenceMatcher struct {
|
||||||
|
a []string
|
||||||
|
b []string
|
||||||
|
b2j map[string][]int
|
||||||
|
IsJunk func(string) bool
|
||||||
|
autoJunk bool
|
||||||
|
bJunk map[string]struct{}
|
||||||
|
matchingBlocks []Match
|
||||||
|
fullBCount map[string]int
|
||||||
|
bPopular map[string]struct{}
|
||||||
|
opCodes []OpCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMatcher(a, b []string) *SequenceMatcher {
|
||||||
|
m := SequenceMatcher{autoJunk: true}
|
||||||
|
m.SetSeqs(a, b)
|
||||||
|
return &m
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMatcherWithJunk(a, b []string, autoJunk bool,
|
||||||
|
isJunk func(string) bool) *SequenceMatcher {
|
||||||
|
|
||||||
|
m := SequenceMatcher{IsJunk: isJunk, autoJunk: autoJunk}
|
||||||
|
m.SetSeqs(a, b)
|
||||||
|
return &m
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSeqs sets the two sequences to be compared.
|
||||||
|
func (m *SequenceMatcher) SetSeqs(a, b []string) {
|
||||||
|
m.SetSeq1(a)
|
||||||
|
m.SetSeq2(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSeq1 sets the first sequence to be compared. The second sequence to be compared is
|
||||||
|
// not changed.
|
||||||
|
//
|
||||||
|
// SequenceMatcher computes and caches detailed information about the second
|
||||||
|
// sequence, so if you want to compare one sequence S against many sequences,
|
||||||
|
// use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other
|
||||||
|
// sequences.
|
||||||
|
//
|
||||||
|
// See also SetSeqs() and SetSeq2().
|
||||||
|
func (m *SequenceMatcher) SetSeq1(a []string) {
|
||||||
|
if &a == &m.a {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.a = a
|
||||||
|
m.matchingBlocks = nil
|
||||||
|
m.opCodes = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSeq2 sets the second sequence to be compared. The first sequence to be compared is
|
||||||
|
// not changed.
|
||||||
|
func (m *SequenceMatcher) SetSeq2(b []string) {
|
||||||
|
if &b == &m.b {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.b = b
|
||||||
|
m.matchingBlocks = nil
|
||||||
|
m.opCodes = nil
|
||||||
|
m.fullBCount = nil
|
||||||
|
m.chainB()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SequenceMatcher) chainB() {
|
||||||
|
// Populate line -> index mapping
|
||||||
|
b2j := map[string][]int{}
|
||||||
|
for i, s := range m.b {
|
||||||
|
indices := b2j[s]
|
||||||
|
indices = append(indices, i)
|
||||||
|
b2j[s] = indices
|
||||||
|
}
|
||||||
|
|
||||||
|
// Purge junk elements
|
||||||
|
m.bJunk = map[string]struct{}{}
|
||||||
|
if m.IsJunk != nil {
|
||||||
|
junk := m.bJunk
|
||||||
|
for s := range b2j {
|
||||||
|
if m.IsJunk(s) {
|
||||||
|
junk[s] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for s := range junk {
|
||||||
|
delete(b2j, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Purge remaining popular elements
|
||||||
|
popular := map[string]struct{}{}
|
||||||
|
n := len(m.b)
|
||||||
|
if m.autoJunk && n >= 200 {
|
||||||
|
ntest := n/100 + 1
|
||||||
|
for s, indices := range b2j {
|
||||||
|
if len(indices) > ntest {
|
||||||
|
popular[s] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for s := range popular {
|
||||||
|
delete(b2j, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.bPopular = popular
|
||||||
|
m.b2j = b2j
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SequenceMatcher) isBJunk(s string) bool {
|
||||||
|
_, ok := m.bJunk[s]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find longest matching block in a[alo:ahi] and b[blo:bhi].
|
||||||
|
//
|
||||||
|
// If IsJunk is not defined:
|
||||||
|
//
|
||||||
|
// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where
|
||||||
|
//
|
||||||
|
// alo <= i <= i+k <= ahi
|
||||||
|
// blo <= j <= j+k <= bhi
|
||||||
|
//
|
||||||
|
// and for all (i',j',k') meeting those conditions,
|
||||||
|
//
|
||||||
|
// k >= k'
|
||||||
|
// i <= i'
|
||||||
|
// and if i == i', j <= j'
|
||||||
|
//
|
||||||
|
// In other words, of all maximal matching blocks, return one that
|
||||||
|
// starts earliest in a, and of all those maximal matching blocks that
|
||||||
|
// start earliest in a, return the one that starts earliest in b.
|
||||||
|
//
|
||||||
|
// If IsJunk is defined, first the longest matching block is
|
||||||
|
// determined as above, but with the additional restriction that no
|
||||||
|
// junk element appears in the block. Then that block is extended as
|
||||||
|
// far as possible by matching (only) junk elements on both sides. So
|
||||||
|
// the resulting block never matches on junk except as identical junk
|
||||||
|
// happens to be adjacent to an "interesting" match.
|
||||||
|
//
|
||||||
|
// If no blocks match, return (alo, blo, 0).
|
||||||
|
func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match {
|
||||||
|
// CAUTION: stripping common prefix or suffix would be incorrect.
|
||||||
|
// E.g.,
|
||||||
|
// ab
|
||||||
|
// acab
|
||||||
|
// Longest matching block is "ab", but if common prefix is
|
||||||
|
// stripped, it's "a" (tied with "b"). UNIX(tm) diff does so
|
||||||
|
// strip, so ends up claiming that ab is changed to acab by
|
||||||
|
// inserting "ca" in the middle. That's minimal but unintuitive:
|
||||||
|
// "it's obvious" that someone inserted "ac" at the front.
|
||||||
|
// Windiff ends up at the same place as diff, but by pairing up
|
||||||
|
// the unique 'b's and then matching the first two 'a's.
|
||||||
|
besti, bestj, bestsize := alo, blo, 0
|
||||||
|
|
||||||
|
// find longest junk-free match
|
||||||
|
// during an iteration of the loop, j2len[j] = length of longest
|
||||||
|
// junk-free match ending with a[i-1] and b[j]
|
||||||
|
j2len := map[int]int{}
|
||||||
|
for i := alo; i != ahi; i++ {
|
||||||
|
// look at all instances of a[i] in b; note that because
|
||||||
|
// b2j has no junk keys, the loop is skipped if a[i] is junk
|
||||||
|
newj2len := map[int]int{}
|
||||||
|
for _, j := range m.b2j[m.a[i]] {
|
||||||
|
// a[i] matches b[j]
|
||||||
|
if j < blo {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if j >= bhi {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
k := j2len[j-1] + 1
|
||||||
|
newj2len[j] = k
|
||||||
|
if k > bestsize {
|
||||||
|
besti, bestj, bestsize = i-k+1, j-k+1, k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
j2len = newj2len
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extend the best by non-junk elements on each end. In particular,
|
||||||
|
// "popular" non-junk elements aren't in b2j, which greatly speeds
|
||||||
|
// the inner loop above, but also means "the best" match so far
|
||||||
|
// doesn't contain any junk *or* popular non-junk elements.
|
||||||
|
for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) &&
|
||||||
|
m.a[besti-1] == m.b[bestj-1] {
|
||||||
|
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
|
||||||
|
}
|
||||||
|
for besti+bestsize < ahi && bestj+bestsize < bhi &&
|
||||||
|
!m.isBJunk(m.b[bestj+bestsize]) &&
|
||||||
|
m.a[besti+bestsize] == m.b[bestj+bestsize] {
|
||||||
|
bestsize++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we have a wholly interesting match (albeit possibly
|
||||||
|
// empty!), we may as well suck up the matching junk on each
|
||||||
|
// side of it too. Can't think of a good reason not to, and it
|
||||||
|
// saves post-processing the (possibly considerable) expense of
|
||||||
|
// figuring out what to do with it. In the case of an empty
|
||||||
|
// interesting match, this is clearly the right thing to do,
|
||||||
|
// because no other kind of match is possible in the regions.
|
||||||
|
for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) &&
|
||||||
|
m.a[besti-1] == m.b[bestj-1] {
|
||||||
|
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
|
||||||
|
}
|
||||||
|
for besti+bestsize < ahi && bestj+bestsize < bhi &&
|
||||||
|
m.isBJunk(m.b[bestj+bestsize]) &&
|
||||||
|
m.a[besti+bestsize] == m.b[bestj+bestsize] {
|
||||||
|
bestsize++
|
||||||
|
}
|
||||||
|
|
||||||
|
return Match{A: besti, B: bestj, Size: bestsize}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMatchingBlocks returns list of triples describing matching subsequences.
|
||||||
|
//
|
||||||
|
// Each triple is of the form (i, j, n), and means that
|
||||||
|
// a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in
|
||||||
|
// i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are
|
||||||
|
// adjacent triples in the list, and the second is not the last triple in the
|
||||||
|
// list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe
|
||||||
|
// adjacent equal blocks.
|
||||||
|
//
|
||||||
|
// The last triple is a dummy, (len(a), len(b), 0), and is the only
|
||||||
|
// triple with n==0.
|
||||||
|
func (m *SequenceMatcher) GetMatchingBlocks() []Match {
|
||||||
|
if m.matchingBlocks != nil {
|
||||||
|
return m.matchingBlocks
|
||||||
|
}
|
||||||
|
|
||||||
|
var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match
|
||||||
|
matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match {
|
||||||
|
match := m.findLongestMatch(alo, ahi, blo, bhi)
|
||||||
|
i, j, k := match.A, match.B, match.Size
|
||||||
|
if match.Size > 0 {
|
||||||
|
if alo < i && blo < j {
|
||||||
|
matched = matchBlocks(alo, i, blo, j, matched)
|
||||||
|
}
|
||||||
|
matched = append(matched, match)
|
||||||
|
if i+k < ahi && j+k < bhi {
|
||||||
|
matched = matchBlocks(i+k, ahi, j+k, bhi, matched)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return matched
|
||||||
|
}
|
||||||
|
matched := matchBlocks(0, len(m.a), 0, len(m.b), nil)
|
||||||
|
|
||||||
|
// It's possible that we have adjacent equal blocks in the
|
||||||
|
// matching_blocks list now.
|
||||||
|
nonAdjacent := []Match{}
|
||||||
|
i1, j1, k1 := 0, 0, 0
|
||||||
|
for _, b := range matched {
|
||||||
|
// Is this block adjacent to i1, j1, k1?
|
||||||
|
i2, j2, k2 := b.A, b.B, b.Size
|
||||||
|
if i1+k1 == i2 && j1+k1 == j2 {
|
||||||
|
// Yes, so collapse them -- this just increases the length of
|
||||||
|
// the first block by the length of the second, and the first
|
||||||
|
// block so lengthened remains the block to compare against.
|
||||||
|
k1 += k2
|
||||||
|
} else {
|
||||||
|
// Not adjacent. Remember the first block (k1==0 means it's
|
||||||
|
// the dummy we started with), and make the second block the
|
||||||
|
// new block to compare against.
|
||||||
|
if k1 > 0 {
|
||||||
|
nonAdjacent = append(nonAdjacent, Match{i1, j1, k1})
|
||||||
|
}
|
||||||
|
i1, j1, k1 = i2, j2, k2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if k1 > 0 {
|
||||||
|
nonAdjacent = append(nonAdjacent, Match{i1, j1, k1})
|
||||||
|
}
|
||||||
|
|
||||||
|
nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0})
|
||||||
|
m.matchingBlocks = nonAdjacent
|
||||||
|
return m.matchingBlocks
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpCodes returns a list of 5-tuples describing how to turn a into b.
|
||||||
|
//
|
||||||
|
// Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple
|
||||||
|
// has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the
|
||||||
|
// tuple preceding it, and likewise for j1 == the previous j2.
|
||||||
|
//
|
||||||
|
// The tags are characters, with these meanings:
|
||||||
|
//
|
||||||
|
// 'r' (replace): a[i1:i2] should be replaced by b[j1:j2]
|
||||||
|
//
|
||||||
|
// 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case.
|
||||||
|
//
|
||||||
|
// 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case.
|
||||||
|
//
|
||||||
|
// 'e' (equal): a[i1:i2] == b[j1:j2]
|
||||||
|
func (m *SequenceMatcher) GetOpCodes() []OpCode {
|
||||||
|
if m.opCodes != nil {
|
||||||
|
return m.opCodes
|
||||||
|
}
|
||||||
|
i, j := 0, 0
|
||||||
|
matching := m.GetMatchingBlocks()
|
||||||
|
opCodes := make([]OpCode, 0, len(matching))
|
||||||
|
for _, m := range matching {
|
||||||
|
// invariant: we've pumped out correct diffs to change
|
||||||
|
// a[:i] into b[:j], and the next matching block is
|
||||||
|
// a[ai:ai+size] == b[bj:bj+size]. So we need to pump
|
||||||
|
// out a diff to change a[i:ai] into b[j:bj], pump out
|
||||||
|
// the matching block, and move (i,j) beyond the match
|
||||||
|
ai, bj, size := m.A, m.B, m.Size
|
||||||
|
tag := byte(0)
|
||||||
|
if i < ai && j < bj {
|
||||||
|
tag = 'r'
|
||||||
|
} else if i < ai {
|
||||||
|
tag = 'd'
|
||||||
|
} else if j < bj {
|
||||||
|
tag = 'i'
|
||||||
|
}
|
||||||
|
if tag > 0 {
|
||||||
|
opCodes = append(opCodes, OpCode{tag, i, ai, j, bj})
|
||||||
|
}
|
||||||
|
i, j = ai+size, bj+size
|
||||||
|
// the list of matching blocks is terminated by a
|
||||||
|
// sentinel with size 0
|
||||||
|
if size > 0 {
|
||||||
|
opCodes = append(opCodes, OpCode{'e', ai, i, bj, j})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.opCodes = opCodes
|
||||||
|
return m.opCodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroupedOpCodes isolates change clusters by eliminating ranges with no changes.
|
||||||
|
//
|
||||||
|
// Returns a generator of groups with up to n lines of context.
|
||||||
|
// Each group is in the same format as returned by GetOpCodes().
|
||||||
|
func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
|
||||||
|
if n < 0 {
|
||||||
|
n = 3
|
||||||
|
}
|
||||||
|
codes := m.GetOpCodes()
|
||||||
|
if len(codes) == 0 {
|
||||||
|
codes = []OpCode{{'e', 0, 1, 0, 1}}
|
||||||
|
}
|
||||||
|
// Fixup leading and trailing groups if they show no changes.
|
||||||
|
if codes[0].Tag == 'e' {
|
||||||
|
c := codes[0]
|
||||||
|
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||||
|
codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2}
|
||||||
|
}
|
||||||
|
if codes[len(codes)-1].Tag == 'e' {
|
||||||
|
c := codes[len(codes)-1]
|
||||||
|
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||||
|
codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)}
|
||||||
|
}
|
||||||
|
nn := n + n
|
||||||
|
groups := [][]OpCode{}
|
||||||
|
group := []OpCode{}
|
||||||
|
for _, c := range codes {
|
||||||
|
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||||
|
// End the current group and start a new one whenever
|
||||||
|
// there is a large range with no changes.
|
||||||
|
if c.Tag == 'e' && i2-i1 > nn {
|
||||||
|
group = append(group, OpCode{c.Tag, i1, min(i2, i1+n),
|
||||||
|
j1, min(j2, j1+n)})
|
||||||
|
groups = append(groups, group)
|
||||||
|
group = []OpCode{}
|
||||||
|
i1, j1 = max(i1, i2-n), max(j1, j2-n)
|
||||||
|
}
|
||||||
|
group = append(group, OpCode{c.Tag, i1, i2, j1, j2})
|
||||||
|
}
|
||||||
|
if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') {
|
||||||
|
groups = append(groups, group)
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ratio returns a measure of the sequences' similarity (float in [0,1]).
|
||||||
|
//
|
||||||
|
// Where T is the total number of elements in both sequences, and
|
||||||
|
// M is the number of matches, this is 2.0*M / T.
|
||||||
|
// Note that this is 1 if the sequences are identical, and 0 if
|
||||||
|
// they have nothing in common.
|
||||||
|
//
|
||||||
|
// .Ratio() is expensive to compute if you haven't already computed
|
||||||
|
// .GetMatchingBlocks() or .GetOpCodes(), in which case you may
|
||||||
|
// want to try .QuickRatio() or .RealQuickRation() first to get an
|
||||||
|
// upper bound.
|
||||||
|
func (m *SequenceMatcher) Ratio() float64 {
|
||||||
|
matches := 0
|
||||||
|
for _, m := range m.GetMatchingBlocks() {
|
||||||
|
matches += m.Size
|
||||||
|
}
|
||||||
|
return calculateRatio(matches, len(m.a)+len(m.b))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuickRatio returns an upper bound on ratio() relatively quickly.
|
||||||
|
//
|
||||||
|
// This isn't defined beyond that it is an upper bound on .Ratio(), and
|
||||||
|
// is faster to compute.
|
||||||
|
func (m *SequenceMatcher) QuickRatio() float64 {
|
||||||
|
// viewing a and b as multisets, set matches to the cardinality
|
||||||
|
// of their intersection; this counts the number of matches
|
||||||
|
// without regard to order, so is clearly an upper bound
|
||||||
|
if m.fullBCount == nil {
|
||||||
|
m.fullBCount = map[string]int{}
|
||||||
|
for _, s := range m.b {
|
||||||
|
m.fullBCount[s] = m.fullBCount[s] + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// avail[x] is the number of times x appears in 'b' less the
|
||||||
|
// number of times we've seen it in 'a' so far ... kinda
|
||||||
|
avail := map[string]int{}
|
||||||
|
matches := 0
|
||||||
|
for _, s := range m.a {
|
||||||
|
n, ok := avail[s]
|
||||||
|
if !ok {
|
||||||
|
n = m.fullBCount[s]
|
||||||
|
}
|
||||||
|
avail[s] = n - 1
|
||||||
|
if n > 0 {
|
||||||
|
matches++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return calculateRatio(matches, len(m.a)+len(m.b))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RealQuickRatio returns an upper bound on ratio() very quickly.
|
||||||
|
//
|
||||||
|
// This isn't defined beyond that it is an upper bound on .Ratio(), and
|
||||||
|
// is faster to compute than either .Ratio() or .QuickRatio().
|
||||||
|
func (m *SequenceMatcher) RealQuickRatio() float64 {
|
||||||
|
la, lb := len(m.a), len(m.b)
|
||||||
|
return calculateRatio(min(la, lb), la+lb)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert range to the "ed" format
|
||||||
|
func formatRangeUnified(start, stop int) string {
|
||||||
|
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||||
|
beginning := start + 1 // lines start numbering with one
|
||||||
|
length := stop - start
|
||||||
|
if length == 1 {
|
||||||
|
return fmt.Sprintf("%d", beginning)
|
||||||
|
}
|
||||||
|
if length == 0 {
|
||||||
|
beginning-- // empty ranges begin at line just before the range
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d,%d", beginning, length)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnifiedDiff represents the unified diff parameters.
|
||||||
|
type UnifiedDiff struct {
|
||||||
|
A []string // First sequence lines
|
||||||
|
FromFile string // First file name
|
||||||
|
FromDate string // First file time
|
||||||
|
B []string // Second sequence lines
|
||||||
|
ToFile string // Second file name
|
||||||
|
ToDate string // Second file time
|
||||||
|
Eol string // Headers end of line, defaults to LF
|
||||||
|
Context int // Number of context lines
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteUnifiedDiff compares two sequences of lines; generates the delta as
|
||||||
|
// a unified diff.
|
||||||
|
//
|
||||||
|
// Unified diffs are a compact way of showing line changes and a few
|
||||||
|
// lines of context. The number of context lines is set by 'n' which
|
||||||
|
// defaults to three.
|
||||||
|
//
|
||||||
|
// By default, the diff control lines (those with ---, +++, or @@) are
|
||||||
|
// created with a trailing newline. This is helpful so that inputs
|
||||||
|
// created from file.readlines() result in diffs that are suitable for
|
||||||
|
// file.writelines() since both the inputs and outputs have trailing
|
||||||
|
// newlines.
|
||||||
|
//
|
||||||
|
// For inputs that do not have trailing newlines, set the lineterm
|
||||||
|
// argument to "" so that the output will be uniformly newline free.
|
||||||
|
//
|
||||||
|
// The unidiff format normally has a header for filenames and modification
|
||||||
|
// times. Any or all of these may be specified using strings for
|
||||||
|
// 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'.
|
||||||
|
// The modification times are normally expressed in the ISO 8601 format.
|
||||||
|
func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error {
|
||||||
|
buf := bufio.NewWriter(writer)
|
||||||
|
defer buf.Flush()
|
||||||
|
wf := func(format string, args ...interface{}) error {
|
||||||
|
_, err := buf.WriteString(fmt.Sprintf(format, args...))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ws := func(s string) error {
|
||||||
|
_, err := buf.WriteString(s)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(diff.Eol) == 0 {
|
||||||
|
diff.Eol = "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
started := false
|
||||||
|
m := NewMatcher(diff.A, diff.B)
|
||||||
|
for _, g := range m.GetGroupedOpCodes(diff.Context) {
|
||||||
|
if !started {
|
||||||
|
started = true
|
||||||
|
fromDate := ""
|
||||||
|
if len(diff.FromDate) > 0 {
|
||||||
|
fromDate = "\t" + diff.FromDate
|
||||||
|
}
|
||||||
|
toDate := ""
|
||||||
|
if len(diff.ToDate) > 0 {
|
||||||
|
toDate = "\t" + diff.ToDate
|
||||||
|
}
|
||||||
|
if diff.FromFile != "" || diff.ToFile != "" {
|
||||||
|
err := wf("--- %s%s%s", diff.FromFile, fromDate, diff.Eol)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = wf("+++ %s%s%s", diff.ToFile, toDate, diff.Eol)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
first, last := g[0], g[len(g)-1]
|
||||||
|
range1 := formatRangeUnified(first.I1, last.I2)
|
||||||
|
range2 := formatRangeUnified(first.J1, last.J2)
|
||||||
|
if err := wf("@@ -%s +%s @@%s", range1, range2, diff.Eol); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, c := range g {
|
||||||
|
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||||
|
if c.Tag == 'e' {
|
||||||
|
for _, line := range diff.A[i1:i2] {
|
||||||
|
if err := ws(" " + line); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c.Tag == 'r' || c.Tag == 'd' {
|
||||||
|
for _, line := range diff.A[i1:i2] {
|
||||||
|
if err := ws("-" + line); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.Tag == 'r' || c.Tag == 'i' {
|
||||||
|
for _, line := range diff.B[j1:j2] {
|
||||||
|
if err := ws("+" + line); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnifiedDiffString is like WriteUnifiedDiff but returns the diff as a string.
|
||||||
|
func GetUnifiedDiffString(diff UnifiedDiff) (string, error) {
|
||||||
|
w := &bytes.Buffer{}
|
||||||
|
err := WriteUnifiedDiff(w, diff)
|
||||||
|
return w.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert range to the "ed" format.
|
||||||
|
func formatRangeContext(start, stop int) string {
|
||||||
|
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||||
|
beginning := start + 1 // lines start numbering with one
|
||||||
|
length := stop - start
|
||||||
|
if length == 0 {
|
||||||
|
beginning-- // empty ranges begin at line just before the range
|
||||||
|
}
|
||||||
|
if length <= 1 {
|
||||||
|
return fmt.Sprintf("%d", beginning)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d,%d", beginning, beginning+length-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContextDiff UnifiedDiff
|
||||||
|
|
||||||
|
// WriteContextDiff compares two sequences of lines; generates the delta as a context diff.
|
||||||
|
//
|
||||||
|
// Context diffs are a compact way of showing line changes and a few
|
||||||
|
// lines of context. The number of context lines is set by diff.Context
|
||||||
|
// which defaults to three.
|
||||||
|
//
|
||||||
|
// By default, the diff control lines (those with *** or ---) are
|
||||||
|
// created with a trailing newline.
|
||||||
|
//
|
||||||
|
// For inputs that do not have trailing newlines, set the diff.Eol
|
||||||
|
// argument to "" so that the output will be uniformly newline free.
|
||||||
|
//
|
||||||
|
// The context diff format normally has a header for filenames and
|
||||||
|
// modification times. Any or all of these may be specified using
|
||||||
|
// strings for diff.FromFile, diff.ToFile, diff.FromDate, diff.ToDate.
|
||||||
|
// The modification times are normally expressed in the ISO 8601 format.
|
||||||
|
// If not specified, the strings default to blanks.
|
||||||
|
func WriteContextDiff(writer io.Writer, diff ContextDiff) error {
|
||||||
|
buf := bufio.NewWriter(writer)
|
||||||
|
defer buf.Flush()
|
||||||
|
var diffErr error
|
||||||
|
wf := func(format string, args ...interface{}) {
|
||||||
|
_, err := buf.WriteString(fmt.Sprintf(format, args...))
|
||||||
|
if diffErr == nil && err != nil {
|
||||||
|
diffErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ws := func(s string) {
|
||||||
|
_, err := buf.WriteString(s)
|
||||||
|
if diffErr == nil && err != nil {
|
||||||
|
diffErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(diff.Eol) == 0 {
|
||||||
|
diff.Eol = "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := map[byte]string{
|
||||||
|
'i': "+ ",
|
||||||
|
'd': "- ",
|
||||||
|
'r': "! ",
|
||||||
|
'e': " ",
|
||||||
|
}
|
||||||
|
|
||||||
|
started := false
|
||||||
|
m := NewMatcher(diff.A, diff.B)
|
||||||
|
for _, g := range m.GetGroupedOpCodes(diff.Context) {
|
||||||
|
if !started {
|
||||||
|
started = true
|
||||||
|
fromDate := ""
|
||||||
|
if len(diff.FromDate) > 0 {
|
||||||
|
fromDate = "\t" + diff.FromDate
|
||||||
|
}
|
||||||
|
toDate := ""
|
||||||
|
if len(diff.ToDate) > 0 {
|
||||||
|
toDate = "\t" + diff.ToDate
|
||||||
|
}
|
||||||
|
if diff.FromFile != "" || diff.ToFile != "" {
|
||||||
|
wf("*** %s%s%s", diff.FromFile, fromDate, diff.Eol)
|
||||||
|
wf("--- %s%s%s", diff.ToFile, toDate, diff.Eol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
first, last := g[0], g[len(g)-1]
|
||||||
|
ws("***************" + diff.Eol)
|
||||||
|
|
||||||
|
range1 := formatRangeContext(first.I1, last.I2)
|
||||||
|
wf("*** %s ****%s", range1, diff.Eol)
|
||||||
|
for _, c := range g {
|
||||||
|
if c.Tag == 'r' || c.Tag == 'd' {
|
||||||
|
for _, cc := range g {
|
||||||
|
if cc.Tag == 'i' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, line := range diff.A[cc.I1:cc.I2] {
|
||||||
|
ws(prefix[cc.Tag] + line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
range2 := formatRangeContext(first.J1, last.J2)
|
||||||
|
wf("--- %s ----%s", range2, diff.Eol)
|
||||||
|
for _, c := range g {
|
||||||
|
if c.Tag == 'r' || c.Tag == 'i' {
|
||||||
|
for _, cc := range g {
|
||||||
|
if cc.Tag == 'd' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, line := range diff.B[cc.J1:cc.J2] {
|
||||||
|
ws(prefix[cc.Tag] + line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return diffErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetContextDiffString is like WriteContextDiff but returns the diff as a string.
|
||||||
|
func GetContextDiffString(diff ContextDiff) (string, error) {
|
||||||
|
w := &bytes.Buffer{}
|
||||||
|
err := WriteContextDiff(w, diff)
|
||||||
|
return w.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SplitLines splits a string on "\n" while preserving them. The output can be used
|
||||||
|
// as input for UnifiedDiff and ContextDiff structures.
|
||||||
|
func SplitLines(s string) []string {
|
||||||
|
lines := strings.SplitAfter(s, "\n")
|
||||||
|
lines[len(lines)-1] += "\n"
|
||||||
|
return lines
|
||||||
|
}
|
326
internal/assert/difflib_test.go
Normal file
326
internal/assert/difflib_test.go
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
// Copied from https://github.com/pmezard/go-difflib/blob/5d4384ee4fb2527b0a1256a821ebfc92f91efefc/difflib/difflib_test.go
|
||||||
|
|
||||||
|
// Copyright 2013 Patrick Mezard. All rights reserved. Use of this source code is
|
||||||
|
// governed by a license that can be found in the THIRD-PARTY-NOTICES file.
|
||||||
|
|
||||||
|
package assert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func assertAlmostEqual(t *testing.T, a, b float64, places int) {
|
||||||
|
if math.Abs(a-b) > math.Pow10(-places) {
|
||||||
|
t.Errorf("%.7f != %.7f", a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEqual(t *testing.T, a, b interface{}) {
|
||||||
|
if !reflect.DeepEqual(a, b) {
|
||||||
|
t.Errorf("%v != %v", a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitChars(s string) []string {
|
||||||
|
chars := make([]string, 0, len(s))
|
||||||
|
// Assume ASCII inputs
|
||||||
|
for i := 0; i != len(s); i++ {
|
||||||
|
chars = append(chars, string(s[i]))
|
||||||
|
}
|
||||||
|
return chars
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSequenceMatcherRatio(t *testing.T) {
|
||||||
|
s := NewMatcher(splitChars("abcd"), splitChars("bcde"))
|
||||||
|
assertEqual(t, s.Ratio(), 0.75)
|
||||||
|
assertEqual(t, s.QuickRatio(), 0.75)
|
||||||
|
assertEqual(t, s.RealQuickRatio(), 1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOptCodes(t *testing.T) {
|
||||||
|
a := "qabxcd"
|
||||||
|
b := "abycdf"
|
||||||
|
s := NewMatcher(splitChars(a), splitChars(b))
|
||||||
|
w := &bytes.Buffer{}
|
||||||
|
for _, op := range s.GetOpCodes() {
|
||||||
|
fmt.Fprintf(w, "%s a[%d:%d], (%s) b[%d:%d] (%s)\n", string(op.Tag),
|
||||||
|
op.I1, op.I2, a[op.I1:op.I2], op.J1, op.J2, b[op.J1:op.J2])
|
||||||
|
}
|
||||||
|
result := w.String()
|
||||||
|
expected := `d a[0:1], (q) b[0:0] ()
|
||||||
|
e a[1:3], (ab) b[0:2] (ab)
|
||||||
|
r a[3:4], (x) b[2:3] (y)
|
||||||
|
e a[4:6], (cd) b[3:5] (cd)
|
||||||
|
i a[6:6], () b[5:6] (f)
|
||||||
|
`
|
||||||
|
if expected != result {
|
||||||
|
t.Errorf("unexpected op codes: \n%s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupedOpCodes(t *testing.T) {
|
||||||
|
a := []string{}
|
||||||
|
for i := 0; i != 39; i++ {
|
||||||
|
a = append(a, fmt.Sprintf("%02d", i))
|
||||||
|
}
|
||||||
|
b := []string{}
|
||||||
|
b = append(b, a[:8]...)
|
||||||
|
b = append(b, " i")
|
||||||
|
b = append(b, a[8:19]...)
|
||||||
|
b = append(b, " x")
|
||||||
|
b = append(b, a[20:22]...)
|
||||||
|
b = append(b, a[27:34]...)
|
||||||
|
b = append(b, " y")
|
||||||
|
b = append(b, a[35:]...)
|
||||||
|
s := NewMatcher(a, b)
|
||||||
|
w := &bytes.Buffer{}
|
||||||
|
for _, g := range s.GetGroupedOpCodes(-1) {
|
||||||
|
fmt.Fprintf(w, "group\n")
|
||||||
|
for _, op := range g {
|
||||||
|
fmt.Fprintf(w, " %s, %d, %d, %d, %d\n", string(op.Tag),
|
||||||
|
op.I1, op.I2, op.J1, op.J2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result := w.String()
|
||||||
|
expected := `group
|
||||||
|
e, 5, 8, 5, 8
|
||||||
|
i, 8, 8, 8, 9
|
||||||
|
e, 8, 11, 9, 12
|
||||||
|
group
|
||||||
|
e, 16, 19, 17, 20
|
||||||
|
r, 19, 20, 20, 21
|
||||||
|
e, 20, 22, 21, 23
|
||||||
|
d, 22, 27, 23, 23
|
||||||
|
e, 27, 30, 23, 26
|
||||||
|
group
|
||||||
|
e, 31, 34, 27, 30
|
||||||
|
r, 34, 35, 30, 31
|
||||||
|
e, 35, 38, 31, 34
|
||||||
|
`
|
||||||
|
if expected != result {
|
||||||
|
t.Errorf("unexpected op codes: \n%s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rep(s string, count int) string {
|
||||||
|
return strings.Repeat(s, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithAsciiOneInsert(t *testing.T) {
|
||||||
|
sm := NewMatcher(splitChars(rep("b", 100)),
|
||||||
|
splitChars("a"+rep("b", 100)))
|
||||||
|
assertAlmostEqual(t, sm.Ratio(), 0.995, 3)
|
||||||
|
assertEqual(t, sm.GetOpCodes(),
|
||||||
|
[]OpCode{{'i', 0, 0, 0, 1}, {'e', 0, 100, 1, 101}})
|
||||||
|
assertEqual(t, len(sm.bPopular), 0)
|
||||||
|
|
||||||
|
sm = NewMatcher(splitChars(rep("b", 100)),
|
||||||
|
splitChars(rep("b", 50)+"a"+rep("b", 50)))
|
||||||
|
assertAlmostEqual(t, sm.Ratio(), 0.995, 3)
|
||||||
|
assertEqual(t, sm.GetOpCodes(),
|
||||||
|
[]OpCode{{'e', 0, 50, 0, 50}, {'i', 50, 50, 50, 51}, {'e', 50, 100, 51, 101}})
|
||||||
|
assertEqual(t, len(sm.bPopular), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithAsciiOnDelete(t *testing.T) {
|
||||||
|
sm := NewMatcher(splitChars(rep("a", 40)+"c"+rep("b", 40)),
|
||||||
|
splitChars(rep("a", 40)+rep("b", 40)))
|
||||||
|
assertAlmostEqual(t, sm.Ratio(), 0.994, 3)
|
||||||
|
assertEqual(t, sm.GetOpCodes(),
|
||||||
|
[]OpCode{{'e', 0, 40, 0, 40}, {'d', 40, 41, 40, 40}, {'e', 41, 81, 40, 80}})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithAsciiBJunk(t *testing.T) {
|
||||||
|
isJunk := func(s string) bool {
|
||||||
|
return s == " "
|
||||||
|
}
|
||||||
|
sm := NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)),
|
||||||
|
splitChars(rep("a", 44)+rep("b", 40)), true, isJunk)
|
||||||
|
assertEqual(t, sm.bJunk, map[string]struct{}{})
|
||||||
|
|
||||||
|
sm = NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)),
|
||||||
|
splitChars(rep("a", 44)+rep("b", 40)+rep(" ", 20)), false, isJunk)
|
||||||
|
assertEqual(t, sm.bJunk, map[string]struct{}{" ": struct{}{}})
|
||||||
|
|
||||||
|
isJunk = func(s string) bool {
|
||||||
|
return s == " " || s == "b"
|
||||||
|
}
|
||||||
|
sm = NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)),
|
||||||
|
splitChars(rep("a", 44)+rep("b", 40)+rep(" ", 20)), false, isJunk)
|
||||||
|
assertEqual(t, sm.bJunk, map[string]struct{}{" ": struct{}{}, "b": struct{}{}})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSFBugsRatioForNullSeqn(t *testing.T) {
|
||||||
|
sm := NewMatcher(nil, nil)
|
||||||
|
assertEqual(t, sm.Ratio(), 1.0)
|
||||||
|
assertEqual(t, sm.QuickRatio(), 1.0)
|
||||||
|
assertEqual(t, sm.RealQuickRatio(), 1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSFBugsComparingEmptyLists(t *testing.T) {
|
||||||
|
groups := NewMatcher(nil, nil).GetGroupedOpCodes(-1)
|
||||||
|
assertEqual(t, len(groups), 0)
|
||||||
|
diff := UnifiedDiff{
|
||||||
|
FromFile: "Original",
|
||||||
|
ToFile: "Current",
|
||||||
|
Context: 3,
|
||||||
|
}
|
||||||
|
result, err := GetUnifiedDiffString(diff)
|
||||||
|
assertEqual(t, err, nil)
|
||||||
|
assertEqual(t, result, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutputFormatRangeFormatUnified(t *testing.T) {
|
||||||
|
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||||
|
//
|
||||||
|
// Each <range> field shall be of the form:
|
||||||
|
// %1d", <beginning line number> if the range contains exactly one line,
|
||||||
|
// and:
|
||||||
|
// "%1d,%1d", <beginning line number>, <number of lines> otherwise.
|
||||||
|
// If a range is empty, its beginning line number shall be the number of
|
||||||
|
// the line just before the range, or 0 if the empty range starts the file.
|
||||||
|
fm := formatRangeUnified
|
||||||
|
assertEqual(t, fm(3, 3), "3,0")
|
||||||
|
assertEqual(t, fm(3, 4), "4")
|
||||||
|
assertEqual(t, fm(3, 5), "4,2")
|
||||||
|
assertEqual(t, fm(3, 6), "4,3")
|
||||||
|
assertEqual(t, fm(0, 0), "0,0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutputFormatRangeFormatContext(t *testing.T) {
|
||||||
|
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||||
|
//
|
||||||
|
// The range of lines in file1 shall be written in the following format
|
||||||
|
// if the range contains two or more lines:
|
||||||
|
// "*** %d,%d ****\n", <beginning line number>, <ending line number>
|
||||||
|
// and the following format otherwise:
|
||||||
|
// "*** %d ****\n", <ending line number>
|
||||||
|
// The ending line number of an empty range shall be the number of the preceding line,
|
||||||
|
// or 0 if the range is at the start of the file.
|
||||||
|
//
|
||||||
|
// Next, the range of lines in file2 shall be written in the following format
|
||||||
|
// if the range contains two or more lines:
|
||||||
|
// "--- %d,%d ----\n", <beginning line number>, <ending line number>
|
||||||
|
// and the following format otherwise:
|
||||||
|
// "--- %d ----\n", <ending line number>
|
||||||
|
fm := formatRangeContext
|
||||||
|
assertEqual(t, fm(3, 3), "3")
|
||||||
|
assertEqual(t, fm(3, 4), "4")
|
||||||
|
assertEqual(t, fm(3, 5), "4,5")
|
||||||
|
assertEqual(t, fm(3, 6), "4,6")
|
||||||
|
assertEqual(t, fm(0, 0), "0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutputFormatTabDelimiter(t *testing.T) {
|
||||||
|
diff := UnifiedDiff{
|
||||||
|
A: splitChars("one"),
|
||||||
|
B: splitChars("two"),
|
||||||
|
FromFile: "Original",
|
||||||
|
FromDate: "2005-01-26 23:30:50",
|
||||||
|
ToFile: "Current",
|
||||||
|
ToDate: "2010-04-12 10:20:52",
|
||||||
|
Eol: "\n",
|
||||||
|
}
|
||||||
|
ud, err := GetUnifiedDiffString(diff)
|
||||||
|
assertEqual(t, err, nil)
|
||||||
|
assertEqual(t, SplitLines(ud)[:2], []string{
|
||||||
|
"--- Original\t2005-01-26 23:30:50\n",
|
||||||
|
"+++ Current\t2010-04-12 10:20:52\n",
|
||||||
|
})
|
||||||
|
cd, err := GetContextDiffString(ContextDiff(diff))
|
||||||
|
assertEqual(t, err, nil)
|
||||||
|
assertEqual(t, SplitLines(cd)[:2], []string{
|
||||||
|
"*** Original\t2005-01-26 23:30:50\n",
|
||||||
|
"--- Current\t2010-04-12 10:20:52\n",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOutputFormatNoTrailingTabOnEmptyFiledate(t *testing.T) {
|
||||||
|
diff := UnifiedDiff{
|
||||||
|
A: splitChars("one"),
|
||||||
|
B: splitChars("two"),
|
||||||
|
FromFile: "Original",
|
||||||
|
ToFile: "Current",
|
||||||
|
Eol: "\n",
|
||||||
|
}
|
||||||
|
ud, err := GetUnifiedDiffString(diff)
|
||||||
|
assertEqual(t, err, nil)
|
||||||
|
assertEqual(t, SplitLines(ud)[:2], []string{"--- Original\n", "+++ Current\n"})
|
||||||
|
|
||||||
|
cd, err := GetContextDiffString(ContextDiff(diff))
|
||||||
|
assertEqual(t, err, nil)
|
||||||
|
assertEqual(t, SplitLines(cd)[:2], []string{"*** Original\n", "--- Current\n"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOmitFilenames(t *testing.T) {
|
||||||
|
diff := UnifiedDiff{
|
||||||
|
A: SplitLines("o\nn\ne\n"),
|
||||||
|
B: SplitLines("t\nw\no\n"),
|
||||||
|
Eol: "\n",
|
||||||
|
}
|
||||||
|
ud, err := GetUnifiedDiffString(diff)
|
||||||
|
assertEqual(t, err, nil)
|
||||||
|
assertEqual(t, SplitLines(ud), []string{
|
||||||
|
"@@ -0,0 +1,2 @@\n",
|
||||||
|
"+t\n",
|
||||||
|
"+w\n",
|
||||||
|
"@@ -2,2 +3,0 @@\n",
|
||||||
|
"-n\n",
|
||||||
|
"-e\n",
|
||||||
|
"\n",
|
||||||
|
})
|
||||||
|
|
||||||
|
cd, err := GetContextDiffString(ContextDiff(diff))
|
||||||
|
assertEqual(t, err, nil)
|
||||||
|
assertEqual(t, SplitLines(cd), []string{
|
||||||
|
"***************\n",
|
||||||
|
"*** 0 ****\n",
|
||||||
|
"--- 1,2 ----\n",
|
||||||
|
"+ t\n",
|
||||||
|
"+ w\n",
|
||||||
|
"***************\n",
|
||||||
|
"*** 2,3 ****\n",
|
||||||
|
"- n\n",
|
||||||
|
"- e\n",
|
||||||
|
"--- 3 ----\n",
|
||||||
|
"\n",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitLines(t *testing.T) {
|
||||||
|
allTests := []struct {
|
||||||
|
input string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{"foo", []string{"foo\n"}},
|
||||||
|
{"foo\nbar", []string{"foo\n", "bar\n"}},
|
||||||
|
{"foo\nbar\n", []string{"foo\n", "bar\n", "\n"}},
|
||||||
|
}
|
||||||
|
for _, test := range allTests {
|
||||||
|
assertEqual(t, SplitLines(test.input), test.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkSplitLines(b *testing.B, count int) {
|
||||||
|
str := strings.Repeat("foo\n", count)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
n := 0
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
n += len(SplitLines(str))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSplitLines100(b *testing.B) {
|
||||||
|
benchmarkSplitLines(b, 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSplitLines10000(b *testing.B) {
|
||||||
|
benchmarkSplitLines(b, 10000)
|
||||||
|
}
|
60
internal/aws/awserr/error.go
Normal file
60
internal/aws/awserr/error.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/awserr/error.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
// Package awserr represents API error interface accessors for the SDK.
|
||||||
|
package awserr
|
||||||
|
|
||||||
|
// An Error wraps lower level errors with code, message and an original error.
|
||||||
|
// The underlying concrete error type may also satisfy other interfaces which
|
||||||
|
// can be to used to obtain more specific information about the error.
|
||||||
|
type Error interface {
|
||||||
|
// Satisfy the generic error interface.
|
||||||
|
error
|
||||||
|
|
||||||
|
// Returns the short phrase depicting the classification of the error.
|
||||||
|
Code() string
|
||||||
|
|
||||||
|
// Returns the error details message.
|
||||||
|
Message() string
|
||||||
|
|
||||||
|
// Returns the original error if one was set. Nil is returned if not set.
|
||||||
|
OrigErr() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchedErrors is a batch of errors which also wraps lower level errors with
|
||||||
|
// code, message, and original errors. Calling Error() will include all errors
|
||||||
|
// that occurred in the batch.
|
||||||
|
//
|
||||||
|
// Replaces BatchError
|
||||||
|
type BatchedErrors interface {
|
||||||
|
// Satisfy the base Error interface.
|
||||||
|
Error
|
||||||
|
|
||||||
|
// Returns the original error if one was set. Nil is returned if not set.
|
||||||
|
OrigErrs() []error
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns an Error object described by the code, message, and origErr.
|
||||||
|
//
|
||||||
|
// If origErr satisfies the Error interface it will not be wrapped within a new
|
||||||
|
// Error object and will instead be returned.
|
||||||
|
func New(code, message string, origErr error) Error {
|
||||||
|
var errs []error
|
||||||
|
if origErr != nil {
|
||||||
|
errs = append(errs, origErr)
|
||||||
|
}
|
||||||
|
return newBaseError(code, message, errs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBatchError returns an BatchedErrors with a collection of errors as an
|
||||||
|
// array of errors.
|
||||||
|
func NewBatchError(code, message string, errs []error) BatchedErrors {
|
||||||
|
return newBaseError(code, message, errs)
|
||||||
|
}
|
144
internal/aws/awserr/types.go
Normal file
144
internal/aws/awserr/types.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/awserr/types.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package awserr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SprintError returns a string of the formatted error code.
|
||||||
|
//
|
||||||
|
// Both extra and origErr are optional. If they are included their lines
|
||||||
|
// will be added, but if they are not included their lines will be ignored.
|
||||||
|
func SprintError(code, message, extra string, origErr error) string {
|
||||||
|
msg := fmt.Sprintf("%s: %s", code, message)
|
||||||
|
if extra != "" {
|
||||||
|
msg = fmt.Sprintf("%s\n\t%s", msg, extra)
|
||||||
|
}
|
||||||
|
if origErr != nil {
|
||||||
|
msg = fmt.Sprintf("%s\ncaused by: %s", msg, origErr.Error())
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
// A baseError wraps the code and message which defines an error. It also
|
||||||
|
// can be used to wrap an original error object.
|
||||||
|
//
|
||||||
|
// Should be used as the root for errors satisfying the awserr.Error. Also
|
||||||
|
// for any error which does not fit into a specific error wrapper type.
|
||||||
|
type baseError struct {
|
||||||
|
// Classification of error
|
||||||
|
code string
|
||||||
|
|
||||||
|
// Detailed information about error
|
||||||
|
message string
|
||||||
|
|
||||||
|
// Optional original error this error is based off of. Allows building
|
||||||
|
// chained errors.
|
||||||
|
errs []error
|
||||||
|
}
|
||||||
|
|
||||||
|
// newBaseError returns an error object for the code, message, and errors.
|
||||||
|
//
|
||||||
|
// code is a short no whitespace phrase depicting the classification of
|
||||||
|
// the error that is being created.
|
||||||
|
//
|
||||||
|
// message is the free flow string containing detailed information about the
|
||||||
|
// error.
|
||||||
|
//
|
||||||
|
// origErrs is the error objects which will be nested under the new errors to
|
||||||
|
// be returned.
|
||||||
|
func newBaseError(code, message string, origErrs []error) *baseError {
|
||||||
|
b := &baseError{
|
||||||
|
code: code,
|
||||||
|
message: message,
|
||||||
|
errs: origErrs,
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns the string representation of the error.
|
||||||
|
//
|
||||||
|
// See ErrorWithExtra for formatting.
|
||||||
|
//
|
||||||
|
// Satisfies the error interface.
|
||||||
|
func (b baseError) Error() string {
|
||||||
|
size := len(b.errs)
|
||||||
|
if size > 0 {
|
||||||
|
return SprintError(b.code, b.message, "", errorList(b.errs))
|
||||||
|
}
|
||||||
|
|
||||||
|
return SprintError(b.code, b.message, "", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the error.
|
||||||
|
// Alias for Error to satisfy the stringer interface.
|
||||||
|
func (b baseError) String() string {
|
||||||
|
return b.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code returns the short phrase depicting the classification of the error.
|
||||||
|
func (b baseError) Code() string {
|
||||||
|
return b.code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message returns the error details message.
|
||||||
|
func (b baseError) Message() string {
|
||||||
|
return b.message
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrigErr returns the original error if one was set. Nil is returned if no
|
||||||
|
// error was set. This only returns the first element in the list. If the full
|
||||||
|
// list is needed, use BatchedErrors.
|
||||||
|
func (b baseError) OrigErr() error {
|
||||||
|
switch len(b.errs) {
|
||||||
|
case 0:
|
||||||
|
return nil
|
||||||
|
case 1:
|
||||||
|
return b.errs[0]
|
||||||
|
default:
|
||||||
|
if err, ok := b.errs[0].(Error); ok {
|
||||||
|
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
|
||||||
|
}
|
||||||
|
return NewBatchError("BatchedErrors",
|
||||||
|
"multiple errors occurred", b.errs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrigErrs returns the original errors if one was set. An empty slice is
|
||||||
|
// returned if no error was set.
|
||||||
|
func (b baseError) OrigErrs() []error {
|
||||||
|
return b.errs
|
||||||
|
}
|
||||||
|
|
||||||
|
// An error list that satisfies the golang interface
|
||||||
|
type errorList []error
|
||||||
|
|
||||||
|
// Error returns the string representation of the error.
|
||||||
|
//
|
||||||
|
// Satisfies the error interface.
|
||||||
|
func (e errorList) Error() string {
|
||||||
|
msg := ""
|
||||||
|
// How do we want to handle the array size being zero
|
||||||
|
if size := len(e); size > 0 {
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
msg += e[i].Error()
|
||||||
|
// We check the next index to see if it is within the slice.
|
||||||
|
// If it is, then we append a newline. We do this, because unit tests
|
||||||
|
// could be broken with the additional '\n'
|
||||||
|
if i+1 < size {
|
||||||
|
msg += "\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
72
internal/aws/credentials/chain_provider.go
Normal file
72
internal/aws/credentials/chain_provider.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A ChainProvider will search for a provider which returns credentials
|
||||||
|
// and cache that provider until Retrieve is called again.
|
||||||
|
//
|
||||||
|
// The ChainProvider provides a way of chaining multiple providers together
|
||||||
|
// which will pick the first available using priority order of the Providers
|
||||||
|
// in the list.
|
||||||
|
//
|
||||||
|
// If none of the Providers retrieve valid credentials Value, ChainProvider's
|
||||||
|
// Retrieve() will return the error ErrNoValidProvidersFoundInChain.
|
||||||
|
//
|
||||||
|
// If a Provider is found which returns valid credentials Value ChainProvider
|
||||||
|
// will cache that Provider for all calls to IsExpired(), until Retrieve is
|
||||||
|
// called again.
|
||||||
|
type ChainProvider struct {
|
||||||
|
Providers []Provider
|
||||||
|
curr Provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChainCredentials returns a pointer to a new Credentials object
|
||||||
|
// wrapping a chain of providers.
|
||||||
|
func NewChainCredentials(providers []Provider) *Credentials {
|
||||||
|
return NewCredentials(&ChainProvider{
|
||||||
|
Providers: append([]Provider{}, providers...),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve returns the credentials value or error if no provider returned
|
||||||
|
// without error.
|
||||||
|
//
|
||||||
|
// If a provider is found it will be cached and any calls to IsExpired()
|
||||||
|
// will return the expired state of the cached provider.
|
||||||
|
func (c *ChainProvider) Retrieve() (Value, error) {
|
||||||
|
var errs = make([]error, 0, len(c.Providers))
|
||||||
|
for _, p := range c.Providers {
|
||||||
|
creds, err := p.Retrieve()
|
||||||
|
if err == nil {
|
||||||
|
c.curr = p
|
||||||
|
return creds, nil
|
||||||
|
}
|
||||||
|
errs = append(errs, err)
|
||||||
|
}
|
||||||
|
c.curr = nil
|
||||||
|
|
||||||
|
var err = awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
|
||||||
|
return Value{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired will returned the expired state of the currently cached provider
|
||||||
|
// if there is one. If there is no current provider, true will be returned.
|
||||||
|
func (c *ChainProvider) IsExpired() bool {
|
||||||
|
if c.curr != nil {
|
||||||
|
return c.curr.IsExpired()
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
176
internal/aws/credentials/chain_provider_test.go
Normal file
176
internal/aws/credentials/chain_provider_test.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider_test.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
|
||||||
|
)
|
||||||
|
|
||||||
|
type secondStubProvider struct {
|
||||||
|
creds Value
|
||||||
|
expired bool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *secondStubProvider) Retrieve() (Value, error) {
|
||||||
|
s.expired = false
|
||||||
|
s.creds.ProviderName = "secondStubProvider"
|
||||||
|
return s.creds, s.err
|
||||||
|
}
|
||||||
|
func (s *secondStubProvider) IsExpired() bool {
|
||||||
|
return s.expired
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChainProviderWithNames(t *testing.T) {
|
||||||
|
p := &ChainProvider{
|
||||||
|
Providers: []Provider{
|
||||||
|
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
|
||||||
|
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
|
||||||
|
&secondStubProvider{
|
||||||
|
creds: Value{
|
||||||
|
AccessKeyID: "AKIF",
|
||||||
|
SecretAccessKey: "NOSECRET",
|
||||||
|
SessionToken: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&stubProvider{
|
||||||
|
creds: Value{
|
||||||
|
AccessKeyID: "AKID",
|
||||||
|
SecretAccessKey: "SECRET",
|
||||||
|
SessionToken: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
creds, err := p.Retrieve()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expect no error, got %v", err)
|
||||||
|
}
|
||||||
|
if e, a := "secondStubProvider", creds.ProviderName; e != a {
|
||||||
|
t.Errorf("Expect provider name to match, %v got, %v", e, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also check credentials
|
||||||
|
if e, a := "AKIF", creds.AccessKeyID; e != a {
|
||||||
|
t.Errorf("Expect access key ID to match, %v got %v", e, a)
|
||||||
|
}
|
||||||
|
if e, a := "NOSECRET", creds.SecretAccessKey; e != a {
|
||||||
|
t.Errorf("Expect secret access key to match, %v got %v", e, a)
|
||||||
|
}
|
||||||
|
if v := creds.SessionToken; len(v) != 0 {
|
||||||
|
t.Errorf("Expect session token to be empty, %v", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChainProviderGet(t *testing.T) {
|
||||||
|
p := &ChainProvider{
|
||||||
|
Providers: []Provider{
|
||||||
|
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
|
||||||
|
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
|
||||||
|
&stubProvider{
|
||||||
|
creds: Value{
|
||||||
|
AccessKeyID: "AKID",
|
||||||
|
SecretAccessKey: "SECRET",
|
||||||
|
SessionToken: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
creds, err := p.Retrieve()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expect no error, got %v", err)
|
||||||
|
}
|
||||||
|
if e, a := "AKID", creds.AccessKeyID; e != a {
|
||||||
|
t.Errorf("Expect access key ID to match, %v got %v", e, a)
|
||||||
|
}
|
||||||
|
if e, a := "SECRET", creds.SecretAccessKey; e != a {
|
||||||
|
t.Errorf("Expect secret access key to match, %v got %v", e, a)
|
||||||
|
}
|
||||||
|
if v := creds.SessionToken; len(v) != 0 {
|
||||||
|
t.Errorf("Expect session token to be empty, %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChainProviderIsExpired(t *testing.T) {
|
||||||
|
stubProvider := &stubProvider{expired: true}
|
||||||
|
p := &ChainProvider{
|
||||||
|
Providers: []Provider{
|
||||||
|
stubProvider,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.IsExpired() {
|
||||||
|
t.Errorf("Expect expired to be true before any Retrieve")
|
||||||
|
}
|
||||||
|
_, err := p.Retrieve()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expect no error, got %v", err)
|
||||||
|
}
|
||||||
|
if p.IsExpired() {
|
||||||
|
t.Errorf("Expect not expired after retrieve")
|
||||||
|
}
|
||||||
|
|
||||||
|
stubProvider.expired = true
|
||||||
|
if !p.IsExpired() {
|
||||||
|
t.Errorf("Expect return of expired provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = p.Retrieve()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expect no error, got %v", err)
|
||||||
|
}
|
||||||
|
if p.IsExpired() {
|
||||||
|
t.Errorf("Expect not expired after retrieve")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChainProviderWithNoProvider(t *testing.T) {
|
||||||
|
p := &ChainProvider{
|
||||||
|
Providers: []Provider{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.IsExpired() {
|
||||||
|
t.Errorf("Expect expired with no providers")
|
||||||
|
}
|
||||||
|
_, err := p.Retrieve()
|
||||||
|
if err.Error() != "NoCredentialProviders: no valid providers in chain" {
|
||||||
|
t.Errorf("Expect no providers error returned, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChainProviderWithNoValidProvider(t *testing.T) {
|
||||||
|
errs := []error{
|
||||||
|
awserr.New("FirstError", "first provider error", nil),
|
||||||
|
awserr.New("SecondError", "second provider error", nil),
|
||||||
|
}
|
||||||
|
p := &ChainProvider{
|
||||||
|
Providers: []Provider{
|
||||||
|
&stubProvider{err: errs[0]},
|
||||||
|
&stubProvider{err: errs[1]},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !p.IsExpired() {
|
||||||
|
t.Errorf("Expect expired with no providers")
|
||||||
|
}
|
||||||
|
_, err := p.Retrieve()
|
||||||
|
|
||||||
|
expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
|
||||||
|
if e, a := expectErr, err; !reflect.DeepEqual(e, a) {
|
||||||
|
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
|
||||||
|
}
|
||||||
|
}
|
197
internal/aws/credentials/credentials.go
Normal file
197
internal/aws/credentials/credentials.go
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Value is the AWS credentials value for individual credential fields.
|
||||||
|
//
|
||||||
|
// A Value is also used to represent Azure credentials.
|
||||||
|
// Azure credentials only consist of an access token, which is stored in the `SessionToken` field.
|
||||||
|
type Value struct {
|
||||||
|
// AWS Access key ID
|
||||||
|
AccessKeyID string
|
||||||
|
|
||||||
|
// AWS Secret Access Key
|
||||||
|
SecretAccessKey string
|
||||||
|
|
||||||
|
// AWS Session Token
|
||||||
|
SessionToken string
|
||||||
|
|
||||||
|
// Provider used to get credentials
|
||||||
|
ProviderName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasKeys returns if the credentials Value has both AccessKeyID and
|
||||||
|
// SecretAccessKey value set.
|
||||||
|
func (v Value) HasKeys() bool {
|
||||||
|
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Provider is the interface for any component which will provide credentials
|
||||||
|
// Value. A provider is required to manage its own Expired state, and what to
|
||||||
|
// be expired means.
|
||||||
|
//
|
||||||
|
// The Provider should not need to implement its own mutexes, because
|
||||||
|
// that will be managed by Credentials.
|
||||||
|
type Provider interface {
|
||||||
|
// Retrieve returns nil if it successfully retrieved the value.
|
||||||
|
// Error is returned if the value were not obtainable, or empty.
|
||||||
|
Retrieve() (Value, error)
|
||||||
|
|
||||||
|
// IsExpired returns if the credentials are no longer valid, and need
|
||||||
|
// to be retrieved.
|
||||||
|
IsExpired() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderWithContext is a Provider that can retrieve credentials with a Context
|
||||||
|
type ProviderWithContext interface {
|
||||||
|
Provider
|
||||||
|
|
||||||
|
RetrieveWithContext(context.Context) (Value, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
|
||||||
|
//
|
||||||
|
// A Credentials is also used to fetch Azure credentials Value.
|
||||||
|
//
|
||||||
|
// Credentials will cache the credentials value until they expire. Once the value
|
||||||
|
// expires the next Get will attempt to retrieve valid credentials.
|
||||||
|
//
|
||||||
|
// Credentials is safe to use across multiple goroutines and will manage the
|
||||||
|
// synchronous state so the Providers do not need to implement their own
|
||||||
|
// synchronization.
|
||||||
|
//
|
||||||
|
// The first Credentials.Get() will always call Provider.Retrieve() to get the
|
||||||
|
// first instance of the credentials Value. All calls to Get() after that
|
||||||
|
// will return the cached credentials Value until IsExpired() returns true.
|
||||||
|
type Credentials struct {
|
||||||
|
sf singleflight.Group
|
||||||
|
|
||||||
|
m sync.RWMutex
|
||||||
|
creds Value
|
||||||
|
provider Provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCredentials returns a pointer to a new Credentials with the provider set.
|
||||||
|
func NewCredentials(provider Provider) *Credentials {
|
||||||
|
c := &Credentials{
|
||||||
|
provider: provider,
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWithContext returns the credentials value, or error if the credentials
|
||||||
|
// Value failed to be retrieved. Will return early if the passed in context is
|
||||||
|
// canceled.
|
||||||
|
//
|
||||||
|
// Will return the cached credentials Value if it has not expired. If the
|
||||||
|
// credentials Value has expired the Provider's Retrieve() will be called
|
||||||
|
// to refresh the credentials.
|
||||||
|
//
|
||||||
|
// If Credentials.Expire() was called the credentials Value will be force
|
||||||
|
// expired, and the next call to Get() will cause them to be refreshed.
|
||||||
|
func (c *Credentials) GetWithContext(ctx context.Context) (Value, error) {
|
||||||
|
// Check if credentials are cached, and not expired.
|
||||||
|
select {
|
||||||
|
case curCreds, ok := <-c.asyncIsExpired():
|
||||||
|
// ok will only be true, of the credentials were not expired. ok will
|
||||||
|
// be false and have no value if the credentials are expired.
|
||||||
|
if ok {
|
||||||
|
return curCreds, nil
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return Value{}, awserr.New("RequestCanceled",
|
||||||
|
"request context canceled", ctx.Err())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cannot pass context down to the actual retrieve, because the first
|
||||||
|
// context would cancel the whole group when there is not direct
|
||||||
|
// association of items in the group.
|
||||||
|
resCh := c.sf.DoChan("", func() (interface{}, error) {
|
||||||
|
return c.singleRetrieve(&suppressedContext{ctx})
|
||||||
|
})
|
||||||
|
select {
|
||||||
|
case res := <-resCh:
|
||||||
|
return res.Val.(Value), res.Err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return Value{}, awserr.New("RequestCanceled",
|
||||||
|
"request context canceled", ctx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Credentials) singleRetrieve(ctx context.Context) (interface{}, error) {
|
||||||
|
c.m.Lock()
|
||||||
|
defer c.m.Unlock()
|
||||||
|
|
||||||
|
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
|
||||||
|
return curCreds, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var creds Value
|
||||||
|
var err error
|
||||||
|
if p, ok := c.provider.(ProviderWithContext); ok {
|
||||||
|
creds, err = p.RetrieveWithContext(ctx)
|
||||||
|
} else {
|
||||||
|
creds, err = c.provider.Retrieve()
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
c.creds = creds
|
||||||
|
}
|
||||||
|
|
||||||
|
return creds, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// asyncIsExpired returns a channel of credentials Value. If the channel is
|
||||||
|
// closed the credentials are expired and credentials value are not empty.
|
||||||
|
func (c *Credentials) asyncIsExpired() <-chan Value {
|
||||||
|
ch := make(chan Value, 1)
|
||||||
|
go func() {
|
||||||
|
c.m.RLock()
|
||||||
|
defer c.m.RUnlock()
|
||||||
|
|
||||||
|
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
|
||||||
|
ch <- curCreds
|
||||||
|
}
|
||||||
|
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
// isExpiredLocked helper method wrapping the definition of expired credentials.
|
||||||
|
func (c *Credentials) isExpiredLocked(creds interface{}) bool {
|
||||||
|
return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired()
|
||||||
|
}
|
||||||
|
|
||||||
|
type suppressedContext struct {
|
||||||
|
context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *suppressedContext) Done() <-chan struct{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *suppressedContext) Err() error {
|
||||||
|
return nil
|
||||||
|
}
|
192
internal/aws/credentials/credentials_test.go
Normal file
192
internal/aws/credentials/credentials_test.go
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials_test.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isExpired(c *Credentials) bool {
|
||||||
|
c.m.RLock()
|
||||||
|
defer c.m.RUnlock()
|
||||||
|
|
||||||
|
return c.isExpiredLocked(c.creds)
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubProvider struct {
|
||||||
|
creds Value
|
||||||
|
retrievedCount int
|
||||||
|
expired bool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubProvider) Retrieve() (Value, error) {
|
||||||
|
s.retrievedCount++
|
||||||
|
s.expired = false
|
||||||
|
s.creds.ProviderName = "stubProvider"
|
||||||
|
return s.creds, s.err
|
||||||
|
}
|
||||||
|
func (s *stubProvider) IsExpired() bool {
|
||||||
|
return s.expired
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsGet(t *testing.T) {
|
||||||
|
c := NewCredentials(&stubProvider{
|
||||||
|
creds: Value{
|
||||||
|
AccessKeyID: "AKID",
|
||||||
|
SecretAccessKey: "SECRET",
|
||||||
|
SessionToken: "",
|
||||||
|
},
|
||||||
|
expired: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
creds, err := c.GetWithContext(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if e, a := "AKID", creds.AccessKeyID; e != a {
|
||||||
|
t.Errorf("Expect access key ID to match, %v got %v", e, a)
|
||||||
|
}
|
||||||
|
if e, a := "SECRET", creds.SecretAccessKey; e != a {
|
||||||
|
t.Errorf("Expect secret access key to match, %v got %v", e, a)
|
||||||
|
}
|
||||||
|
if v := creds.SessionToken; len(v) != 0 {
|
||||||
|
t.Errorf("Expect session token to be empty, %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsGetWithError(t *testing.T) {
|
||||||
|
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
|
||||||
|
|
||||||
|
_, err := c.GetWithContext(context.Background())
|
||||||
|
if e, a := "provider error", err.(awserr.Error).Code(); e != a {
|
||||||
|
t.Errorf("Expected provider error, %v got %v", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsExpire(t *testing.T) {
|
||||||
|
stub := &stubProvider{}
|
||||||
|
c := NewCredentials(stub)
|
||||||
|
|
||||||
|
stub.expired = false
|
||||||
|
if !isExpired(c) {
|
||||||
|
t.Errorf("Expected to start out expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.GetWithContext(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no err, got %v", err)
|
||||||
|
}
|
||||||
|
if isExpired(c) {
|
||||||
|
t.Errorf("Expected not to be expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
stub.expired = true
|
||||||
|
if !isExpired(c) {
|
||||||
|
t.Errorf("Expected to be expired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsGetWithProviderName(t *testing.T) {
|
||||||
|
stub := &stubProvider{}
|
||||||
|
|
||||||
|
c := NewCredentials(stub)
|
||||||
|
|
||||||
|
creds, err := c.GetWithContext(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if e, a := creds.ProviderName, "stubProvider"; e != a {
|
||||||
|
t.Errorf("Expected provider name to match, %v got %v", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockProvider struct {
|
||||||
|
// The date/time when to expire on
|
||||||
|
expiration time.Time
|
||||||
|
|
||||||
|
// If set will be used by IsExpired to determine the current time.
|
||||||
|
// Defaults to time.Now if CurrentTime is not set. Available for testing
|
||||||
|
// to be able to mock out the current time.
|
||||||
|
CurrentTime func() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired returns if the credentials are expired.
|
||||||
|
func (e *MockProvider) IsExpired() bool {
|
||||||
|
curTime := e.CurrentTime
|
||||||
|
if curTime == nil {
|
||||||
|
curTime = time.Now
|
||||||
|
}
|
||||||
|
return e.expiration.Before(curTime())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*MockProvider) Retrieve() (Value, error) {
|
||||||
|
return Value{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsIsExpired_Race(_ *testing.T) {
|
||||||
|
creds := NewChainCredentials([]Provider{&MockProvider{}})
|
||||||
|
|
||||||
|
starter := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-starter
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
isExpired(creds)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
close(starter)
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubProviderConcurrent struct {
|
||||||
|
stubProvider
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubProviderConcurrent) Retrieve() (Value, error) {
|
||||||
|
<-s.done
|
||||||
|
return s.stubProvider.Retrieve()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsGetConcurrent(t *testing.T) {
|
||||||
|
stub := &stubProviderConcurrent{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
c := NewCredentials(stub)
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
go func() {
|
||||||
|
_, err := c.GetWithContext(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no err, got %v", err)
|
||||||
|
}
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates that a single call to Retrieve is shared between two calls to Get
|
||||||
|
stub.done <- struct{}{}
|
||||||
|
<-done
|
||||||
|
<-done
|
||||||
|
}
|
51
internal/aws/signer/v4/header_rules.go
Normal file
51
internal/aws/signer/v4/header_rules.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/header_rules.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package v4
|
||||||
|
|
||||||
|
// validator houses a set of rule needed for validation of a
|
||||||
|
// string value
|
||||||
|
type rules []rule
|
||||||
|
|
||||||
|
// rule interface allows for more flexible rules and just simply
|
||||||
|
// checks whether or not a value adheres to that rule
|
||||||
|
type rule interface {
|
||||||
|
IsValid(value string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid will iterate through all rules and see if any rules
|
||||||
|
// apply to the value and supports nested rules
|
||||||
|
func (r rules) IsValid(value string) bool {
|
||||||
|
for _, rule := range r {
|
||||||
|
if rule.IsValid(value) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapRule generic rule for maps
|
||||||
|
type mapRule map[string]struct{}
|
||||||
|
|
||||||
|
// IsValid for the map rule satisfies whether it exists in the map
|
||||||
|
func (m mapRule) IsValid(value string) bool {
|
||||||
|
_, ok := m[value]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// excludeList is a generic rule for exclude listing
|
||||||
|
type excludeList struct {
|
||||||
|
rule
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid for exclude list checks if the value is within the exclude list
|
||||||
|
func (b excludeList) IsValid(value string) bool {
|
||||||
|
return !b.rule.IsValid(value)
|
||||||
|
}
|
80
internal/aws/signer/v4/request.go
Normal file
80
internal/aws/signer/v4/request.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/request/request.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package v4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Returns host from request
|
||||||
|
func getHost(r *http.Request) string {
|
||||||
|
if r.Host != "" {
|
||||||
|
return r.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.URL.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hostname returns u.Host, without any port number.
|
||||||
|
//
|
||||||
|
// If Host is an IPv6 literal with a port number, Hostname returns the
|
||||||
|
// IPv6 literal without the square brackets. IPv6 literals may include
|
||||||
|
// a zone identifier.
|
||||||
|
//
|
||||||
|
// Copied from the Go 1.8 standard library (net/url)
|
||||||
|
func stripPort(hostport string) string {
|
||||||
|
colon := strings.IndexByte(hostport, ':')
|
||||||
|
if colon == -1 {
|
||||||
|
return hostport
|
||||||
|
}
|
||||||
|
if i := strings.IndexByte(hostport, ']'); i != -1 {
|
||||||
|
return strings.TrimPrefix(hostport[:i], "[")
|
||||||
|
}
|
||||||
|
return hostport[:colon]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port returns the port part of u.Host, without the leading colon.
|
||||||
|
// If u.Host doesn't contain a port, Port returns an empty string.
|
||||||
|
//
|
||||||
|
// Copied from the Go 1.8 standard library (net/url)
|
||||||
|
func portOnly(hostport string) string {
|
||||||
|
colon := strings.IndexByte(hostport, ':')
|
||||||
|
if colon == -1 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if i := strings.Index(hostport, "]:"); i != -1 {
|
||||||
|
return hostport[i+len("]:"):]
|
||||||
|
}
|
||||||
|
if strings.Contains(hostport, "]") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hostport[colon+len(":"):]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if the specified URI is using the standard port
|
||||||
|
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
|
||||||
|
func isDefaultPort(scheme, port string) bool {
|
||||||
|
if port == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerCaseScheme := strings.ToLower(scheme)
|
||||||
|
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
65
internal/aws/signer/v4/uri_path.go
Normal file
65
internal/aws/signer/v4/uri_path.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/uri_path.go
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/private/protocol/rest/build.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package v4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Whether the byte value can be sent without escaping in AWS URLs
|
||||||
|
var noEscape [256]bool
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
for i := 0; i < len(noEscape); i++ {
|
||||||
|
// AWS expects every character except these to be escaped
|
||||||
|
noEscape[i] = (i >= 'A' && i <= 'Z') ||
|
||||||
|
(i >= 'a' && i <= 'z') ||
|
||||||
|
(i >= '0' && i <= '9') ||
|
||||||
|
i == '-' ||
|
||||||
|
i == '.' ||
|
||||||
|
i == '_' ||
|
||||||
|
i == '~'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getURIPath(u *url.URL) string {
|
||||||
|
var uri string
|
||||||
|
|
||||||
|
if len(u.Opaque) > 0 {
|
||||||
|
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
|
||||||
|
} else {
|
||||||
|
uri = u.EscapedPath()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(uri) == 0 {
|
||||||
|
uri = "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
return uri
|
||||||
|
}
|
||||||
|
|
||||||
|
// EscapePath escapes part of a URL path in Amazon style
|
||||||
|
func EscapePath(path string, encodeSep bool) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
for i := 0; i < len(path); i++ {
|
||||||
|
c := path[i]
|
||||||
|
if noEscape[c] || (c == '/' && !encodeSep) {
|
||||||
|
buf.WriteByte(c)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(&buf, "%%%02X", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.String()
|
||||||
|
}
|
421
internal/aws/signer/v4/v4.go
Normal file
421
internal/aws/signer/v4/v4.go
Normal file
@ -0,0 +1,421 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/v4.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package v4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
authorizationHeader = "Authorization"
|
||||||
|
authHeaderSignatureElem = "Signature="
|
||||||
|
|
||||||
|
authHeaderPrefix = "AWS4-HMAC-SHA256"
|
||||||
|
timeFormat = "20060102T150405Z"
|
||||||
|
shortTimeFormat = "20060102"
|
||||||
|
awsV4Request = "aws4_request"
|
||||||
|
|
||||||
|
// emptyStringSHA256 is a SHA256 of an empty string
|
||||||
|
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
|
||||||
|
)
|
||||||
|
|
||||||
|
var ignoredHeaders = rules{
|
||||||
|
excludeList{
|
||||||
|
mapRule{
|
||||||
|
authorizationHeader: struct{}{},
|
||||||
|
"User-Agent": struct{}{},
|
||||||
|
"X-Amzn-Trace-Id": struct{}{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signer applies AWS v4 signing to given request. Use this to sign requests
|
||||||
|
// that need to be signed with AWS V4 Signatures.
|
||||||
|
type Signer struct {
|
||||||
|
// The authentication credentials the request will be signed against.
|
||||||
|
// This value must be set to sign requests.
|
||||||
|
Credentials *credentials.Credentials
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSigner returns a Signer pointer configured with the credentials provided.
|
||||||
|
func NewSigner(credentials *credentials.Credentials) *Signer {
|
||||||
|
v4 := &Signer{
|
||||||
|
Credentials: credentials,
|
||||||
|
}
|
||||||
|
|
||||||
|
return v4
|
||||||
|
}
|
||||||
|
|
||||||
|
type signingCtx struct {
|
||||||
|
ServiceName string
|
||||||
|
Region string
|
||||||
|
Request *http.Request
|
||||||
|
Body io.ReadSeeker
|
||||||
|
Query url.Values
|
||||||
|
Time time.Time
|
||||||
|
SignedHeaderVals http.Header
|
||||||
|
|
||||||
|
credValues credentials.Value
|
||||||
|
|
||||||
|
bodyDigest string
|
||||||
|
signedHeaders string
|
||||||
|
canonicalHeaders string
|
||||||
|
canonicalString string
|
||||||
|
credentialString string
|
||||||
|
stringToSign string
|
||||||
|
signature string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign signs AWS v4 requests with the provided body, service name, region the
|
||||||
|
// request is made to, and time the request is signed at. The signTime allows
|
||||||
|
// you to specify that a request is signed for the future, and cannot be
|
||||||
|
// used until then.
|
||||||
|
//
|
||||||
|
// Returns a list of HTTP headers that were included in the signature or an
|
||||||
|
// error if signing the request failed. Generally for signed requests this value
|
||||||
|
// is not needed as the full request context will be captured by the http.Request
|
||||||
|
// value. It is included for reference though.
|
||||||
|
//
|
||||||
|
// Sign will set the request's Body to be the `body` parameter passed in. If
|
||||||
|
// the body is not already an io.ReadCloser, it will be wrapped within one. If
|
||||||
|
// a `nil` body parameter passed to Sign, the request's Body field will be
|
||||||
|
// also set to nil. Its important to note that this functionality will not
|
||||||
|
// change the request's ContentLength of the request.
|
||||||
|
//
|
||||||
|
// Sign differs from Presign in that it will sign the request using HTTP
|
||||||
|
// header values. This type of signing is intended for http.Request values that
|
||||||
|
// will not be shared, or are shared in a way the header values on the request
|
||||||
|
// will not be lost.
|
||||||
|
//
|
||||||
|
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
|
||||||
|
// generated. To bypass the signer computing the hash you can set the
|
||||||
|
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
|
||||||
|
// only compute the hash if the request header value is empty.
|
||||||
|
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
|
||||||
|
return v4.signWithBody(r, body, service, region, signTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
|
||||||
|
ctx := &signingCtx{
|
||||||
|
Request: r,
|
||||||
|
Body: body,
|
||||||
|
Query: r.URL.Query(),
|
||||||
|
Time: signTime,
|
||||||
|
ServiceName: service,
|
||||||
|
Region: region,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range ctx.Query {
|
||||||
|
sort.Strings(ctx.Query[key])
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.isRequestSigned() {
|
||||||
|
ctx.Time = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
ctx.credValues, err = v4.Credentials.GetWithContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
return http.Header{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.sanitizeHostForHeader()
|
||||||
|
ctx.assignAmzQueryValues()
|
||||||
|
if err := ctx.build(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var reader io.ReadCloser
|
||||||
|
if body != nil {
|
||||||
|
var ok bool
|
||||||
|
if reader, ok = body.(io.ReadCloser); !ok {
|
||||||
|
reader = ioutil.NopCloser(body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.Body = reader
|
||||||
|
|
||||||
|
return ctx.SignedHeaderVals, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeHostForHeader removes default port from host and updates request.Host
|
||||||
|
func (ctx *signingCtx) sanitizeHostForHeader() {
|
||||||
|
r := ctx.Request
|
||||||
|
host := getHost(r)
|
||||||
|
port := portOnly(host)
|
||||||
|
if port != "" && isDefaultPort(r.URL.Scheme, port) {
|
||||||
|
r.Host = stripPort(host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *signingCtx) assignAmzQueryValues() {
|
||||||
|
if ctx.credValues.SessionToken != "" {
|
||||||
|
ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *signingCtx) build() error {
|
||||||
|
ctx.buildTime() // no depends
|
||||||
|
ctx.buildCredentialString() // no depends
|
||||||
|
|
||||||
|
if err := ctx.buildBodyDigest(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
unsignedHeaders := ctx.Request.Header
|
||||||
|
|
||||||
|
ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
|
||||||
|
ctx.buildCanonicalString() // depends on canon headers / signed headers
|
||||||
|
ctx.buildStringToSign() // depends on canon string
|
||||||
|
ctx.buildSignature() // depends on string to sign
|
||||||
|
|
||||||
|
parts := []string{
|
||||||
|
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
|
||||||
|
"SignedHeaders=" + ctx.signedHeaders,
|
||||||
|
authHeaderSignatureElem + ctx.signature,
|
||||||
|
}
|
||||||
|
ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *signingCtx) buildTime() {
|
||||||
|
ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *signingCtx) buildCredentialString() {
|
||||||
|
ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
|
||||||
|
headers := make([]string, 0, len(header)+1)
|
||||||
|
headers = append(headers, "host")
|
||||||
|
for k, v := range header {
|
||||||
|
if !r.IsValid(k) {
|
||||||
|
continue // ignored header
|
||||||
|
}
|
||||||
|
if ctx.SignedHeaderVals == nil {
|
||||||
|
ctx.SignedHeaderVals = make(http.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
lowerCaseKey := strings.ToLower(k)
|
||||||
|
if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok {
|
||||||
|
// include additional values
|
||||||
|
ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = append(headers, lowerCaseKey)
|
||||||
|
ctx.SignedHeaderVals[lowerCaseKey] = v
|
||||||
|
}
|
||||||
|
sort.Strings(headers)
|
||||||
|
|
||||||
|
ctx.signedHeaders = strings.Join(headers, ";")
|
||||||
|
|
||||||
|
headerItems := make([]string, len(headers))
|
||||||
|
for i, k := range headers {
|
||||||
|
if k == "host" {
|
||||||
|
if ctx.Request.Host != "" {
|
||||||
|
headerItems[i] = "host:" + ctx.Request.Host
|
||||||
|
} else {
|
||||||
|
headerItems[i] = "host:" + ctx.Request.URL.Host
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
headerValues := make([]string, len(ctx.SignedHeaderVals[k]))
|
||||||
|
for i, v := range ctx.SignedHeaderVals[k] {
|
||||||
|
headerValues[i] = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
headerItems[i] = k + ":" +
|
||||||
|
strings.Join(headerValues, ",")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stripExcessSpaces(headerItems)
|
||||||
|
ctx.canonicalHeaders = strings.Join(headerItems, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *signingCtx) buildCanonicalString() {
|
||||||
|
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
|
||||||
|
|
||||||
|
uri := getURIPath(ctx.Request.URL)
|
||||||
|
|
||||||
|
uri = EscapePath(uri, false)
|
||||||
|
|
||||||
|
ctx.canonicalString = strings.Join([]string{
|
||||||
|
ctx.Request.Method,
|
||||||
|
uri,
|
||||||
|
ctx.Request.URL.RawQuery,
|
||||||
|
ctx.canonicalHeaders + "\n",
|
||||||
|
ctx.signedHeaders,
|
||||||
|
ctx.bodyDigest,
|
||||||
|
}, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *signingCtx) buildStringToSign() {
|
||||||
|
ctx.stringToSign = strings.Join([]string{
|
||||||
|
authHeaderPrefix,
|
||||||
|
formatTime(ctx.Time),
|
||||||
|
ctx.credentialString,
|
||||||
|
hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
|
||||||
|
}, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *signingCtx) buildSignature() {
|
||||||
|
creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
|
||||||
|
signature := hmacSHA256(creds, []byte(ctx.stringToSign))
|
||||||
|
ctx.signature = hex.EncodeToString(signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *signingCtx) buildBodyDigest() error {
|
||||||
|
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
|
||||||
|
if hash == "" {
|
||||||
|
if ctx.Body == nil {
|
||||||
|
hash = emptyStringSHA256
|
||||||
|
} else {
|
||||||
|
if !aws.IsReaderSeekable(ctx.Body) {
|
||||||
|
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
|
||||||
|
}
|
||||||
|
hashBytes, err := makeSha256Reader(ctx.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hash = hex.EncodeToString(hashBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.bodyDigest = hash
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRequestSigned returns if the request is currently signed or presigned
|
||||||
|
func (ctx *signingCtx) isRequestSigned() bool {
|
||||||
|
return ctx.Request.Header.Get("Authorization") != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func hmacSHA256(key []byte, data []byte) []byte {
|
||||||
|
hash := hmac.New(sha256.New, key)
|
||||||
|
hash.Write(data)
|
||||||
|
return hash.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hashSHA256(data []byte) []byte {
|
||||||
|
hash := sha256.New()
|
||||||
|
hash.Write(data)
|
||||||
|
return hash.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
|
||||||
|
hash := sha256.New()
|
||||||
|
start, err := reader.Seek(0, io.SeekCurrent)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// ensure error is return if unable to seek back to start of payload.
|
||||||
|
_, err = reader.Seek(start, io.SeekStart)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
|
||||||
|
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
|
||||||
|
size, err := aws.SeekerLen(reader)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = io.Copy(hash, reader)
|
||||||
|
} else {
|
||||||
|
_, _ = io.CopyN(hash, reader, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hash.Sum(nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const doubleSpace = " "
|
||||||
|
|
||||||
|
// stripExcessSpaces will rewrite the passed in slice's string values to not
|
||||||
|
// contain multiple side-by-side spaces.
|
||||||
|
func stripExcessSpaces(vals []string) {
|
||||||
|
var j, k, l, m, spaces int
|
||||||
|
for i, str := range vals {
|
||||||
|
// revive:disable:empty-block
|
||||||
|
|
||||||
|
// Trim trailing spaces
|
||||||
|
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trim leading spaces
|
||||||
|
for k = 0; k < j && str[k] == ' '; k++ {
|
||||||
|
}
|
||||||
|
|
||||||
|
// revive:enable:empty-block
|
||||||
|
|
||||||
|
str = str[k : j+1]
|
||||||
|
|
||||||
|
// Strip multiple spaces.
|
||||||
|
j = strings.Index(str, doubleSpace)
|
||||||
|
if j < 0 {
|
||||||
|
vals[i] = str
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := []byte(str)
|
||||||
|
for k, m, l = j, j, len(buf); k < l; k++ {
|
||||||
|
if buf[k] == ' ' {
|
||||||
|
if spaces == 0 {
|
||||||
|
// First space.
|
||||||
|
buf[m] = buf[k]
|
||||||
|
m++
|
||||||
|
}
|
||||||
|
spaces++
|
||||||
|
} else {
|
||||||
|
// End of multiple spaces.
|
||||||
|
spaces = 0
|
||||||
|
buf[m] = buf[k]
|
||||||
|
m++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vals[i] = string(buf[:m])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSigningScope(region, service string, dt time.Time) string {
|
||||||
|
return strings.Join([]string{
|
||||||
|
formatShortTime(dt),
|
||||||
|
region,
|
||||||
|
service,
|
||||||
|
awsV4Request,
|
||||||
|
}, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
|
||||||
|
keyDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
|
||||||
|
keyRegion := hmacSHA256(keyDate, []byte(region))
|
||||||
|
keyService := hmacSHA256(keyRegion, []byte(service))
|
||||||
|
signingKey := hmacSHA256(keyService, []byte(awsV4Request))
|
||||||
|
return signingKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatShortTime(dt time.Time) string {
|
||||||
|
return dt.UTC().Format(shortTimeFormat)
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatTime(dt time.Time) string {
|
||||||
|
return dt.UTC().Format(timeFormat)
|
||||||
|
}
|
434
internal/aws/signer/v4/v4_test.go
Normal file
434
internal/aws/signer/v4/v4_test.go
Normal file
@ -0,0 +1,434 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/v4_test.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package v4
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/credproviders"
|
||||||
|
)
|
||||||
|
|
||||||
|
func epochTime() time.Time { return time.Unix(0, 0) }
|
||||||
|
|
||||||
|
func TestStripExcessHeaders(t *testing.T) {
|
||||||
|
vals := []string{
|
||||||
|
"",
|
||||||
|
"123",
|
||||||
|
"1 2 3",
|
||||||
|
"1 2 3 ",
|
||||||
|
" 1 2 3",
|
||||||
|
"1 2 3",
|
||||||
|
"1 23",
|
||||||
|
"1 2 3",
|
||||||
|
"1 2 ",
|
||||||
|
" 1 2 ",
|
||||||
|
"12 3",
|
||||||
|
"12 3 1",
|
||||||
|
"12 3 1",
|
||||||
|
"12 3 1abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{
|
||||||
|
"",
|
||||||
|
"123",
|
||||||
|
"1 2 3",
|
||||||
|
"1 2 3",
|
||||||
|
"1 2 3",
|
||||||
|
"1 2 3",
|
||||||
|
"1 23",
|
||||||
|
"1 2 3",
|
||||||
|
"1 2",
|
||||||
|
"1 2",
|
||||||
|
"12 3",
|
||||||
|
"12 3 1",
|
||||||
|
"12 3 1",
|
||||||
|
"12 3 1abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
stripExcessSpaces(vals)
|
||||||
|
for i := 0; i < len(vals); i++ {
|
||||||
|
if e, a := expected[i], vals[i]; e != a {
|
||||||
|
t.Errorf("%d, expect %v, got %v", i, e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRequest(body string) (*http.Request, io.ReadSeeker) {
|
||||||
|
reader := strings.NewReader(body)
|
||||||
|
return buildRequestWithBodyReader("dynamodb", "us-east-1", reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRequestReaderSeeker(serviceName, region, body string) (*http.Request, io.ReadSeeker) {
|
||||||
|
reader := &readerSeekerWrapper{strings.NewReader(body)}
|
||||||
|
return buildRequestWithBodyReader(serviceName, region, reader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, io.ReadSeeker) {
|
||||||
|
var bodyLen int
|
||||||
|
|
||||||
|
type lenner interface {
|
||||||
|
Len() int
|
||||||
|
}
|
||||||
|
if lr, ok := body.(lenner); ok {
|
||||||
|
bodyLen = lr.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
|
||||||
|
req, _ := http.NewRequest("POST", endpoint, body)
|
||||||
|
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
|
||||||
|
req.Header.Set("X-Amz-Target", "prefix.Operation")
|
||||||
|
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||||
|
|
||||||
|
if bodyLen > 0 {
|
||||||
|
req.Header.Set("Content-Length", strconv.Itoa(bodyLen))
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
|
||||||
|
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
|
||||||
|
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
|
||||||
|
|
||||||
|
var seeker io.ReadSeeker
|
||||||
|
if sr, ok := body.(io.ReadSeeker); ok {
|
||||||
|
seeker = sr
|
||||||
|
} else {
|
||||||
|
seeker = aws.ReadSeekCloser(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, seeker
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildSigner() Signer {
|
||||||
|
return Signer{
|
||||||
|
Credentials: newTestStaticCredentials(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestStaticCredentials() *credentials.Credentials {
|
||||||
|
return credentials.NewCredentials(&credproviders.StaticProvider{Value: credentials.Value{
|
||||||
|
AccessKeyID: "AKID",
|
||||||
|
SecretAccessKey: "SECRET",
|
||||||
|
SessionToken: "SESSION",
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignRequest(t *testing.T) {
|
||||||
|
req, body := buildRequest("{}")
|
||||||
|
signer := buildSigner()
|
||||||
|
_, err := signer.Sign(req, body, "dynamodb", "us-east-1", epochTime())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no err, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedDate := "19700101T000000Z"
|
||||||
|
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
|
||||||
|
|
||||||
|
q := req.Header
|
||||||
|
if e, a := expectedSig, q.Get("Authorization"); e != a {
|
||||||
|
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
|
||||||
|
}
|
||||||
|
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
|
||||||
|
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignUnseekableBody(t *testing.T) {
|
||||||
|
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
|
||||||
|
signer := buildSigner()
|
||||||
|
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expect error signing request")
|
||||||
|
}
|
||||||
|
|
||||||
|
if e, a := "unseekable request body", err.Error(); !strings.Contains(a, e) {
|
||||||
|
t.Errorf("expect %q to be in %q", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignPreComputedHashUnseekableBody(t *testing.T) {
|
||||||
|
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
|
||||||
|
|
||||||
|
signer := buildSigner()
|
||||||
|
|
||||||
|
req.Header.Set("X-Amz-Content-Sha256", "some-content-sha256")
|
||||||
|
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expect no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||||
|
if e, a := "some-content-sha256", hash; e != a {
|
||||||
|
t.Errorf("expect %v, got %v", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignPrecomputedBodyChecksum(t *testing.T) {
|
||||||
|
req, body := buildRequest("hello")
|
||||||
|
req.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
|
||||||
|
signer := buildSigner()
|
||||||
|
_, err := signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no err, got %v", err)
|
||||||
|
}
|
||||||
|
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||||
|
if e, a := "PRECOMPUTED", hash; e != a {
|
||||||
|
t.Errorf("expect %v, got %v", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignWithRequestBody(t *testing.T) {
|
||||||
|
creds := newTestStaticCredentials()
|
||||||
|
signer := NewSigner(creds)
|
||||||
|
|
||||||
|
expectBody := []byte("abc123")
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
b, err := ioutil.ReadAll(r.Body)
|
||||||
|
r.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect no error, got %v", err)
|
||||||
|
}
|
||||||
|
if e, a := expectBody, b; !reflect.DeepEqual(e, a) {
|
||||||
|
t.Errorf("expect %v, got %v", e, a)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", server.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect not no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = signer.Sign(req, bytes.NewReader(expectBody), "service", "region", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect not no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect not no error, got %v", err)
|
||||||
|
}
|
||||||
|
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||||
|
t.Errorf("expect %v, got %v", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignWithRequestBody_Overwrite(t *testing.T) {
|
||||||
|
creds := newTestStaticCredentials()
|
||||||
|
signer := NewSigner(creds)
|
||||||
|
|
||||||
|
var expectBody []byte
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
b, err := ioutil.ReadAll(r.Body)
|
||||||
|
r.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect not no error, got %v", err)
|
||||||
|
}
|
||||||
|
if e, a := len(expectBody), len(b); e != a {
|
||||||
|
t.Errorf("expect %v, got %v", e, a)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", server.URL, strings.NewReader("invalid body"))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect not no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = signer.Sign(req, nil, "service", "region", time.Now())
|
||||||
|
req.ContentLength = 0
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect not no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expect not no error, got %v", err)
|
||||||
|
}
|
||||||
|
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||||
|
t.Errorf("expect %v, got %v", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCanonicalRequest(t *testing.T) {
|
||||||
|
req, body := buildRequest("{}")
|
||||||
|
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
|
||||||
|
ctx := &signingCtx{
|
||||||
|
ServiceName: "dynamodb",
|
||||||
|
Region: "us-east-1",
|
||||||
|
Request: req,
|
||||||
|
Body: body,
|
||||||
|
Query: req.URL.Query(),
|
||||||
|
Time: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.buildCanonicalString()
|
||||||
|
expected := "https://example.org/bucket/key-._~,!@#$%^&*()?Foo=z&Foo=o&Foo=m&Foo=a"
|
||||||
|
if e, a := expected, ctx.Request.URL.String(); e != a {
|
||||||
|
t.Errorf("expect %v, got %v", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignWithBody_ReplaceRequestBody(t *testing.T) {
|
||||||
|
creds := newTestStaticCredentials()
|
||||||
|
req, seekerBody := buildRequest("{}")
|
||||||
|
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||||
|
|
||||||
|
s := NewSigner(creds)
|
||||||
|
origBody := req.Body
|
||||||
|
|
||||||
|
_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expect no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Body == origBody {
|
||||||
|
t.Errorf("expect request body to not be origBody")
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Body == nil {
|
||||||
|
t.Errorf("expect request body to be changed but was nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestHost(t *testing.T) {
|
||||||
|
req, body := buildRequest("{}")
|
||||||
|
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
|
||||||
|
req.Host = "myhost"
|
||||||
|
ctx := &signingCtx{
|
||||||
|
ServiceName: "dynamodb",
|
||||||
|
Region: "us-east-1",
|
||||||
|
Request: req,
|
||||||
|
Body: body,
|
||||||
|
Query: req.URL.Query(),
|
||||||
|
Time: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
|
||||||
|
if !strings.Contains(ctx.canonicalHeaders, "host:"+req.Host) {
|
||||||
|
t.Errorf("canonical host header invalid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSign_buildCanonicalHeaders(t *testing.T) {
|
||||||
|
serviceName := "mockAPI"
|
||||||
|
region := "mock-region"
|
||||||
|
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create request, %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("FooInnerSpace", " inner space ")
|
||||||
|
req.Header.Set("FooLeadingSpace", " leading-space")
|
||||||
|
req.Header.Add("FooMultipleSpace", "no-space")
|
||||||
|
req.Header.Add("FooMultipleSpace", "\ttab-space")
|
||||||
|
req.Header.Add("FooMultipleSpace", "trailing-space ")
|
||||||
|
req.Header.Set("FooNoSpace", "no-space")
|
||||||
|
req.Header.Set("FooTabSpace", "\ttab-space\t")
|
||||||
|
req.Header.Set("FooTrailingSpace", "trailing-space ")
|
||||||
|
req.Header.Set("FooWrappedSpace", " wrapped-space ")
|
||||||
|
|
||||||
|
ctx := &signingCtx{
|
||||||
|
ServiceName: serviceName,
|
||||||
|
Region: region,
|
||||||
|
Request: req,
|
||||||
|
Body: nil,
|
||||||
|
Query: req.URL.Query(),
|
||||||
|
Time: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
|
||||||
|
|
||||||
|
expectCanonicalHeaders := strings.Join([]string{
|
||||||
|
`fooinnerspace:inner space`,
|
||||||
|
`fooleadingspace:leading-space`,
|
||||||
|
`foomultiplespace:no-space,tab-space,trailing-space`,
|
||||||
|
`foonospace:no-space`,
|
||||||
|
`footabspace:tab-space`,
|
||||||
|
`footrailingspace:trailing-space`,
|
||||||
|
`foowrappedspace:wrapped-space`,
|
||||||
|
`host:mockAPI.mock-region.amazonaws.com`,
|
||||||
|
}, "\n")
|
||||||
|
if e, a := expectCanonicalHeaders, ctx.canonicalHeaders; e != a {
|
||||||
|
t.Errorf("expect:\n%s\n\nactual:\n%s", e, a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSignRequest(b *testing.B) {
|
||||||
|
signer := buildSigner()
|
||||||
|
req, body := buildRequestReaderSeeker("dynamodb", "us-east-1", "{}")
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
|
||||||
|
if err != nil {
|
||||||
|
b.Errorf("Expected no err, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var stripExcessSpaceCases = []string{
|
||||||
|
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
|
||||||
|
`123 321 123 321`,
|
||||||
|
` 123 321 123 321 `,
|
||||||
|
` 123 321 123 321 `,
|
||||||
|
"123",
|
||||||
|
"1 2 3",
|
||||||
|
" 1 2 3",
|
||||||
|
"1 2 3",
|
||||||
|
"1 23",
|
||||||
|
"1 2 3",
|
||||||
|
"1 2 ",
|
||||||
|
" 1 2 ",
|
||||||
|
"12 3",
|
||||||
|
"12 3 1",
|
||||||
|
"12 3 1",
|
||||||
|
"12 3 1abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStripExcessSpaces(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Make sure to start with a copy of the cases
|
||||||
|
cases := append([]string{}, stripExcessSpaceCases...)
|
||||||
|
stripExcessSpaces(cases)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// readerSeekerWrapper mimics the interface provided by request.offsetReader
|
||||||
|
type readerSeekerWrapper struct {
|
||||||
|
r *strings.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *readerSeekerWrapper) Read(p []byte) (n int, err error) {
|
||||||
|
return r.r.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *readerSeekerWrapper) Seek(offset int64, whence int) (int64, error) {
|
||||||
|
return r.r.Seek(offset, whence)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *readerSeekerWrapper) Len() int {
|
||||||
|
return r.r.Len()
|
||||||
|
}
|
153
internal/aws/types.go
Normal file
153
internal/aws/types.go
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||||
|
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/types.go
|
||||||
|
// See THIRD-PARTY-NOTICES for original license terms
|
||||||
|
|
||||||
|
package aws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
|
||||||
|
// SDK to accept an io.Reader that is not also an io.Seeker for unsigned
|
||||||
|
// streaming payload API operations.
|
||||||
|
//
|
||||||
|
// A ReadSeekCloser wrapping an nonseekable io.Reader used in an API
|
||||||
|
// operation's input will prevent that operation being retried in the case of
|
||||||
|
// network errors, and cause operation requests to fail if the operation
|
||||||
|
// requires payload signing.
|
||||||
|
//
|
||||||
|
// Note: If using With S3 PutObject to stream an object upload The SDK's S3
|
||||||
|
// Upload manager (s3manager.Uploader) provides support for streaming with the
|
||||||
|
// ability to retry network errors.
|
||||||
|
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
|
||||||
|
return ReaderSeekerCloser{r}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReaderSeekerCloser represents a reader that can also delegate io.Seeker and
|
||||||
|
// io.Closer interfaces to the underlying object if they are available.
|
||||||
|
type ReaderSeekerCloser struct {
|
||||||
|
r io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsReaderSeekable returns if the underlying reader type can be seeked. A
|
||||||
|
// io.Reader might not actually be seekable if it is the ReaderSeekerCloser
|
||||||
|
// type.
|
||||||
|
func IsReaderSeekable(r io.Reader) bool {
|
||||||
|
switch v := r.(type) {
|
||||||
|
case ReaderSeekerCloser:
|
||||||
|
return v.IsSeeker()
|
||||||
|
case *ReaderSeekerCloser:
|
||||||
|
return v.IsSeeker()
|
||||||
|
case io.ReadSeeker:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads from the reader up to size of p. The number of bytes read, and
|
||||||
|
// error if it occurred will be returned.
|
||||||
|
//
|
||||||
|
// If the reader is not an io.Reader zero bytes read, and nil error will be
|
||||||
|
// returned.
|
||||||
|
//
|
||||||
|
// Performs the same functionality as io.Reader Read
|
||||||
|
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
|
||||||
|
switch t := r.r.(type) {
|
||||||
|
case io.Reader:
|
||||||
|
return t.Read(p)
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seek sets the offset for the next Read to offset, interpreted according to
|
||||||
|
// whence: 0 means relative to the origin of the file, 1 means relative to the
|
||||||
|
// current offset, and 2 means relative to the end. Seek returns the new offset
|
||||||
|
// and an error, if any.
|
||||||
|
//
|
||||||
|
// If the ReaderSeekerCloser is not an io.Seeker nothing will be done.
|
||||||
|
func (r ReaderSeekerCloser) Seek(offset int64, whence int) (int64, error) {
|
||||||
|
switch t := r.r.(type) {
|
||||||
|
case io.Seeker:
|
||||||
|
return t.Seek(offset, whence)
|
||||||
|
}
|
||||||
|
return int64(0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSeeker returns if the underlying reader is also a seeker.
|
||||||
|
func (r ReaderSeekerCloser) IsSeeker() bool {
|
||||||
|
_, ok := r.r.(io.Seeker)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasLen returns the length of the underlying reader if the value implements
|
||||||
|
// the Len() int method.
|
||||||
|
func (r ReaderSeekerCloser) HasLen() (int, bool) {
|
||||||
|
type lenner interface {
|
||||||
|
Len() int
|
||||||
|
}
|
||||||
|
|
||||||
|
if lr, ok := r.r.(lenner); ok {
|
||||||
|
return lr.Len(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLen returns the length of the bytes remaining in the underlying reader.
|
||||||
|
// Checks first for Len(), then io.Seeker to determine the size of the
|
||||||
|
// underlying reader.
|
||||||
|
//
|
||||||
|
// Will return -1 if the length cannot be determined.
|
||||||
|
func (r ReaderSeekerCloser) GetLen() (int64, error) {
|
||||||
|
if l, ok := r.HasLen(); ok {
|
||||||
|
return int64(l), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, ok := r.r.(io.Seeker); ok {
|
||||||
|
return seekerLen(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SeekerLen attempts to get the number of bytes remaining at the seeker's
|
||||||
|
// current position. Returns the number of bytes remaining or error.
|
||||||
|
func SeekerLen(s io.Seeker) (int64, error) {
|
||||||
|
// Determine if the seeker is actually seekable. ReaderSeekerCloser
|
||||||
|
// hides the fact that a io.Readers might not actually be seekable.
|
||||||
|
switch v := s.(type) {
|
||||||
|
case ReaderSeekerCloser:
|
||||||
|
return v.GetLen()
|
||||||
|
case *ReaderSeekerCloser:
|
||||||
|
return v.GetLen()
|
||||||
|
}
|
||||||
|
|
||||||
|
return seekerLen(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func seekerLen(s io.Seeker) (int64, error) {
|
||||||
|
curOffset, err := s.Seek(0, io.SeekCurrent)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
endOffset, err := s.Seek(0, io.SeekEnd)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.Seek(curOffset, io.SeekStart)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return endOffset - curOffset, nil
|
||||||
|
}
|
40
internal/bsoncoreutil/bsoncoreutil.go
Normal file
40
internal/bsoncoreutil/bsoncoreutil.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bsoncoreutil
|
||||||
|
|
||||||
|
// Truncate truncates a given string for a certain width
|
||||||
|
func Truncate(str string, width int) string {
|
||||||
|
if width == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(str) <= width {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate the byte slice of the string to the given width.
|
||||||
|
newStr := str[:width]
|
||||||
|
|
||||||
|
// Check if the last byte is at the beginning of a multi-byte character.
|
||||||
|
// If it is, then remove the last byte.
|
||||||
|
if newStr[len(newStr)-1]&0xC0 == 0xC0 {
|
||||||
|
return newStr[:len(newStr)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the last byte is a multi-byte character
|
||||||
|
if newStr[len(newStr)-1]&0xC0 == 0x80 {
|
||||||
|
// If it is, step back until you we are at the start of a character
|
||||||
|
for i := len(newStr) - 1; i >= 0; i-- {
|
||||||
|
if newStr[i]&0xC0 == 0xC0 {
|
||||||
|
// Truncate at the end of the character before the character we stepped back to
|
||||||
|
return newStr[:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newStr
|
||||||
|
}
|
59
internal/bsoncoreutil/bsoncoreutil_test.go
Normal file
59
internal/bsoncoreutil/bsoncoreutil_test.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bsoncoreutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTruncate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
for _, tcase := range []struct {
|
||||||
|
name string
|
||||||
|
arg string
|
||||||
|
width int
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
arg: "",
|
||||||
|
width: 0,
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "short",
|
||||||
|
arg: "foo",
|
||||||
|
width: 1000,
|
||||||
|
expected: "foo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "long",
|
||||||
|
arg: "foo bar baz",
|
||||||
|
width: 9,
|
||||||
|
expected: "foo bar b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-byte",
|
||||||
|
arg: "你好",
|
||||||
|
width: 4,
|
||||||
|
expected: "你",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
tcase := tcase
|
||||||
|
|
||||||
|
t.Run(tcase.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
actual := Truncate(tcase.arg, tcase.width)
|
||||||
|
assert.Equal(t, tcase.expected, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
62
internal/bsonutil/bsonutil.go
Normal file
62
internal/bsonutil/bsonutil.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package bsonutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StringSliceFromRawValue decodes the provided BSON value into a []string. This function returns an error if the value
|
||||||
|
// is not an array or any of the elements in the array are not strings. The name parameter is used to add context to
|
||||||
|
// error messages.
|
||||||
|
func StringSliceFromRawValue(name string, val bson.RawValue) ([]string, error) {
|
||||||
|
arr, ok := val.ArrayOK()
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("expected '%s' to be an array but it's a BSON %s", name, val.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayValues, err := arr.Values()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
strs := make([]string, 0, len(arrayValues))
|
||||||
|
for _, arrayVal := range arrayValues {
|
||||||
|
str, ok := arrayVal.StringValueOK()
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("expected '%s' to be an array of strings, but found a BSON %s", name, arrayVal.Type)
|
||||||
|
}
|
||||||
|
strs = append(strs, str)
|
||||||
|
}
|
||||||
|
return strs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RawArrayToDocuments converts an array of documents to []bson.Raw.
|
||||||
|
func RawArrayToDocuments(arr bson.RawArray) []bson.Raw {
|
||||||
|
values, err := arr.Values()
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("error converting BSON document to values: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]bson.Raw, len(values))
|
||||||
|
for i := range values {
|
||||||
|
out[i] = values[i].Document()
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// RawToInterfaces takes one or many bson.Raw documents and returns them as a []interface{}.
|
||||||
|
func RawToInterfaces(docs ...bson.Raw) []interface{} {
|
||||||
|
out := make([]interface{}, len(docs))
|
||||||
|
for i := range docs {
|
||||||
|
out[i] = docs[i]
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
62
internal/codecutil/encoding.go
Normal file
62
internal/codecutil/encoding.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package codecutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrNilValue = errors.New("value is nil")
|
||||||
|
|
||||||
|
// MarshalError is returned when attempting to transform a value into a document
|
||||||
|
// results in an error.
|
||||||
|
type MarshalError struct {
|
||||||
|
Value interface{}
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface.
|
||||||
|
func (e MarshalError) Error() string {
|
||||||
|
return fmt.Sprintf("cannot transform type %s to a BSON Document: %v",
|
||||||
|
reflect.TypeOf(e.Value), e.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncoderFn is used to functionally construct an encoder for marshaling values.
|
||||||
|
type EncoderFn func(io.Writer) *bson.Encoder
|
||||||
|
|
||||||
|
// MarshalValue will attempt to encode the value with the encoder returned by
|
||||||
|
// the encoder function.
|
||||||
|
func MarshalValue(val interface{}, encFn EncoderFn) (bsoncore.Value, error) {
|
||||||
|
// If the val is already a bsoncore.Value, then do nothing.
|
||||||
|
if bval, ok := val.(bsoncore.Value); ok {
|
||||||
|
return bval, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if val == nil {
|
||||||
|
return bsoncore.Value{}, ErrNilValue
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
|
enc := encFn(buf)
|
||||||
|
|
||||||
|
// Encode the value in a single-element document with an empty key. Use
|
||||||
|
// bsoncore to extract the first element and return the BSON value.
|
||||||
|
err := enc.Encode(bson.D{{Key: "", Value: val}})
|
||||||
|
if err != nil {
|
||||||
|
return bsoncore.Value{}, MarshalError{Value: val, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bsoncore.Document(buf.Bytes()).Index(0).Value(), nil
|
||||||
|
}
|
82
internal/codecutil/encoding_test.go
Normal file
82
internal/codecutil/encoding_test.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package codecutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testEncFn(t *testing.T) EncoderFn {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
return func(w io.Writer) *bson.Encoder {
|
||||||
|
rw := bson.NewDocumentWriter(w)
|
||||||
|
return bson.NewEncoder(rw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarshalValue(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
val interface{}
|
||||||
|
registry *bson.Registry
|
||||||
|
encFn EncoderFn
|
||||||
|
want string
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
val: nil,
|
||||||
|
want: "",
|
||||||
|
wantErr: ErrNilValue,
|
||||||
|
encFn: testEncFn(t),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bson.D",
|
||||||
|
val: bson.D{{"foo", "bar"}},
|
||||||
|
want: `{"foo": "bar"}`,
|
||||||
|
encFn: testEncFn(t),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "map",
|
||||||
|
val: map[string]interface{}{"foo": "bar"},
|
||||||
|
want: `{"foo": "bar"}`,
|
||||||
|
encFn: testEncFn(t),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "struct",
|
||||||
|
val: struct{ Foo string }{Foo: "bar"},
|
||||||
|
want: `{"foo": "bar"}`,
|
||||||
|
encFn: testEncFn(t),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-document type",
|
||||||
|
val: "foo: bar",
|
||||||
|
want: `"foo: bar"`,
|
||||||
|
encFn: testEncFn(t),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
value, err := MarshalValue(test.val, test.encFn)
|
||||||
|
|
||||||
|
assert.Equal(t, test.wantErr, err, "expected and actual error do not match")
|
||||||
|
assert.Equal(t, test.want, value.String(), "expected and actual comments are different")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
148
internal/credproviders/assume_role_provider.go
Normal file
148
internal/credproviders/assume_role_provider.go
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package credproviders
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// assumeRoleProviderName provides a name of assume role provider
|
||||||
|
assumeRoleProviderName = "AssumeRoleProvider"
|
||||||
|
|
||||||
|
stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15`
|
||||||
|
)
|
||||||
|
|
||||||
|
// An AssumeRoleProvider retrieves credentials for assume role with web identity.
|
||||||
|
type AssumeRoleProvider struct {
|
||||||
|
AwsRoleArnEnv EnvVar
|
||||||
|
AwsWebIdentityTokenFileEnv EnvVar
|
||||||
|
AwsRoleSessionNameEnv EnvVar
|
||||||
|
|
||||||
|
httpClient *http.Client
|
||||||
|
expiration time.Time
|
||||||
|
|
||||||
|
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
|
||||||
|
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
|
||||||
|
//
|
||||||
|
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
||||||
|
// 10 seconds before the credentials are actually expired.
|
||||||
|
expiryWindow time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAssumeRoleProvider returns a pointer to an assume role provider.
|
||||||
|
func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider {
|
||||||
|
return &AssumeRoleProvider{
|
||||||
|
// AwsRoleArnEnv is the environment variable for AWS_ROLE_ARN
|
||||||
|
AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"),
|
||||||
|
// AwsWebIdentityTokenFileEnv is the environment variable for AWS_WEB_IDENTITY_TOKEN_FILE
|
||||||
|
AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"),
|
||||||
|
// AwsRoleSessionNameEnv is the environment variable for AWS_ROLE_SESSION_NAME
|
||||||
|
AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"),
|
||||||
|
httpClient: httpClient,
|
||||||
|
expiryWindow: expiryWindow,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetrieveWithContext retrieves the keys from the AWS service.
|
||||||
|
func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
|
||||||
|
const defaultHTTPTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
v := credentials.Value{ProviderName: assumeRoleProviderName}
|
||||||
|
|
||||||
|
roleArn := a.AwsRoleArnEnv.Get()
|
||||||
|
tokenFile := a.AwsWebIdentityTokenFileEnv.Get()
|
||||||
|
if tokenFile == "" && roleArn == "" {
|
||||||
|
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing")
|
||||||
|
}
|
||||||
|
if tokenFile != "" && roleArn == "" {
|
||||||
|
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing")
|
||||||
|
}
|
||||||
|
if tokenFile == "" && roleArn != "" {
|
||||||
|
return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing")
|
||||||
|
}
|
||||||
|
token, err := ioutil.ReadFile(tokenFile)
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionName := a.AwsRoleSessionNameEnv.Get()
|
||||||
|
if sessionName == "" {
|
||||||
|
// Use a UUID if the RoleSessionName is not given.
|
||||||
|
id, err := uuid.New()
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
sessionName = id.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token))
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, fullURI, nil)
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := a.httpClient.Do(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return v, fmt.Errorf("response failure: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var stsResp struct {
|
||||||
|
Response struct {
|
||||||
|
Result struct {
|
||||||
|
Credentials struct {
|
||||||
|
AccessKeyID string `json:"AccessKeyId"`
|
||||||
|
SecretAccessKey string `json:"SecretAccessKey"`
|
||||||
|
Token string `json:"SessionToken"`
|
||||||
|
Expiration float64 `json:"Expiration"`
|
||||||
|
} `json:"Credentials"`
|
||||||
|
} `json:"AssumeRoleWithWebIdentityResult"`
|
||||||
|
} `json:"AssumeRoleWithWebIdentityResponse"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&stsResp)
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID
|
||||||
|
v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey
|
||||||
|
v.SessionToken = stsResp.Response.Result.Credentials.Token
|
||||||
|
if !v.HasKeys() {
|
||||||
|
return v, errors.New("failed to retrieve web identity keys")
|
||||||
|
}
|
||||||
|
sec := int64(stsResp.Response.Result.Credentials.Expiration)
|
||||||
|
a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow)
|
||||||
|
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve retrieves the keys from the AWS service.
|
||||||
|
func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
|
||||||
|
return a.RetrieveWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired returns true if the credentials are expired.
|
||||||
|
func (a *AssumeRoleProvider) IsExpired() bool {
|
||||||
|
return a.expiration.Before(time.Now())
|
||||||
|
}
|
183
internal/credproviders/ec2_provider.go
Normal file
183
internal/credproviders/ec2_provider.go
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package credproviders
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ec2ProviderName provides a name of EC2 provider
|
||||||
|
ec2ProviderName = "EC2Provider"
|
||||||
|
|
||||||
|
awsEC2URI = "http://169.254.169.254/"
|
||||||
|
awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
|
||||||
|
awsEC2TokenPath = "latest/api/token"
|
||||||
|
|
||||||
|
defaultHTTPTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// An EC2Provider retrieves credentials from EC2 metadata.
|
||||||
|
type EC2Provider struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
expiration time.Time
|
||||||
|
|
||||||
|
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
|
||||||
|
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
|
||||||
|
//
|
||||||
|
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
||||||
|
// 10 seconds before the credentials are actually expired.
|
||||||
|
expiryWindow time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEC2Provider returns a pointer to an EC2 credential provider.
|
||||||
|
func NewEC2Provider(httpClient *http.Client, expiryWindow time.Duration) *EC2Provider {
|
||||||
|
return &EC2Provider{
|
||||||
|
httpClient: httpClient,
|
||||||
|
expiryWindow: expiryWindow,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EC2Provider) getToken(ctx context.Context) (string, error) {
|
||||||
|
req, err := http.NewRequest(http.MethodPut, awsEC2URI+awsEC2TokenPath, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
const defaultEC2TTLSeconds = "30"
|
||||||
|
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", defaultEC2TTLSeconds)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := e.httpClient.Do(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(token) == 0 {
|
||||||
|
return "", errors.New("unable to retrieve token from EC2 metadata")
|
||||||
|
}
|
||||||
|
return string(token), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EC2Provider) getRoleName(ctx context.Context, token string) (string, error) {
|
||||||
|
req, err := http.NewRequest(http.MethodGet, awsEC2URI+awsEC2RolePath, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header.Set("X-aws-ec2-metadata-token", token)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := e.httpClient.Do(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
role, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(role) == 0 {
|
||||||
|
return "", errors.New("unable to retrieve role_name from EC2 metadata")
|
||||||
|
}
|
||||||
|
return string(role), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EC2Provider) getCredentials(ctx context.Context, token string, role string) (credentials.Value, time.Time, error) {
|
||||||
|
v := credentials.Value{ProviderName: ec2ProviderName}
|
||||||
|
|
||||||
|
pathWithRole := awsEC2URI + awsEC2RolePath + role
|
||||||
|
req, err := http.NewRequest(http.MethodGet, pathWithRole, nil)
|
||||||
|
if err != nil {
|
||||||
|
return v, time.Time{}, err
|
||||||
|
}
|
||||||
|
req.Header.Set("X-aws-ec2-metadata-token", token)
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := e.httpClient.Do(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return v, time.Time{}, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return v, time.Time{}, fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ec2Resp struct {
|
||||||
|
AccessKeyID string `json:"AccessKeyId"`
|
||||||
|
SecretAccessKey string `json:"SecretAccessKey"`
|
||||||
|
Token string `json:"Token"`
|
||||||
|
Expiration time.Time `json:"Expiration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&ec2Resp)
|
||||||
|
if err != nil {
|
||||||
|
return v, time.Time{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v.AccessKeyID = ec2Resp.AccessKeyID
|
||||||
|
v.SecretAccessKey = ec2Resp.SecretAccessKey
|
||||||
|
v.SessionToken = ec2Resp.Token
|
||||||
|
|
||||||
|
return v, ec2Resp.Expiration, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetrieveWithContext retrieves the keys from the AWS service.
|
||||||
|
func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
|
||||||
|
v := credentials.Value{ProviderName: ec2ProviderName}
|
||||||
|
|
||||||
|
token, err := e.getToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
role, err := e.getRoleName(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v, exp, err := e.getCredentials(ctx, token, role)
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
if !v.HasKeys() {
|
||||||
|
return v, errors.New("failed to retrieve EC2 keys")
|
||||||
|
}
|
||||||
|
e.expiration = exp.Add(-e.expiryWindow)
|
||||||
|
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve retrieves the keys from the AWS service.
|
||||||
|
func (e *EC2Provider) Retrieve() (credentials.Value, error) {
|
||||||
|
return e.RetrieveWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired returns true if the credentials are expired.
|
||||||
|
func (e *EC2Provider) IsExpired() bool {
|
||||||
|
return e.expiration.Before(time.Now())
|
||||||
|
}
|
112
internal/credproviders/ecs_provider.go
Normal file
112
internal/credproviders/ecs_provider.go
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package credproviders
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ecsProviderName provides a name of ECS provider
|
||||||
|
ecsProviderName = "ECSProvider"
|
||||||
|
|
||||||
|
awsRelativeURI = "http://169.254.170.2/"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An ECSProvider retrieves credentials from ECS metadata.
|
||||||
|
type ECSProvider struct {
|
||||||
|
AwsContainerCredentialsRelativeURIEnv EnvVar
|
||||||
|
|
||||||
|
httpClient *http.Client
|
||||||
|
expiration time.Time
|
||||||
|
|
||||||
|
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
|
||||||
|
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
|
||||||
|
//
|
||||||
|
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
||||||
|
// 10 seconds before the credentials are actually expired.
|
||||||
|
expiryWindow time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewECSProvider returns a pointer to an ECS credential provider.
|
||||||
|
func NewECSProvider(httpClient *http.Client, expiryWindow time.Duration) *ECSProvider {
|
||||||
|
return &ECSProvider{
|
||||||
|
// AwsContainerCredentialsRelativeURIEnv is the environment variable for AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
|
||||||
|
AwsContainerCredentialsRelativeURIEnv: EnvVar("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"),
|
||||||
|
httpClient: httpClient,
|
||||||
|
expiryWindow: expiryWindow,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetrieveWithContext retrieves the keys from the AWS service.
|
||||||
|
func (e *ECSProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
|
||||||
|
const defaultHTTPTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
v := credentials.Value{ProviderName: ecsProviderName}
|
||||||
|
|
||||||
|
relativeEcsURI := e.AwsContainerCredentialsRelativeURIEnv.Get()
|
||||||
|
if len(relativeEcsURI) == 0 {
|
||||||
|
return v, errors.New("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is missing")
|
||||||
|
}
|
||||||
|
fullURI := awsRelativeURI + relativeEcsURI
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, fullURI, nil)
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||||
|
defer cancel()
|
||||||
|
resp, err := e.httpClient.Do(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return v, fmt.Errorf("response failure: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ecsResp struct {
|
||||||
|
AccessKeyID string `json:"AccessKeyId"`
|
||||||
|
SecretAccessKey string `json:"SecretAccessKey"`
|
||||||
|
Token string `json:"Token"`
|
||||||
|
Expiration time.Time `json:"Expiration"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&ecsResp)
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
v.AccessKeyID = ecsResp.AccessKeyID
|
||||||
|
v.SecretAccessKey = ecsResp.SecretAccessKey
|
||||||
|
v.SessionToken = ecsResp.Token
|
||||||
|
if !v.HasKeys() {
|
||||||
|
return v, errors.New("failed to retrieve ECS keys")
|
||||||
|
}
|
||||||
|
e.expiration = ecsResp.Expiration.Add(-e.expiryWindow)
|
||||||
|
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve retrieves the keys from the AWS service.
|
||||||
|
func (e *ECSProvider) Retrieve() (credentials.Value, error) {
|
||||||
|
return e.RetrieveWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired returns true if the credentials are expired.
|
||||||
|
func (e *ECSProvider) IsExpired() bool {
|
||||||
|
return e.expiration.Before(time.Now())
|
||||||
|
}
|
69
internal/credproviders/env_provider.go
Normal file
69
internal/credproviders/env_provider.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package credproviders
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
// envProviderName provides a name of Env provider
|
||||||
|
const envProviderName = "EnvProvider"
|
||||||
|
|
||||||
|
// EnvVar is an environment variable
|
||||||
|
type EnvVar string
|
||||||
|
|
||||||
|
// Get retrieves the environment variable
|
||||||
|
func (ev EnvVar) Get() string {
|
||||||
|
return os.Getenv(string(ev))
|
||||||
|
}
|
||||||
|
|
||||||
|
// A EnvProvider retrieves credentials from the environment variables of the
|
||||||
|
// running process. Environment credentials never expire.
|
||||||
|
type EnvProvider struct {
|
||||||
|
AwsAccessKeyIDEnv EnvVar
|
||||||
|
AwsSecretAccessKeyEnv EnvVar
|
||||||
|
AwsSessionTokenEnv EnvVar
|
||||||
|
|
||||||
|
retrieved bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEnvProvider returns a pointer to an ECS credential provider.
|
||||||
|
func NewEnvProvider() *EnvProvider {
|
||||||
|
return &EnvProvider{
|
||||||
|
// AwsAccessKeyIDEnv is the environment variable for AWS_ACCESS_KEY_ID
|
||||||
|
AwsAccessKeyIDEnv: EnvVar("AWS_ACCESS_KEY_ID"),
|
||||||
|
// AwsSecretAccessKeyEnv is the environment variable for AWS_SECRET_ACCESS_KEY
|
||||||
|
AwsSecretAccessKeyEnv: EnvVar("AWS_SECRET_ACCESS_KEY"),
|
||||||
|
// AwsSessionTokenEnv is the environment variable for AWS_SESSION_TOKEN
|
||||||
|
AwsSessionTokenEnv: EnvVar("AWS_SESSION_TOKEN"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve retrieves the keys from the environment.
|
||||||
|
func (e *EnvProvider) Retrieve() (credentials.Value, error) {
|
||||||
|
e.retrieved = false
|
||||||
|
|
||||||
|
v := credentials.Value{
|
||||||
|
AccessKeyID: e.AwsAccessKeyIDEnv.Get(),
|
||||||
|
SecretAccessKey: e.AwsSecretAccessKeyEnv.Get(),
|
||||||
|
SessionToken: e.AwsSessionTokenEnv.Get(),
|
||||||
|
ProviderName: envProviderName,
|
||||||
|
}
|
||||||
|
err := verify(v)
|
||||||
|
if err == nil {
|
||||||
|
e.retrieved = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired returns true if the credentials have not been retrieved.
|
||||||
|
func (e *EnvProvider) IsExpired() bool {
|
||||||
|
return !e.retrieved
|
||||||
|
}
|
103
internal/credproviders/imds_provider.go
Normal file
103
internal/credproviders/imds_provider.go
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package credproviders
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// AzureProviderName provides a name of Azure provider
|
||||||
|
AzureProviderName = "AzureProvider"
|
||||||
|
|
||||||
|
azureURI = "http://169.254.169.254/metadata/identity/oauth2/token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An AzureProvider retrieves credentials from Azure IMDS.
|
||||||
|
type AzureProvider struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
expiration time.Time
|
||||||
|
expiryWindow time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAzureProvider returns a pointer to an Azure credential provider.
|
||||||
|
func NewAzureProvider(httpClient *http.Client, expiryWindow time.Duration) *AzureProvider {
|
||||||
|
return &AzureProvider{
|
||||||
|
httpClient: httpClient,
|
||||||
|
expiration: time.Time{},
|
||||||
|
expiryWindow: expiryWindow,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetrieveWithContext retrieves the keys from the Azure service.
|
||||||
|
func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
|
||||||
|
v := credentials.Value{ProviderName: AzureProviderName}
|
||||||
|
req, err := http.NewRequest(http.MethodGet, azureURI, nil)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
|
||||||
|
}
|
||||||
|
q := make(url.Values)
|
||||||
|
q.Set("api-version", "2018-02-01")
|
||||||
|
q.Set("resource", "https://vault.azure.net")
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
req.Header.Set("Metadata", "true")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("unable to retrieve Azure credentials: error reading response body: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return v, fmt.Errorf("unable to retrieve Azure credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
var tokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
ExpiresIn string `json:"expires_in"`
|
||||||
|
}
|
||||||
|
// Attempt to read body as JSON
|
||||||
|
err = json.Unmarshal(body, &tokenResponse)
|
||||||
|
if err != nil {
|
||||||
|
return v, fmt.Errorf("unable to retrieve Azure credentials: error reading body JSON: %w (response body: %s)", err, body)
|
||||||
|
}
|
||||||
|
if tokenResponse.AccessToken == "" {
|
||||||
|
return v, fmt.Errorf("unable to retrieve Azure credentials: got unexpected empty accessToken from Azure Metadata Server. Response body: %s", body)
|
||||||
|
}
|
||||||
|
v.SessionToken = tokenResponse.AccessToken
|
||||||
|
|
||||||
|
expiresIn, err := time.ParseDuration(tokenResponse.ExpiresIn + "s")
|
||||||
|
if err != nil {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
if expiration := expiresIn - a.expiryWindow; expiration > 0 {
|
||||||
|
a.expiration = time.Now().Add(expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve retrieves the keys from the Azure service.
|
||||||
|
func (a *AzureProvider) Retrieve() (credentials.Value, error) {
|
||||||
|
return a.RetrieveWithContext(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired returns if the credentials have been retrieved.
|
||||||
|
func (a *AzureProvider) IsExpired() bool {
|
||||||
|
return a.expiration.Before(time.Now())
|
||||||
|
}
|
59
internal/credproviders/static_provider.go
Normal file
59
internal/credproviders/static_provider.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package credproviders
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
// staticProviderName provides a name of Static provider
|
||||||
|
const staticProviderName = "StaticProvider"
|
||||||
|
|
||||||
|
// A StaticProvider is a set of credentials which are set programmatically,
|
||||||
|
// and will never expire.
|
||||||
|
type StaticProvider struct {
|
||||||
|
credentials.Value
|
||||||
|
|
||||||
|
verified bool
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func verify(v credentials.Value) error {
|
||||||
|
if !v.HasKeys() {
|
||||||
|
return errors.New("failed to retrieve ACCESS_KEY_ID and SECRET_ACCESS_KEY")
|
||||||
|
}
|
||||||
|
if v.AccessKeyID != "" && v.SecretAccessKey == "" {
|
||||||
|
return errors.New("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing")
|
||||||
|
}
|
||||||
|
if v.AccessKeyID == "" && v.SecretAccessKey != "" {
|
||||||
|
return errors.New("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing")
|
||||||
|
}
|
||||||
|
if v.AccessKeyID == "" && v.SecretAccessKey == "" && v.SessionToken != "" {
|
||||||
|
return errors.New("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve returns the credentials or error if the credentials are invalid.
|
||||||
|
func (s *StaticProvider) Retrieve() (credentials.Value, error) {
|
||||||
|
if !s.verified {
|
||||||
|
s.err = verify(s.Value)
|
||||||
|
s.Value.ProviderName = staticProviderName
|
||||||
|
s.verified = true
|
||||||
|
}
|
||||||
|
return s.Value, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired returns if the credentials are expired.
|
||||||
|
//
|
||||||
|
// For StaticProvider, the credentials never expired.
|
||||||
|
func (s *StaticProvider) IsExpired() bool {
|
||||||
|
return false
|
||||||
|
}
|
40
internal/csfle/csfle.go
Normal file
40
internal/csfle/csfle.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package csfle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EncryptedCacheCollection = "ecc"
|
||||||
|
EncryptedStateCollection = "esc"
|
||||||
|
EncryptedCompactionCollection = "ecoc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetEncryptedStateCollectionName returns the encrypted state collection name associated with dataCollectionName.
|
||||||
|
func GetEncryptedStateCollectionName(efBSON bsoncore.Document, dataCollectionName string, stateCollection string) (string, error) {
|
||||||
|
fieldName := stateCollection + "Collection"
|
||||||
|
val, err := efBSON.LookupErr(fieldName)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, bsoncore.ErrElementNotFound) {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// Return default name.
|
||||||
|
defaultName := "enxcol_." + dataCollectionName + "." + stateCollection
|
||||||
|
return defaultName, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stateCollectionName, ok := val.StringValueOK()
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("expected string for '%v', got: %v", fieldName, val.Type)
|
||||||
|
}
|
||||||
|
return stateCollectionName, nil
|
||||||
|
}
|
106
internal/csot/csot.go
Normal file
106
internal/csot/csot.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package csot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type clientLevel struct{}
|
||||||
|
|
||||||
|
func isClientLevel(ctx context.Context) bool {
|
||||||
|
val := ctx.Value(clientLevel{})
|
||||||
|
if val == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return val.(bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTimeoutContext checks if the provided context has been assigned a deadline
|
||||||
|
// or has unlimited retries.
|
||||||
|
func IsTimeoutContext(ctx context.Context) bool {
|
||||||
|
_, ok := ctx.Deadline()
|
||||||
|
|
||||||
|
return ok || isClientLevel(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTimeout will set the given timeout on the context, if no deadline has
|
||||||
|
// already been set.
|
||||||
|
//
|
||||||
|
// This function assumes that the timeout field is static, given that the
|
||||||
|
// timeout should be sourced from the client. Therefore, once a timeout function
|
||||||
|
// parameter has been applied to the context, it will remain for the lifetime
|
||||||
|
// of the context.
|
||||||
|
func WithTimeout(parent context.Context, timeout *time.Duration) (context.Context, context.CancelFunc) {
|
||||||
|
cancel := func() {}
|
||||||
|
|
||||||
|
if timeout == nil || IsTimeoutContext(parent) {
|
||||||
|
// In the following conditions, do nothing:
|
||||||
|
// 1. The parent already has a deadline
|
||||||
|
// 2. The parent does not have a deadline, but a client-level timeout has
|
||||||
|
// been applied.
|
||||||
|
// 3. The parent does not have a deadline, there is not client-level
|
||||||
|
// timeout, and the timeout parameter DNE.
|
||||||
|
return parent, cancel
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a client-level timeout has not been applied, then apply it.
|
||||||
|
parent = context.WithValue(parent, clientLevel{}, true)
|
||||||
|
|
||||||
|
dur := *timeout
|
||||||
|
|
||||||
|
if dur == 0 {
|
||||||
|
// If the parent does not have a deadline and the timeout is zero, then
|
||||||
|
// do nothing.
|
||||||
|
return parent, cancel
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the parent does not have a dealine and the timeout is non-zero, then
|
||||||
|
// apply the timeout.
|
||||||
|
return context.WithTimeout(parent, dur)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithServerSelectionTimeout creates a context with a timeout that is the
|
||||||
|
// minimum of serverSelectionTimeoutMS and context deadline. The usage of
|
||||||
|
// non-positive values for serverSelectionTimeoutMS are an anti-pattern and are
|
||||||
|
// not considered in this calculation.
|
||||||
|
func WithServerSelectionTimeout(
|
||||||
|
parent context.Context,
|
||||||
|
serverSelectionTimeout time.Duration,
|
||||||
|
) (context.Context, context.CancelFunc) {
|
||||||
|
if serverSelectionTimeout <= 0 {
|
||||||
|
return parent, func() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
return context.WithTimeout(parent, serverSelectionTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZeroRTTMonitor implements the RTTMonitor interface and is used internally for testing. It returns 0 for all
|
||||||
|
// RTT calculations and an empty string for RTT statistics.
|
||||||
|
type ZeroRTTMonitor struct{}
|
||||||
|
|
||||||
|
// EWMA implements the RTT monitor interface.
|
||||||
|
func (zrm *ZeroRTTMonitor) EWMA() time.Duration {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Min implements the RTT monitor interface.
|
||||||
|
func (zrm *ZeroRTTMonitor) Min() time.Duration {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// P90 implements the RTT monitor interface.
|
||||||
|
func (zrm *ZeroRTTMonitor) P90() time.Duration {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats implements the RTT monitor interface.
|
||||||
|
func (zrm *ZeroRTTMonitor) Stats() string {
|
||||||
|
return ""
|
||||||
|
}
|
249
internal/csot/csot_test.go
Normal file
249
internal/csot/csot_test.go
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package csot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/ptrutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestContext(t *testing.T, timeout time.Duration) context.Context {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithServerSelectionTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
parent context.Context
|
||||||
|
serverSelectionTimeout time.Duration
|
||||||
|
wantTimeout time.Duration
|
||||||
|
wantOk bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no context deadine and ssto is zero",
|
||||||
|
parent: context.Background(),
|
||||||
|
serverSelectionTimeout: 0,
|
||||||
|
wantTimeout: 0,
|
||||||
|
wantOk: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no context deadline and ssto is positive",
|
||||||
|
parent: context.Background(),
|
||||||
|
serverSelectionTimeout: 1,
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no context deadline and ssto is negative",
|
||||||
|
parent: context.Background(),
|
||||||
|
serverSelectionTimeout: -1,
|
||||||
|
wantTimeout: 0,
|
||||||
|
wantOk: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is zero and ssto is positive",
|
||||||
|
parent: newTestContext(t, 0),
|
||||||
|
serverSelectionTimeout: 1,
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is zero and ssto is negative",
|
||||||
|
parent: newTestContext(t, 0),
|
||||||
|
serverSelectionTimeout: -1,
|
||||||
|
wantTimeout: 0,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is negative and ssto is zero",
|
||||||
|
parent: newTestContext(t, -1),
|
||||||
|
serverSelectionTimeout: 0,
|
||||||
|
wantTimeout: -1,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is negative and ssto is positive",
|
||||||
|
parent: newTestContext(t, -1),
|
||||||
|
serverSelectionTimeout: 1,
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is negative and ssto is negative",
|
||||||
|
parent: newTestContext(t, -1),
|
||||||
|
serverSelectionTimeout: -1,
|
||||||
|
wantTimeout: -1,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is positive and ssto is zero",
|
||||||
|
parent: newTestContext(t, 1),
|
||||||
|
serverSelectionTimeout: 0,
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is positive and equal to ssto",
|
||||||
|
parent: newTestContext(t, 1),
|
||||||
|
serverSelectionTimeout: 1,
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is positive lt ssto",
|
||||||
|
parent: newTestContext(t, 1),
|
||||||
|
serverSelectionTimeout: 2,
|
||||||
|
wantTimeout: 2,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is positive gt ssto",
|
||||||
|
parent: newTestContext(t, 2),
|
||||||
|
serverSelectionTimeout: 1,
|
||||||
|
wantTimeout: 2,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "context deadline is positive and ssto is negative",
|
||||||
|
parent: newTestContext(t, -1),
|
||||||
|
serverSelectionTimeout: -1,
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantOk: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test // Capture the range variable
|
||||||
|
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := WithServerSelectionTimeout(test.parent, test.serverSelectionTimeout)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
deadline, gotOk := ctx.Deadline()
|
||||||
|
assert.Equal(t, test.wantOk, gotOk)
|
||||||
|
|
||||||
|
if gotOk {
|
||||||
|
delta := time.Until(deadline) - test.wantTimeout
|
||||||
|
tolerance := 10 * time.Millisecond
|
||||||
|
|
||||||
|
assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance)
|
||||||
|
assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTimeout(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
parent context.Context
|
||||||
|
timeout *time.Duration
|
||||||
|
wantTimeout time.Duration
|
||||||
|
wantDeadline bool
|
||||||
|
wantValues []interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "deadline set with non-zero timeout",
|
||||||
|
parent: newTestContext(t, 1),
|
||||||
|
timeout: ptrutil.Ptr(time.Duration(2)),
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantDeadline: true,
|
||||||
|
wantValues: []interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deadline set with zero timeout",
|
||||||
|
parent: newTestContext(t, 1),
|
||||||
|
timeout: ptrutil.Ptr(time.Duration(0)),
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantDeadline: true,
|
||||||
|
wantValues: []interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deadline set with nil timeout",
|
||||||
|
parent: newTestContext(t, 1),
|
||||||
|
timeout: nil,
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantDeadline: true,
|
||||||
|
wantValues: []interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deadline unset with non-zero timeout",
|
||||||
|
parent: context.Background(),
|
||||||
|
timeout: ptrutil.Ptr(time.Duration(1)),
|
||||||
|
wantTimeout: 1,
|
||||||
|
wantDeadline: true,
|
||||||
|
wantValues: []interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deadline unset with zero timeout",
|
||||||
|
parent: context.Background(),
|
||||||
|
timeout: ptrutil.Ptr(time.Duration(0)),
|
||||||
|
wantTimeout: 0,
|
||||||
|
wantDeadline: false,
|
||||||
|
wantValues: []interface{}{clientLevel{}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "deadline unset with nil timeout",
|
||||||
|
parent: context.Background(),
|
||||||
|
timeout: nil,
|
||||||
|
wantTimeout: 0,
|
||||||
|
wantDeadline: false,
|
||||||
|
wantValues: []interface{}{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// If "clientLevel" has been set, but a new timeout is applied
|
||||||
|
// to the context, then the constructed context should retain the old
|
||||||
|
// timeout. To simplify the code, we assume the first timeout is static.
|
||||||
|
name: "deadline unset with non-zero timeout at clientLevel",
|
||||||
|
parent: context.WithValue(context.Background(), clientLevel{}, true),
|
||||||
|
timeout: ptrutil.Ptr(time.Duration(1)),
|
||||||
|
wantTimeout: 0,
|
||||||
|
wantDeadline: false,
|
||||||
|
wantValues: []interface{}{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test // Capture the range variable
|
||||||
|
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := WithTimeout(test.parent, test.timeout)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
deadline, gotDeadline := ctx.Deadline()
|
||||||
|
assert.Equal(t, test.wantDeadline, gotDeadline)
|
||||||
|
|
||||||
|
if gotDeadline {
|
||||||
|
delta := time.Until(deadline) - test.wantTimeout
|
||||||
|
tolerance := 10 * time.Millisecond
|
||||||
|
|
||||||
|
assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance)
|
||||||
|
assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, wantValue := range test.wantValues {
|
||||||
|
assert.NotNil(t, ctx.Value(wantValue), "expected context to have value %v", wantValue)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
117
internal/decimal128/decinal128.go
Normal file
117
internal/decimal128/decinal128.go
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package decimal128
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
|
||||||
|
const (
|
||||||
|
MaxDecimal128Exp = 6111
|
||||||
|
MinDecimal128Exp = -6176
|
||||||
|
)
|
||||||
|
|
||||||
|
func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) {
|
||||||
|
div64 := uint64(div)
|
||||||
|
a := h >> 32
|
||||||
|
aq := a / div64
|
||||||
|
ar := a % div64
|
||||||
|
b := ar<<32 + h&(1<<32-1)
|
||||||
|
bq := b / div64
|
||||||
|
br := b % div64
|
||||||
|
c := br<<32 + l>>32
|
||||||
|
cq := c / div64
|
||||||
|
cr := c % div64
|
||||||
|
d := cr<<32 + l&(1<<32-1)
|
||||||
|
dq := d / div64
|
||||||
|
dr := d % div64
|
||||||
|
return (aq<<32 | bq), (cq<<32 | dq), uint32(dr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the decimal value.
|
||||||
|
func String(h, l uint64) string {
|
||||||
|
var posSign int // positive sign
|
||||||
|
var exp int // exponent
|
||||||
|
var high, low uint64 // significand high/low
|
||||||
|
|
||||||
|
if h>>63&1 == 0 {
|
||||||
|
posSign = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
switch h >> 58 & (1<<5 - 1) {
|
||||||
|
case 0x1F:
|
||||||
|
return "NaN"
|
||||||
|
case 0x1E:
|
||||||
|
return "-Infinity"[posSign:]
|
||||||
|
}
|
||||||
|
|
||||||
|
low = l
|
||||||
|
if h>>61&3 == 3 {
|
||||||
|
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
|
||||||
|
// Implicit 0b100 prefix in significand.
|
||||||
|
exp = int(h >> 47 & (1<<14 - 1))
|
||||||
|
// Spec says all of these values are out of range.
|
||||||
|
high, low = 0, 0
|
||||||
|
} else {
|
||||||
|
// Bits: 1*sign 14*exponent 113*significand
|
||||||
|
exp = int(h >> 49 & (1<<14 - 1))
|
||||||
|
high = h & (1<<49 - 1)
|
||||||
|
}
|
||||||
|
exp += MinDecimal128Exp
|
||||||
|
|
||||||
|
// Would be handled by the logic below, but that's trivial and common.
|
||||||
|
if high == 0 && low == 0 && exp == 0 {
|
||||||
|
return "-0"[posSign:]
|
||||||
|
}
|
||||||
|
|
||||||
|
var repr [48]byte // Loop 5 times over 9 digits plus dot, negative sign, and leading zero.
|
||||||
|
var last = len(repr)
|
||||||
|
var i = len(repr)
|
||||||
|
var dot = len(repr) + exp
|
||||||
|
var rem uint32
|
||||||
|
Loop:
|
||||||
|
for d9 := 0; d9 < 5; d9++ {
|
||||||
|
high, low, rem = divmod(high, low, 1e9)
|
||||||
|
for d1 := 0; d1 < 9; d1++ {
|
||||||
|
// Handle "-0.0", "0.00123400", "-1.00E-6", "1.050E+3", etc.
|
||||||
|
if i < len(repr) && (dot == i || low == 0 && high == 0 && rem > 0 && rem < 10 && (dot < i-6 || exp > 0)) {
|
||||||
|
exp += len(repr) - i
|
||||||
|
i--
|
||||||
|
repr[i] = '.'
|
||||||
|
last = i - 1
|
||||||
|
dot = len(repr) // Unmark.
|
||||||
|
}
|
||||||
|
c := '0' + byte(rem%10)
|
||||||
|
rem /= 10
|
||||||
|
i--
|
||||||
|
repr[i] = c
|
||||||
|
// Handle "0E+3", "1E+3", etc.
|
||||||
|
if low == 0 && high == 0 && rem == 0 && i == len(repr)-1 && (dot < i-5 || exp > 0) {
|
||||||
|
last = i
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
if c != '0' {
|
||||||
|
last = i
|
||||||
|
}
|
||||||
|
// Break early. Works without it, but why.
|
||||||
|
if dot > i && low == 0 && high == 0 && rem == 0 {
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
repr[last-1] = '-'
|
||||||
|
last--
|
||||||
|
|
||||||
|
if exp > 0 {
|
||||||
|
return string(repr[last+posSign:]) + "E+" + strconv.Itoa(exp)
|
||||||
|
}
|
||||||
|
if exp < 0 {
|
||||||
|
return string(repr[last+posSign:]) + "E" + strconv.Itoa(exp)
|
||||||
|
}
|
||||||
|
return string(repr[last+posSign:])
|
||||||
|
}
|
85
internal/errutil/join.go
Normal file
85
internal/errutil/join.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package errutil
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// join is a Go 1.13-1.19 compatible version of [errors.Join]. It is only called
|
||||||
|
// by Join in join_go1.19.go. It is included here in a file without build
|
||||||
|
// constraints only for testing purposes.
|
||||||
|
//
|
||||||
|
// It is heavily based on Join from
|
||||||
|
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
|
||||||
|
func join(errs ...error) error {
|
||||||
|
n := 0
|
||||||
|
for _, err := range errs {
|
||||||
|
if err != nil {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
e := &joinError{
|
||||||
|
errs: make([]error, 0, n),
|
||||||
|
}
|
||||||
|
for _, err := range errs {
|
||||||
|
if err != nil {
|
||||||
|
e.errs = append(e.errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// joinError is a Go 1.13-1.19 compatible joinable error type. Its error
|
||||||
|
// message is identical to [errors.Join], but it implements "Unwrap() error"
|
||||||
|
// instead of "Unwrap() []error".
|
||||||
|
//
|
||||||
|
// It is heavily based on the joinError from
|
||||||
|
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
|
||||||
|
type joinError struct {
|
||||||
|
errs []error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *joinError) Error() string {
|
||||||
|
var b []byte
|
||||||
|
for i, err := range e.errs {
|
||||||
|
if i > 0 {
|
||||||
|
b = append(b, '\n')
|
||||||
|
}
|
||||||
|
b = append(b, err.Error()...)
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns another joinError with the same errors as the current
|
||||||
|
// joinError except the first error in the slice. Continuing to call Unwrap
|
||||||
|
// on each returned error will increment through every error in the slice. The
|
||||||
|
// resulting behavior when using [errors.Is] and [errors.As] is similar to an
|
||||||
|
// error created using [errors.Join] in Go 1.20+.
|
||||||
|
func (e *joinError) Unwrap() error {
|
||||||
|
if len(e.errs) == 1 {
|
||||||
|
return e.errs[0]
|
||||||
|
}
|
||||||
|
return &joinError{errs: e.errs[1:]}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is calls [errors.Is] with the first error in the slice.
|
||||||
|
func (e *joinError) Is(target error) bool {
|
||||||
|
if len(e.errs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return errors.Is(e.errs[0], target)
|
||||||
|
}
|
||||||
|
|
||||||
|
// As calls [errors.As] with the first error in the slice.
|
||||||
|
func (e *joinError) As(target interface{}) bool {
|
||||||
|
if len(e.errs) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return errors.As(e.errs[0], target)
|
||||||
|
}
|
20
internal/errutil/join_go1.19.go
Normal file
20
internal/errutil/join_go1.19.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
//go:build !go1.20
|
||||||
|
// +build !go1.20
|
||||||
|
|
||||||
|
package errutil
|
||||||
|
|
||||||
|
// Join returns an error that wraps the given errors. Any nil error values are
|
||||||
|
// discarded. Join returns nil if every value in errs is nil. The error formats
|
||||||
|
// as the concatenation of the strings obtained by calling the Error method of
|
||||||
|
// each element of errs, with a newline between each string.
|
||||||
|
//
|
||||||
|
// A non-nil error returned by Join implements the "Unwrap() error" method.
|
||||||
|
func Join(errs ...error) error {
|
||||||
|
return join(errs...)
|
||||||
|
}
|
17
internal/errutil/join_go1.20.go
Normal file
17
internal/errutil/join_go1.20.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
//go:build go1.20
|
||||||
|
// +build go1.20
|
||||||
|
|
||||||
|
package errutil
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// Join calls [errors.Join].
|
||||||
|
func Join(errs ...error) error {
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
243
internal/errutil/join_test.go
Normal file
243
internal/errutil/join_test.go
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package errutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestJoin_Nil asserts that join returns a nil error for the same inputs that
|
||||||
|
// [errors.Join] returns a nil error.
|
||||||
|
func TestJoin_Nil(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
assert.Equal(t, errors.Join(), join(), "errors.Join() != join()")
|
||||||
|
assert.Equal(t, errors.Join(nil), join(nil), "errors.Join(nil) != join(nil)")
|
||||||
|
assert.Equal(t, errors.Join(nil, nil), join(nil, nil), "errors.Join(nil, nil) != join(nil, nil)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJoin_Error asserts that join returns an error with the same error message
|
||||||
|
// as the error returned by [errors.Join].
|
||||||
|
func TestJoin_Error(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err1 := errors.New("err1")
|
||||||
|
err2 := errors.New("err2")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
errs []error
|
||||||
|
}{{
|
||||||
|
desc: "single error",
|
||||||
|
errs: []error{err1},
|
||||||
|
}, {
|
||||||
|
desc: "two errors",
|
||||||
|
errs: []error{err1, err2},
|
||||||
|
}, {
|
||||||
|
desc: "two errors and a nil value",
|
||||||
|
errs: []error{err1, nil, err2},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test // Capture range variable.
|
||||||
|
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
want := errors.Join(test.errs...).Error()
|
||||||
|
got := join(test.errs...).Error()
|
||||||
|
assert.Equal(t,
|
||||||
|
want,
|
||||||
|
got,
|
||||||
|
"errors.Join().Error() != join().Error() for input %v",
|
||||||
|
test.errs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
|
||||||
|
// to the error returned by [errors.Join] when passed to [errors.Is].
|
||||||
|
func TestJoin_ErrorsIs(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err1 := errors.New("err1")
|
||||||
|
err2 := errors.New("err2")
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
errs []error
|
||||||
|
target error
|
||||||
|
}{{
|
||||||
|
desc: "one error with a matching target",
|
||||||
|
errs: []error{err1},
|
||||||
|
target: err1,
|
||||||
|
}, {
|
||||||
|
desc: "one error with a non-matching target",
|
||||||
|
errs: []error{err1},
|
||||||
|
target: err2,
|
||||||
|
}, {
|
||||||
|
desc: "nil error",
|
||||||
|
errs: []error{nil},
|
||||||
|
target: err1,
|
||||||
|
}, {
|
||||||
|
desc: "no errors",
|
||||||
|
errs: []error{},
|
||||||
|
target: err1,
|
||||||
|
}, {
|
||||||
|
desc: "two different errors with a matching target",
|
||||||
|
errs: []error{err1, err2},
|
||||||
|
target: err2,
|
||||||
|
}, {
|
||||||
|
desc: "two identical errors with a matching target",
|
||||||
|
errs: []error{err1, err1},
|
||||||
|
target: err1,
|
||||||
|
}, {
|
||||||
|
desc: "wrapped error with a matching target",
|
||||||
|
errs: []error{fmt.Errorf("error: %w", err1)},
|
||||||
|
target: err1,
|
||||||
|
}, {
|
||||||
|
desc: "nested joined error with a matching target",
|
||||||
|
errs: []error{err1, join(err2, errors.New("nope"))},
|
||||||
|
target: err2,
|
||||||
|
}, {
|
||||||
|
desc: "nested joined error with no matching targets",
|
||||||
|
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
|
||||||
|
target: err2,
|
||||||
|
}, {
|
||||||
|
desc: "nested joined error with a wrapped matching target",
|
||||||
|
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
|
||||||
|
target: err1,
|
||||||
|
}, {
|
||||||
|
desc: "context.DeadlineExceeded",
|
||||||
|
errs: []error{err1, nil, context.DeadlineExceeded, err2},
|
||||||
|
target: context.DeadlineExceeded,
|
||||||
|
}, {
|
||||||
|
desc: "wrapped context.DeadlineExceeded",
|
||||||
|
errs: []error{err1, nil, fmt.Errorf("error: %w", context.DeadlineExceeded), err2},
|
||||||
|
target: context.DeadlineExceeded,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test // Capture range variable.
|
||||||
|
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
// Assert that top-level errors returned by errors.Join and join
|
||||||
|
// behave the same with errors.Is.
|
||||||
|
want := errors.Join(test.errs...)
|
||||||
|
got := join(test.errs...)
|
||||||
|
assert.Equal(t,
|
||||||
|
errors.Is(want, test.target),
|
||||||
|
errors.Is(got, test.target),
|
||||||
|
"errors.Join() and join() behave differently with errors.Is")
|
||||||
|
|
||||||
|
// Assert that wrapped errors returned by errors.Join and join
|
||||||
|
// behave the same with errors.Is.
|
||||||
|
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
|
||||||
|
got = fmt.Errorf("error: %w", join(test.errs...))
|
||||||
|
assert.Equal(t,
|
||||||
|
errors.Is(want, test.target),
|
||||||
|
errors.Is(got, test.target),
|
||||||
|
"errors.Join() and join(), when wrapped, behave differently with errors.Is")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type errType1 struct{}
|
||||||
|
|
||||||
|
func (errType1) Error() string { return "" }
|
||||||
|
|
||||||
|
type errType2 struct{}
|
||||||
|
|
||||||
|
func (errType2) Error() string { return "" }
|
||||||
|
|
||||||
|
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
|
||||||
|
// to the error returned by [errors.Join] when passed to [errors.As].
|
||||||
|
func TestJoin_ErrorsAs(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err1 := errType1{}
|
||||||
|
err2 := errType2{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
errs []error
|
||||||
|
target interface{}
|
||||||
|
}{{
|
||||||
|
desc: "one error with a matching target",
|
||||||
|
errs: []error{err1},
|
||||||
|
target: &errType1{},
|
||||||
|
}, {
|
||||||
|
desc: "one error with a non-matching target",
|
||||||
|
errs: []error{err1},
|
||||||
|
target: &errType2{},
|
||||||
|
}, {
|
||||||
|
desc: "nil error",
|
||||||
|
errs: []error{nil},
|
||||||
|
target: &errType1{},
|
||||||
|
}, {
|
||||||
|
desc: "no errors",
|
||||||
|
errs: []error{},
|
||||||
|
target: &errType1{},
|
||||||
|
}, {
|
||||||
|
desc: "two different errors with a matching target",
|
||||||
|
errs: []error{err1, err2},
|
||||||
|
target: &errType2{},
|
||||||
|
}, {
|
||||||
|
desc: "two identical errors with a matching target",
|
||||||
|
errs: []error{err1, err1},
|
||||||
|
target: &errType1{},
|
||||||
|
}, {
|
||||||
|
desc: "wrapped error with a matching target",
|
||||||
|
errs: []error{fmt.Errorf("error: %w", err1)},
|
||||||
|
target: &errType1{},
|
||||||
|
}, {
|
||||||
|
desc: "nested joined error with a matching target",
|
||||||
|
errs: []error{err1, join(err2, errors.New("nope"))},
|
||||||
|
target: &errType2{},
|
||||||
|
}, {
|
||||||
|
desc: "nested joined error with no matching targets",
|
||||||
|
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
|
||||||
|
target: &errType2{},
|
||||||
|
}, {
|
||||||
|
desc: "nested joined error with a wrapped matching target",
|
||||||
|
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
|
||||||
|
target: &errType1{},
|
||||||
|
}, {
|
||||||
|
desc: "context.DeadlineExceeded",
|
||||||
|
errs: []error{err1, nil, context.DeadlineExceeded, err2},
|
||||||
|
target: &errType2{},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test // Capture range variable.
|
||||||
|
|
||||||
|
t.Run(test.desc, func(t *testing.T) {
|
||||||
|
// Assert that top-level errors returned by errors.Join and join
|
||||||
|
// behave the same with errors.As.
|
||||||
|
want := errors.Join(test.errs...)
|
||||||
|
got := join(test.errs...)
|
||||||
|
assert.Equal(t,
|
||||||
|
errors.As(want, test.target),
|
||||||
|
errors.As(got, test.target),
|
||||||
|
"errors.Join() and join() behave differently with errors.As")
|
||||||
|
|
||||||
|
// Assert that wrapped errors returned by errors.Join and join
|
||||||
|
// behave the same with errors.As.
|
||||||
|
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
|
||||||
|
got = fmt.Errorf("error: %w", join(test.errs...))
|
||||||
|
assert.Equal(t,
|
||||||
|
errors.As(want, test.target),
|
||||||
|
errors.As(got, test.target),
|
||||||
|
"errors.Join() and join(), when wrapped, behave differently with errors.As")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
63
internal/failpoint/failpoint.go
Normal file
63
internal/failpoint/failpoint.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package failpoint
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ModeAlwaysOn is the fail point mode that enables the fail point for an
|
||||||
|
// indefinite number of matching commands.
|
||||||
|
ModeAlwaysOn = "alwaysOn"
|
||||||
|
|
||||||
|
// ModeOff is the fail point mode that disables the fail point.
|
||||||
|
ModeOff = "off"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FailPoint is used to configure a server fail point. It is intended to be
|
||||||
|
// passed as the command argument to RunCommand.
|
||||||
|
//
|
||||||
|
// For more information about fail points, see
|
||||||
|
// https://github.com/mongodb/specifications/tree/HEAD/source/transactions/tests#server-fail-point
|
||||||
|
type FailPoint struct {
|
||||||
|
ConfigureFailPoint string `bson:"configureFailPoint"`
|
||||||
|
// Mode should be a string, FailPointMode, or map[string]interface{}
|
||||||
|
Mode interface{} `bson:"mode"`
|
||||||
|
Data Data `bson:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mode configures when a fail point will be enabled. It is used to set the
|
||||||
|
// FailPoint.Mode field.
|
||||||
|
type Mode struct {
|
||||||
|
Times int32 `bson:"times"`
|
||||||
|
Skip int32 `bson:"skip"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data configures how a fail point will behave. It is used to set the
|
||||||
|
// FailPoint.Data field.
|
||||||
|
type Data struct {
|
||||||
|
FailCommands []string `bson:"failCommands,omitempty"`
|
||||||
|
CloseConnection bool `bson:"closeConnection,omitempty"`
|
||||||
|
ErrorCode int32 `bson:"errorCode,omitempty"`
|
||||||
|
FailBeforeCommitExceptionCode int32 `bson:"failBeforeCommitExceptionCode,omitempty"`
|
||||||
|
ErrorLabels *[]string `bson:"errorLabels,omitempty"`
|
||||||
|
WriteConcernError *WriteConcernError `bson:"writeConcernError,omitempty"`
|
||||||
|
BlockConnection bool `bson:"blockConnection,omitempty"`
|
||||||
|
BlockTimeMS int32 `bson:"blockTimeMS,omitempty"`
|
||||||
|
AppName string `bson:"appName,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteConcernError is the write concern error to return when the fail point is
|
||||||
|
// triggered. It is used to set the FailPoint.Data.WriteConcernError field.
|
||||||
|
type WriteConcernError struct {
|
||||||
|
Code int32 `bson:"code"`
|
||||||
|
Name string `bson:"codeName"`
|
||||||
|
Errmsg string `bson:"errmsg"`
|
||||||
|
ErrorLabels *[]string `bson:"errorLabels,omitempty"`
|
||||||
|
ErrInfo bson.Raw `bson:"errInfo,omitempty"`
|
||||||
|
}
|
13
internal/handshake/handshake.go
Normal file
13
internal/handshake/handshake.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package handshake
|
||||||
|
|
||||||
|
// LegacyHello is the legacy version of the hello command.
|
||||||
|
var LegacyHello = "isMaster"
|
||||||
|
|
||||||
|
// LegacyHelloLowercase is the lowercase, legacy version of the hello command.
|
||||||
|
var LegacyHelloLowercase = "ismaster"
|
30
internal/httputil/httputil.go
Normal file
30
internal/httputil/httputil.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package httputil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultHTTPClient is the default HTTP client used across the driver.
|
||||||
|
var DefaultHTTPClient = &http.Client{
|
||||||
|
Transport: http.DefaultTransport.(*http.Transport).Clone(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIdleHTTPConnections closes any connections which were previously
|
||||||
|
// connected from previous requests but are now sitting idle in a "keep-alive"
|
||||||
|
// state. It does not interrupt any connections currently in use.
|
||||||
|
//
|
||||||
|
// Borrowed from the Go standard library.
|
||||||
|
func CloseIdleHTTPConnections(client *http.Client) {
|
||||||
|
type closeIdler interface {
|
||||||
|
CloseIdleConnections()
|
||||||
|
}
|
||||||
|
if tr, ok := client.Transport.(closeIdler); ok {
|
||||||
|
tr.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
}
|
14
internal/israce/norace.go
Normal file
14
internal/israce/norace.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
//go:build !race
|
||||||
|
// +build !race
|
||||||
|
|
||||||
|
// Package israce reports if the Go race detector is enabled.
|
||||||
|
package israce
|
||||||
|
|
||||||
|
// Enabled reports if the race detector is enabled.
|
||||||
|
const Enabled = false
|
14
internal/israce/race.go
Normal file
14
internal/israce/race.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
//go:build race
|
||||||
|
// +build race
|
||||||
|
|
||||||
|
// Package israce reports if the Go race detector is enabled.
|
||||||
|
package israce
|
||||||
|
|
||||||
|
// Enabled reports if the race detector is enabled.
|
||||||
|
const Enabled = true
|
313
internal/logger/component.go
Normal file
313
internal/logger/component.go
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CommandFailed = "Command failed"
|
||||||
|
CommandStarted = "Command started"
|
||||||
|
CommandSucceeded = "Command succeeded"
|
||||||
|
ConnectionPoolCreated = "Connection pool created"
|
||||||
|
ConnectionPoolReady = "Connection pool ready"
|
||||||
|
ConnectionPoolCleared = "Connection pool cleared"
|
||||||
|
ConnectionPoolClosed = "Connection pool closed"
|
||||||
|
ConnectionCreated = "Connection created"
|
||||||
|
ConnectionReady = "Connection ready"
|
||||||
|
ConnectionClosed = "Connection closed"
|
||||||
|
ConnectionCheckoutStarted = "Connection checkout started"
|
||||||
|
ConnectionCheckoutFailed = "Connection checkout failed"
|
||||||
|
ConnectionCheckedOut = "Connection checked out"
|
||||||
|
ConnectionCheckedIn = "Connection checked in"
|
||||||
|
ServerSelectionFailed = "Server selection failed"
|
||||||
|
ServerSelectionStarted = "Server selection started"
|
||||||
|
ServerSelectionSucceeded = "Server selection succeeded"
|
||||||
|
ServerSelectionWaiting = "Waiting for suitable server to become available"
|
||||||
|
TopologyClosed = "Stopped topology monitoring"
|
||||||
|
TopologyDescriptionChanged = "Topology description changed"
|
||||||
|
TopologyOpening = "Starting topology monitoring"
|
||||||
|
TopologyServerClosed = "Stopped server monitoring"
|
||||||
|
TopologyServerHeartbeatFailed = "Server heartbeat failed"
|
||||||
|
TopologyServerHeartbeatStarted = "Server heartbeat started"
|
||||||
|
TopologyServerHeartbeatSucceeded = "Server heartbeat succeeded"
|
||||||
|
TopologyServerOpening = "Starting server monitoring"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
KeyAwaited = "awaited"
|
||||||
|
KeyCommand = "command"
|
||||||
|
KeyCommandName = "commandName"
|
||||||
|
KeyDatabaseName = "databaseName"
|
||||||
|
KeyDriverConnectionID = "driverConnectionId"
|
||||||
|
KeyDurationMS = "durationMS"
|
||||||
|
KeyError = "error"
|
||||||
|
KeyFailure = "failure"
|
||||||
|
KeyMaxConnecting = "maxConnecting"
|
||||||
|
KeyMaxIdleTimeMS = "maxIdleTimeMS"
|
||||||
|
KeyMaxPoolSize = "maxPoolSize"
|
||||||
|
KeyMessage = "message"
|
||||||
|
KeyMinPoolSize = "minPoolSize"
|
||||||
|
KeyNewDescription = "newDescription"
|
||||||
|
KeyOperation = "operation"
|
||||||
|
KeyOperationID = "operationId"
|
||||||
|
KeyPreviousDescription = "previousDescription"
|
||||||
|
KeyRemainingTimeMS = "remainingTimeMS"
|
||||||
|
KeyReason = "reason"
|
||||||
|
KeyReply = "reply"
|
||||||
|
KeyRequestID = "requestId"
|
||||||
|
KeySelector = "selector"
|
||||||
|
KeyServerConnectionID = "serverConnectionId"
|
||||||
|
KeyServerHost = "serverHost"
|
||||||
|
KeyServerPort = "serverPort"
|
||||||
|
KeyServiceID = "serviceId"
|
||||||
|
KeyTimestamp = "timestamp"
|
||||||
|
KeyTopologyDescription = "topologyDescription"
|
||||||
|
KeyTopologyID = "topologyId"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyValues is a list of key-value pairs.
|
||||||
|
type KeyValues []interface{}
|
||||||
|
|
||||||
|
// Add adds a key-value pair to an instance of a KeyValues list.
|
||||||
|
func (kvs *KeyValues) Add(key string, value interface{}) {
|
||||||
|
*kvs = append(*kvs, key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReasonConnClosedStale = "Connection became stale because the pool was cleared"
|
||||||
|
ReasonConnClosedIdle = "Connection has been available but unused for longer than the configured max idle time"
|
||||||
|
ReasonConnClosedError = "An error occurred while using the connection"
|
||||||
|
ReasonConnClosedPoolClosed = "Connection pool was closed"
|
||||||
|
ReasonConnCheckoutFailedTimout = "Wait queue timeout elapsed without a connection becoming available"
|
||||||
|
ReasonConnCheckoutFailedError = "An error occurred while trying to establish a new connection"
|
||||||
|
ReasonConnCheckoutFailedPoolClosed = "Connection pool was closed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Component is an enumeration representing the "components" which can be
|
||||||
|
// logged against. A LogLevel can be configured on a per-component basis.
|
||||||
|
type Component int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ComponentAll enables logging for all components.
|
||||||
|
ComponentAll Component = iota
|
||||||
|
|
||||||
|
// ComponentCommand enables command monitor logging.
|
||||||
|
ComponentCommand
|
||||||
|
|
||||||
|
// ComponentTopology enables topology logging.
|
||||||
|
ComponentTopology
|
||||||
|
|
||||||
|
// ComponentServerSelection enables server selection logging.
|
||||||
|
ComponentServerSelection
|
||||||
|
|
||||||
|
// ComponentConnection enables connection services logging.
|
||||||
|
ComponentConnection
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
mongoDBLogAllEnvVar = "MONGODB_LOG_ALL"
|
||||||
|
mongoDBLogCommandEnvVar = "MONGODB_LOG_COMMAND"
|
||||||
|
mongoDBLogTopologyEnvVar = "MONGODB_LOG_TOPOLOGY"
|
||||||
|
mongoDBLogServerSelectionEnvVar = "MONGODB_LOG_SERVER_SELECTION"
|
||||||
|
mongoDBLogConnectionEnvVar = "MONGODB_LOG_CONNECTION"
|
||||||
|
)
|
||||||
|
|
||||||
|
var componentEnvVarMap = map[string]Component{
|
||||||
|
mongoDBLogAllEnvVar: ComponentAll,
|
||||||
|
mongoDBLogCommandEnvVar: ComponentCommand,
|
||||||
|
mongoDBLogTopologyEnvVar: ComponentTopology,
|
||||||
|
mongoDBLogServerSelectionEnvVar: ComponentServerSelection,
|
||||||
|
mongoDBLogConnectionEnvVar: ComponentConnection,
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnvHasComponentVariables returns true if the environment contains any of the
|
||||||
|
// component environment variables.
|
||||||
|
func EnvHasComponentVariables() bool {
|
||||||
|
for envVar := range componentEnvVarMap {
|
||||||
|
if os.Getenv(envVar) != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Command is a struct defining common fields that must be included in all
|
||||||
|
// commands.
|
||||||
|
type Command struct {
|
||||||
|
DriverConnectionID int64 // Driver's ID for the connection
|
||||||
|
Name string // Command name
|
||||||
|
DatabaseName string // Database name
|
||||||
|
Message string // Message associated with the command
|
||||||
|
OperationID int32 // Driver-generated operation ID
|
||||||
|
RequestID int64 // Driver-generated request ID
|
||||||
|
ServerConnectionID *int64 // Server's ID for the connection used for the command
|
||||||
|
ServerHost string // Hostname or IP address for the server
|
||||||
|
ServerPort string // Port for the server
|
||||||
|
ServiceID *bson.ObjectID // ID for the command in load balancer mode
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerializeCommand takes a command and a variable number of key-value pairs and
|
||||||
|
// returns a slice of interface{} that can be passed to the logger for
|
||||||
|
// structured logging.
|
||||||
|
func SerializeCommand(cmd Command, extraKeysAndValues ...interface{}) KeyValues {
|
||||||
|
// Initialize the boilerplate keys and values.
|
||||||
|
keysAndValues := KeyValues{
|
||||||
|
KeyCommandName, cmd.Name,
|
||||||
|
KeyDatabaseName, cmd.DatabaseName,
|
||||||
|
KeyDriverConnectionID, cmd.DriverConnectionID,
|
||||||
|
KeyMessage, cmd.Message,
|
||||||
|
KeyOperationID, cmd.OperationID,
|
||||||
|
KeyRequestID, cmd.RequestID,
|
||||||
|
KeyServerHost, cmd.ServerHost,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the extra keys and values.
|
||||||
|
for i := 0; i < len(extraKeysAndValues); i += 2 {
|
||||||
|
keysAndValues.Add(extraKeysAndValues[i].(string), extraKeysAndValues[i+1])
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := strconv.ParseInt(cmd.ServerPort, 10, 32)
|
||||||
|
if err == nil {
|
||||||
|
keysAndValues.Add(KeyServerPort, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the "serverConnectionId" if it is not nil.
|
||||||
|
if cmd.ServerConnectionID != nil {
|
||||||
|
keysAndValues.Add(KeyServerConnectionID, *cmd.ServerConnectionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the "serviceId" if it is not nil.
|
||||||
|
if cmd.ServiceID != nil {
|
||||||
|
keysAndValues.Add(KeyServiceID, cmd.ServiceID.Hex())
|
||||||
|
}
|
||||||
|
|
||||||
|
return keysAndValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connection contains data that all connection log messages MUST contain.
|
||||||
|
type Connection struct {
|
||||||
|
Message string // Message associated with the connection
|
||||||
|
ServerHost string // Hostname or IP address for the server
|
||||||
|
ServerPort string // Port for the server
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerializeConnection serializes a Connection message into a slice of keys and
|
||||||
|
// values that can be passed to a logger.
|
||||||
|
func SerializeConnection(conn Connection, extraKeysAndValues ...interface{}) KeyValues {
|
||||||
|
// Initialize the boilerplate keys and values.
|
||||||
|
keysAndValues := KeyValues{
|
||||||
|
KeyMessage, conn.Message,
|
||||||
|
KeyServerHost, conn.ServerHost,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the optional keys and values.
|
||||||
|
for i := 0; i < len(extraKeysAndValues); i += 2 {
|
||||||
|
keysAndValues.Add(extraKeysAndValues[i].(string), extraKeysAndValues[i+1])
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := strconv.ParseInt(conn.ServerPort, 10, 32)
|
||||||
|
if err == nil {
|
||||||
|
keysAndValues.Add(KeyServerPort, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return keysAndValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server contains data that all server messages MAY contain.
|
||||||
|
type Server struct {
|
||||||
|
DriverConnectionID int64 // Driver's ID for the connection
|
||||||
|
TopologyID bson.ObjectID // Driver's unique ID for this topology
|
||||||
|
Message string // Message associated with the topology
|
||||||
|
ServerConnectionID *int64 // Server's ID for the connection
|
||||||
|
ServerHost string // Hostname or IP address for the server
|
||||||
|
ServerPort string // Port for the server
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerializeServer serializes a Server message into a slice of keys and
|
||||||
|
// values that can be passed to a logger.
|
||||||
|
func SerializeServer(srv Server, extraKV ...interface{}) KeyValues {
|
||||||
|
// Initialize the boilerplate keys and values.
|
||||||
|
keysAndValues := KeyValues{
|
||||||
|
KeyDriverConnectionID, srv.DriverConnectionID,
|
||||||
|
KeyMessage, srv.Message,
|
||||||
|
KeyServerHost, srv.ServerHost,
|
||||||
|
KeyTopologyID, srv.TopologyID.Hex(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if connID := srv.ServerConnectionID; connID != nil {
|
||||||
|
keysAndValues.Add(KeyServerConnectionID, *connID)
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := strconv.ParseInt(srv.ServerPort, 10, 32)
|
||||||
|
if err == nil {
|
||||||
|
keysAndValues.Add(KeyServerPort, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the optional keys and values.
|
||||||
|
for i := 0; i < len(extraKV); i += 2 {
|
||||||
|
keysAndValues.Add(extraKV[i].(string), extraKV[i+1])
|
||||||
|
}
|
||||||
|
|
||||||
|
return keysAndValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerSelection contains data that all server selection messages MUST
|
||||||
|
// contain.
|
||||||
|
type ServerSelection struct {
|
||||||
|
Selector string
|
||||||
|
OperationID *int32
|
||||||
|
Operation string
|
||||||
|
TopologyDescription string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerializeServerSelection serializes a Topology message into a slice of keys
|
||||||
|
// and values that can be passed to a logger.
|
||||||
|
func SerializeServerSelection(srvSelection ServerSelection, extraKV ...interface{}) KeyValues {
|
||||||
|
keysAndValues := KeyValues{
|
||||||
|
KeySelector, srvSelection.Selector,
|
||||||
|
KeyOperation, srvSelection.Operation,
|
||||||
|
KeyTopologyDescription, srvSelection.TopologyDescription,
|
||||||
|
}
|
||||||
|
|
||||||
|
if srvSelection.OperationID != nil {
|
||||||
|
keysAndValues.Add(KeyOperationID, *srvSelection.OperationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the optional keys and values.
|
||||||
|
for i := 0; i < len(extraKV); i += 2 {
|
||||||
|
keysAndValues.Add(extraKV[i].(string), extraKV[i+1])
|
||||||
|
}
|
||||||
|
|
||||||
|
return keysAndValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// Topology contains data that all topology messages MAY contain.
|
||||||
|
type Topology struct {
|
||||||
|
ID bson.ObjectID // Driver's unique ID for this topology
|
||||||
|
Message string // Message associated with the topology
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerializeTopology serializes a Topology message into a slice of keys and
|
||||||
|
// values that can be passed to a logger.
|
||||||
|
func SerializeTopology(topo Topology, extraKV ...interface{}) KeyValues {
|
||||||
|
keysAndValues := KeyValues{
|
||||||
|
KeyTopologyID, topo.ID.Hex(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the optional keys and values.
|
||||||
|
for i := 0; i < len(extraKV); i += 2 {
|
||||||
|
keysAndValues.Add(extraKV[i].(string), extraKV[i+1])
|
||||||
|
}
|
||||||
|
|
||||||
|
return keysAndValues
|
||||||
|
}
|
229
internal/logger/component_test.go
Normal file
229
internal/logger/component_test.go
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func verifySerialization(t *testing.T, got, want KeyValues) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for i := 0; i < len(got); i += 2 {
|
||||||
|
assert.Equal(t, want[i], got[i], "key position mismatch")
|
||||||
|
assert.Equal(t, want[i+1], got[i+1], "value position mismatch for %q", want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializeCommand(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
serverConnectionID := int64(100)
|
||||||
|
serviceID := bson.NewObjectID()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cmd Command
|
||||||
|
extraKeysAndValues []interface{}
|
||||||
|
want KeyValues
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
want: KeyValues{
|
||||||
|
KeyCommandName, "",
|
||||||
|
KeyDatabaseName, "",
|
||||||
|
KeyDriverConnectionID, int64(0),
|
||||||
|
KeyMessage, "",
|
||||||
|
KeyOperationID, int32(0),
|
||||||
|
KeyRequestID, int64(0),
|
||||||
|
KeyServerHost, "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complete Command object",
|
||||||
|
cmd: Command{
|
||||||
|
DriverConnectionID: 1,
|
||||||
|
Name: "foo",
|
||||||
|
DatabaseName: "db",
|
||||||
|
Message: "bar",
|
||||||
|
OperationID: 2,
|
||||||
|
RequestID: 3,
|
||||||
|
ServerHost: "localhost",
|
||||||
|
ServerPort: "27017",
|
||||||
|
ServerConnectionID: &serverConnectionID,
|
||||||
|
ServiceID: &serviceID,
|
||||||
|
},
|
||||||
|
want: KeyValues{
|
||||||
|
KeyCommandName, "foo",
|
||||||
|
KeyDatabaseName, "db",
|
||||||
|
KeyDriverConnectionID, int64(1),
|
||||||
|
KeyMessage, "bar",
|
||||||
|
KeyOperationID, int32(2),
|
||||||
|
KeyRequestID, int64(3),
|
||||||
|
KeyServerHost, "localhost",
|
||||||
|
KeyServerPort, int64(27017),
|
||||||
|
KeyServerConnectionID, serverConnectionID,
|
||||||
|
KeyServiceID, serviceID.Hex(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := SerializeCommand(test.cmd, test.extraKeysAndValues...)
|
||||||
|
verifySerialization(t, got, test.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializeConnection(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
conn Connection
|
||||||
|
extraKeysAndValues []interface{}
|
||||||
|
want KeyValues
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
want: KeyValues{
|
||||||
|
KeyMessage, "",
|
||||||
|
KeyServerHost, "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complete Connection object",
|
||||||
|
conn: Connection{
|
||||||
|
Message: "foo",
|
||||||
|
ServerHost: "localhost",
|
||||||
|
ServerPort: "27017",
|
||||||
|
},
|
||||||
|
want: KeyValues{
|
||||||
|
"message", "foo",
|
||||||
|
"serverHost", "localhost",
|
||||||
|
"serverPort", int64(27017),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := SerializeConnection(test.conn, test.extraKeysAndValues...)
|
||||||
|
verifySerialization(t, got, test.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializeServer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
topologyID := bson.NewObjectID()
|
||||||
|
serverConnectionID := int64(100)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
srv Server
|
||||||
|
extraKeysAndValues []interface{}
|
||||||
|
want KeyValues
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
want: KeyValues{
|
||||||
|
KeyDriverConnectionID, int64(0),
|
||||||
|
KeyMessage, "",
|
||||||
|
KeyServerHost, "",
|
||||||
|
KeyTopologyID, bson.ObjectID{}.Hex(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complete Server object",
|
||||||
|
srv: Server{
|
||||||
|
DriverConnectionID: 1,
|
||||||
|
TopologyID: topologyID,
|
||||||
|
Message: "foo",
|
||||||
|
ServerConnectionID: &serverConnectionID,
|
||||||
|
ServerHost: "localhost",
|
||||||
|
ServerPort: "27017",
|
||||||
|
},
|
||||||
|
want: KeyValues{
|
||||||
|
KeyDriverConnectionID, int64(1),
|
||||||
|
KeyMessage, "foo",
|
||||||
|
KeyServerHost, "localhost",
|
||||||
|
KeyTopologyID, topologyID.Hex(),
|
||||||
|
KeyServerConnectionID, serverConnectionID,
|
||||||
|
KeyServerPort, int64(27017),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := SerializeServer(test.srv, test.extraKeysAndValues...)
|
||||||
|
verifySerialization(t, got, test.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializeTopology(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
topologyID := bson.NewObjectID()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
topo Topology
|
||||||
|
extraKeysAndValues []interface{}
|
||||||
|
want KeyValues
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
want: KeyValues{
|
||||||
|
KeyTopologyID, bson.ObjectID{}.Hex(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complete Server object",
|
||||||
|
topo: Topology{
|
||||||
|
ID: topologyID,
|
||||||
|
Message: "foo",
|
||||||
|
},
|
||||||
|
want: KeyValues{
|
||||||
|
KeyTopologyID, topologyID.Hex(),
|
||||||
|
KeyMessage, "foo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
test := test
|
||||||
|
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := SerializeTopology(test.topo, test.extraKeysAndValues...)
|
||||||
|
verifySerialization(t, got, test.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
48
internal/logger/context.go
Normal file
48
internal/logger/context.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// contextKey is a custom type used to prevent key collisions when using the
|
||||||
|
// context package.
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
contextKeyOperation contextKey = "operation"
|
||||||
|
contextKeyOperationID contextKey = "operationID"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithOperationName adds the operation name to the context.
|
||||||
|
func WithOperationName(ctx context.Context, operation string) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyOperation, operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOperationID adds the operation ID to the context.
|
||||||
|
func WithOperationID(ctx context.Context, operationID int32) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyOperationID, operationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OperationName returns the operation name from the context.
|
||||||
|
func OperationName(ctx context.Context) (string, bool) {
|
||||||
|
operationName := ctx.Value(contextKeyOperation)
|
||||||
|
if operationName == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return operationName.(string), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// OperationID returns the operation ID from the context.
|
||||||
|
func OperationID(ctx context.Context) (int32, bool) {
|
||||||
|
operationID := ctx.Value(contextKeyOperationID)
|
||||||
|
if operationID == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return operationID.(int32), true
|
||||||
|
}
|
187
internal/logger/context_test.go
Normal file
187
internal/logger/context_test.go
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package logger_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContext_WithOperationName(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
opName string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple",
|
||||||
|
ctx: context.Background(),
|
||||||
|
opName: "foo",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt // Capture the range variable.
|
||||||
|
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := logger.WithOperationName(tt.ctx, tt.opName)
|
||||||
|
|
||||||
|
opName, ok := logger.OperationName(ctx)
|
||||||
|
assert.Equal(t, tt.ok, ok)
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
assert.Equal(t, tt.opName, opName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContext_OperationName(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
opName interface{}
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
ctx: context.Background(),
|
||||||
|
opName: nil,
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string type",
|
||||||
|
ctx: context.Background(),
|
||||||
|
opName: "foo",
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-string type",
|
||||||
|
ctx: context.Background(),
|
||||||
|
opName: int32(1),
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt // Capture the range variable.
|
||||||
|
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if opNameStr, ok := tt.opName.(string); ok {
|
||||||
|
ctx = logger.WithOperationName(tt.ctx, opNameStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
opName, ok := logger.OperationName(ctx)
|
||||||
|
assert.Equal(t, tt.ok, ok)
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
assert.Equal(t, tt.opName, opName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContext_WithOperationID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
opID int32
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non-zero",
|
||||||
|
ctx: context.Background(),
|
||||||
|
opID: 1,
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt // Capture the range variable.
|
||||||
|
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := logger.WithOperationID(tt.ctx, tt.opID)
|
||||||
|
|
||||||
|
opID, ok := logger.OperationID(ctx)
|
||||||
|
assert.Equal(t, tt.ok, ok)
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
assert.Equal(t, tt.opID, opID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContext_OperationID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ctx context.Context
|
||||||
|
opID interface{}
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
ctx: context.Background(),
|
||||||
|
opID: nil,
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "i32 type",
|
||||||
|
ctx: context.Background(),
|
||||||
|
opID: int32(1),
|
||||||
|
ok: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-i32 type",
|
||||||
|
ctx: context.Background(),
|
||||||
|
opID: "foo",
|
||||||
|
ok: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt // Capture the range variable.
|
||||||
|
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if opIDI32, ok := tt.opID.(int32); ok {
|
||||||
|
ctx = logger.WithOperationID(tt.ctx, opIDI32)
|
||||||
|
}
|
||||||
|
|
||||||
|
opName, ok := logger.OperationID(ctx)
|
||||||
|
assert.Equal(t, tt.ok, ok)
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
assert.Equal(t, tt.opID, opName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
63
internal/logger/io_sink.go
Normal file
63
internal/logger/io_sink.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IOSink writes a JSON-encoded message to the io.Writer.
|
||||||
|
type IOSink struct {
|
||||||
|
enc *json.Encoder
|
||||||
|
|
||||||
|
// encMu protects the encoder from concurrent writes. While the logger
|
||||||
|
// itself does not concurrently write to the sink, the sink may be used
|
||||||
|
// concurrently within the driver.
|
||||||
|
encMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compile-time check to ensure IOSink implements the LogSink interface.
|
||||||
|
var _ LogSink = &IOSink{}
|
||||||
|
|
||||||
|
// NewIOSink will create an IOSink object that writes JSON messages to the
|
||||||
|
// provided io.Writer.
|
||||||
|
func NewIOSink(out io.Writer) *IOSink {
|
||||||
|
return &IOSink{
|
||||||
|
enc: json.NewEncoder(out),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info will write a JSON-encoded message to the io.Writer.
|
||||||
|
func (sink *IOSink) Info(_ int, msg string, keysAndValues ...interface{}) {
|
||||||
|
mapSize := len(keysAndValues) / 2
|
||||||
|
if math.MaxInt-mapSize >= 2 {
|
||||||
|
mapSize += 2
|
||||||
|
}
|
||||||
|
kvMap := make(map[string]interface{}, mapSize)
|
||||||
|
|
||||||
|
kvMap[KeyTimestamp] = time.Now().UnixNano()
|
||||||
|
kvMap[KeyMessage] = msg
|
||||||
|
|
||||||
|
for i := 0; i < len(keysAndValues); i += 2 {
|
||||||
|
kvMap[keysAndValues[i].(string)] = keysAndValues[i+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
sink.encMu.Lock()
|
||||||
|
defer sink.encMu.Unlock()
|
||||||
|
|
||||||
|
_ = sink.enc.Encode(kvMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error will write a JSON-encoded error message to the io.Writer.
|
||||||
|
func (sink *IOSink) Error(err error, msg string, kv ...interface{}) {
|
||||||
|
kv = append(kv, KeyError, err.Error())
|
||||||
|
sink.Info(0, msg, kv...)
|
||||||
|
}
|
74
internal/logger/level.go
Normal file
74
internal/logger/level.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// DiffToInfo is the number of levels in the Go Driver that come before the
|
||||||
|
// "Info" level. This should ensure that "Info" is the 0th level passed to the
|
||||||
|
// sink.
|
||||||
|
const DiffToInfo = 1
|
||||||
|
|
||||||
|
// Level is an enumeration representing the log severity levels supported by
|
||||||
|
// the driver. The order of the logging levels is important. The driver expects
|
||||||
|
// that a user will likely use the "logr" package to create a LogSink, which
|
||||||
|
// defaults InfoLevel as 0. Any additions to the Level enumeration before the
|
||||||
|
// InfoLevel will need to also update the "diffToInfo" constant.
|
||||||
|
type Level int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// LevelOff suppresses logging.
|
||||||
|
LevelOff Level = iota
|
||||||
|
|
||||||
|
// LevelInfo enables logging of informational messages. These logs are
|
||||||
|
// high-level information about normal driver behavior.
|
||||||
|
LevelInfo
|
||||||
|
|
||||||
|
// LevelDebug enables logging of debug messages. These logs can be
|
||||||
|
// voluminous and are intended for detailed information that may be
|
||||||
|
// helpful when debugging an application.
|
||||||
|
LevelDebug
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
levelLiteralOff = "off"
|
||||||
|
levelLiteralEmergency = "emergency"
|
||||||
|
levelLiteralAlert = "alert"
|
||||||
|
levelLiteralCritical = "critical"
|
||||||
|
levelLiteralError = "error"
|
||||||
|
levelLiteralWarning = "warning"
|
||||||
|
levelLiteralNotice = "notice"
|
||||||
|
levelLiteralInfo = "info"
|
||||||
|
levelLiteralDebug = "debug"
|
||||||
|
levelLiteralTrace = "trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
var LevelLiteralMap = map[string]Level{
|
||||||
|
levelLiteralOff: LevelOff,
|
||||||
|
levelLiteralEmergency: LevelInfo,
|
||||||
|
levelLiteralAlert: LevelInfo,
|
||||||
|
levelLiteralCritical: LevelInfo,
|
||||||
|
levelLiteralError: LevelInfo,
|
||||||
|
levelLiteralWarning: LevelInfo,
|
||||||
|
levelLiteralNotice: LevelInfo,
|
||||||
|
levelLiteralInfo: LevelInfo,
|
||||||
|
levelLiteralDebug: LevelDebug,
|
||||||
|
levelLiteralTrace: LevelDebug,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseLevel will check if the given string is a valid environment variable
|
||||||
|
// for a logging severity level. If it is, then it will return the associated
|
||||||
|
// driver's Level. The default Level is “LevelOff”.
|
||||||
|
func ParseLevel(str string) Level {
|
||||||
|
for literal, level := range LevelLiteralMap {
|
||||||
|
if strings.EqualFold(literal, str) {
|
||||||
|
return level
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return LevelOff
|
||||||
|
}
|
265
internal/logger/logger.go
Normal file
265
internal/logger/logger.go
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
// Package logger provides the internal logging solution for the MongoDB Go
|
||||||
|
// Driver.
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/bsoncoreutil"
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultMaxDocumentLength is the default maximum number of bytes that can be
|
||||||
|
// logged for a stringified BSON document.
|
||||||
|
const DefaultMaxDocumentLength = 1000
|
||||||
|
|
||||||
|
// TruncationSuffix are trailing ellipsis "..." appended to a message to
|
||||||
|
// indicate to the user that truncation occurred. This constant does not count
|
||||||
|
// toward the max document length.
|
||||||
|
const TruncationSuffix = "..."
|
||||||
|
|
||||||
|
const logSinkPathEnvVar = "MONGODB_LOG_PATH"
|
||||||
|
const maxDocumentLengthEnvVar = "MONGODB_LOG_MAX_DOCUMENT_LENGTH"
|
||||||
|
|
||||||
|
// LogSink represents a logging implementation, this interface should be 1-1
|
||||||
|
// with the exported "LogSink" interface in the mongo/options package.
|
||||||
|
type LogSink interface {
|
||||||
|
// Info logs a non-error message with the given key/value pairs. The
|
||||||
|
// level argument is provided for optional logging.
|
||||||
|
Info(level int, msg string, keysAndValues ...interface{})
|
||||||
|
|
||||||
|
// Error logs an error, with the given message and key/value pairs.
|
||||||
|
Error(err error, msg string, keysAndValues ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger represents the configuration for the internal logger.
|
||||||
|
type Logger struct {
|
||||||
|
ComponentLevels map[Component]Level // Log levels for each component.
|
||||||
|
Sink LogSink // LogSink for log printing.
|
||||||
|
MaxDocumentLength uint // Command truncation width.
|
||||||
|
logFile *os.File // File to write logs to.
|
||||||
|
}
|
||||||
|
|
||||||
|
// New will construct a new logger. If any of the given options are the
|
||||||
|
// zero-value of the argument type, then the constructor will attempt to
|
||||||
|
// source the data from the environment. If the environment has not been set,
|
||||||
|
// then the constructor will the respective default values.
|
||||||
|
func New(sink LogSink, maxDocLen uint, compLevels map[Component]Level) (*Logger, error) {
|
||||||
|
logger := &Logger{
|
||||||
|
ComponentLevels: selectComponentLevels(compLevels),
|
||||||
|
MaxDocumentLength: selectMaxDocumentLength(maxDocLen),
|
||||||
|
}
|
||||||
|
|
||||||
|
sink, logFile, err := selectLogSink(sink)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Sink = sink
|
||||||
|
logger.logFile = logFile
|
||||||
|
|
||||||
|
return logger, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close will close the logger's log file, if it exists.
|
||||||
|
func (logger *Logger) Close() error {
|
||||||
|
if logger.logFile != nil {
|
||||||
|
return logger.logFile.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LevelComponentEnabled will return true if the given LogLevel is enabled for
|
||||||
|
// the given LogComponent. If the ComponentLevels on the logger are enabled for
|
||||||
|
// "ComponentAll", then this function will return true for any level bound by
|
||||||
|
// the level assigned to "ComponentAll".
|
||||||
|
//
|
||||||
|
// If the level is not enabled (i.e. LevelOff), then false is returned. This is
|
||||||
|
// to avoid false positives, such as returning "true" for a component that is
|
||||||
|
// not enabled. For example, without this condition, an empty LevelComponent
|
||||||
|
// would be considered "enabled" for "LevelOff".
|
||||||
|
func (logger *Logger) LevelComponentEnabled(level Level, component Component) bool {
|
||||||
|
if level == LevelOff {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if logger.ComponentLevels == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return logger.ComponentLevels[component] >= level ||
|
||||||
|
logger.ComponentLevels[ComponentAll] >= level
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print will synchronously print the given message to the configured LogSink.
|
||||||
|
// If the LogSink is nil, then this method will do nothing. Future work could be done to make
|
||||||
|
// this method asynchronous, see buffer management in libraries such as log4j.
|
||||||
|
//
|
||||||
|
// It's worth noting that many structured logs defined by DBX-wide
|
||||||
|
// specifications include a "message" field, which is often shared with the
|
||||||
|
// message arguments passed to this print function. The "Info" method used by
|
||||||
|
// this function is implemented based on the go-logr/logr LogSink interface,
|
||||||
|
// which is why "Print" has a message parameter. Any duplication in code is
|
||||||
|
// intentional to adhere to the logr pattern.
|
||||||
|
func (logger *Logger) Print(level Level, component Component, msg string, keysAndValues ...interface{}) {
|
||||||
|
// If the level is not enabled for the component, then
|
||||||
|
// skip the message.
|
||||||
|
if !logger.LevelComponentEnabled(level, component) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the sink is nil, then skip the message.
|
||||||
|
if logger.Sink == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Sink.Info(int(level)-DiffToInfo, msg, keysAndValues...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error logs an error, with the given message and key/value pairs.
|
||||||
|
// It functions similarly to Print, but may have unique behavior, and should be
|
||||||
|
// preferred for logging errors.
|
||||||
|
func (logger *Logger) Error(err error, msg string, keysAndValues ...interface{}) {
|
||||||
|
if logger.Sink == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Sink.Error(err, msg, keysAndValues...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectMaxDocumentLength will return the integer value of the first non-zero
|
||||||
|
// function, with the user-defined function taking priority over the environment
|
||||||
|
// variables. For the environment, the function will attempt to get the value of
|
||||||
|
// "MONGODB_LOG_MAX_DOCUMENT_LENGTH" and parse it as an unsigned integer. If the
|
||||||
|
// environment variable is not set or is not an unsigned integer, then this
|
||||||
|
// function will return the default max document length.
|
||||||
|
func selectMaxDocumentLength(maxDocLen uint) uint {
|
||||||
|
if maxDocLen != 0 {
|
||||||
|
return maxDocLen
|
||||||
|
}
|
||||||
|
|
||||||
|
maxDocLenEnv := os.Getenv(maxDocumentLengthEnvVar)
|
||||||
|
if maxDocLenEnv != "" {
|
||||||
|
maxDocLenEnvInt, err := strconv.ParseUint(maxDocLenEnv, 10, 32)
|
||||||
|
if err == nil {
|
||||||
|
return uint(maxDocLenEnvInt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return DefaultMaxDocumentLength
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
logSinkPathStdout = "stdout"
|
||||||
|
logSinkPathStderr = "stderr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// selectLogSink will return the first non-nil LogSink, with the user-defined
|
||||||
|
// LogSink taking precedence over the environment-defined LogSink. If no LogSink
|
||||||
|
// is defined, then this function will return a LogSink that writes to stderr.
|
||||||
|
func selectLogSink(sink LogSink) (LogSink, *os.File, error) {
|
||||||
|
if sink != nil {
|
||||||
|
return sink, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
path := os.Getenv(logSinkPathEnvVar)
|
||||||
|
lowerPath := strings.ToLower(path)
|
||||||
|
|
||||||
|
if lowerPath == string(logSinkPathStderr) {
|
||||||
|
return NewIOSink(os.Stderr), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if lowerPath == string(logSinkPathStdout) {
|
||||||
|
return NewIOSink(os.Stdout), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if path != "" {
|
||||||
|
logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("unable to open log file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewIOSink(logFile), logFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewIOSink(os.Stderr), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectComponentLevels returns a new map of LogComponents to LogLevels that is
|
||||||
|
// the result of merging the user-defined data with the environment, with the
|
||||||
|
// user-defined data taking priority.
|
||||||
|
func selectComponentLevels(componentLevels map[Component]Level) map[Component]Level {
|
||||||
|
selected := make(map[Component]Level)
|
||||||
|
|
||||||
|
// Determine if the "MONGODB_LOG_ALL" environment variable is set.
|
||||||
|
var globalEnvLevel *Level
|
||||||
|
if all := os.Getenv(mongoDBLogAllEnvVar); all != "" {
|
||||||
|
level := ParseLevel(all)
|
||||||
|
globalEnvLevel = &level
|
||||||
|
}
|
||||||
|
|
||||||
|
for envVar, component := range componentEnvVarMap {
|
||||||
|
// If the component already has a level, then skip it.
|
||||||
|
if _, ok := componentLevels[component]; ok {
|
||||||
|
selected[component] = componentLevels[component]
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the "MONGODB_LOG_ALL" environment variable is set, then
|
||||||
|
// set the level for the component to the value of the
|
||||||
|
// environment variable.
|
||||||
|
if globalEnvLevel != nil {
|
||||||
|
selected[component] = *globalEnvLevel
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, set the level for the component to the value of
|
||||||
|
// the environment variable.
|
||||||
|
selected[component] = ParseLevel(os.Getenv(envVar))
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatDocument formats a BSON document or RawValue for logging. The document is truncated
|
||||||
|
// to the given width.
|
||||||
|
func FormatDocument(msg bson.Raw, width uint) string {
|
||||||
|
if len(msg) == 0 {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
|
||||||
|
str := bsoncore.Document(msg).StringN(int(width))
|
||||||
|
|
||||||
|
// If the last byte is not a closing bracket, then the document was truncated
|
||||||
|
if len(str) > 0 && str[len(str)-1] != '}' {
|
||||||
|
str += TruncationSuffix
|
||||||
|
}
|
||||||
|
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatString formats a String for logging. The string is truncated
|
||||||
|
// to the given width.
|
||||||
|
func FormatString(str string, width uint) string {
|
||||||
|
strTrunc := bsoncoreutil.Truncate(str, int(width))
|
||||||
|
|
||||||
|
// Checks if the string was truncating by comparing the lengths of the two strings.
|
||||||
|
if len(strTrunc) < len(str) {
|
||||||
|
strTrunc += TruncationSuffix
|
||||||
|
}
|
||||||
|
|
||||||
|
return strTrunc
|
||||||
|
}
|
576
internal/logger/logger_test.go
Normal file
576
internal/logger/logger_test.go
Normal file
@ -0,0 +1,576 @@
|
|||||||
|
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gitea.psichedelico.com/go/bson"
|
||||||
|
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||||
|
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockLogSink struct{}
|
||||||
|
|
||||||
|
func (mockLogSink) Info(int, string, ...interface{}) {}
|
||||||
|
func (mockLogSink) Error(error, string, ...interface{}) {}
|
||||||
|
|
||||||
|
func BenchmarkLoggerWithLargeDocuments(b *testing.B) {
|
||||||
|
// Define the large document test cases
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
create func() bson.D
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LargeStrings",
|
||||||
|
create: func() bson.D { return createLargeStringsDocument(10) },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MassiveArrays",
|
||||||
|
create: func() bson.D { return createMassiveArraysDocument(100000) },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "VeryVoluminousDocument",
|
||||||
|
create: func() bson.D { return createVoluminousDocument(100000) },
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
// Run benchmark with logging and truncation enabled
|
||||||
|
b.Run("LoggingWithTruncation", func(b *testing.B) {
|
||||||
|
logger, err := New(mockLogSink{}, 0, map[Component]Level{
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
bs, err := bson.Marshal(tc.create())
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Print(LevelInfo, ComponentCommand, FormatDocument(bs, 1024), "foo", "bar", "baz")
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Run benchmark with logging enabled without truncation
|
||||||
|
b.Run("LoggingWithoutTruncation", func(b *testing.B) {
|
||||||
|
logger, err := New(mockLogSink{}, 0, map[Component]Level{
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
bs, err := bson.Marshal(tc.create())
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
msg := bsoncore.Document(bs).String()
|
||||||
|
logger.Print(LevelInfo, ComponentCommand, msg, "foo", "bar", "baz")
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Run benchmark without logging or truncation
|
||||||
|
b.Run("WithoutLoggingOrTruncation", func(b *testing.B) {
|
||||||
|
bs, err := bson.Marshal(tc.create())
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = bsoncore.Document(bs).String()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions to create large documents
|
||||||
|
func createVoluminousDocument(numKeys int) bson.D {
|
||||||
|
d := make(bson.D, numKeys)
|
||||||
|
for i := 0; i < numKeys; i++ {
|
||||||
|
d = append(d, bson.E{Key: fmt.Sprintf("key%d", i), Value: "value"})
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func createLargeStringsDocument(sizeMB int) bson.D {
|
||||||
|
largeString := strings.Repeat("a", sizeMB*1024*1024)
|
||||||
|
return bson.D{
|
||||||
|
{Key: "largeString1", Value: largeString},
|
||||||
|
{Key: "largeString2", Value: largeString},
|
||||||
|
{Key: "largeString3", Value: largeString},
|
||||||
|
{Key: "largeString4", Value: largeString},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMassiveArraysDocument(arraySize int) bson.D {
|
||||||
|
massiveArray := make([]string, arraySize)
|
||||||
|
for i := 0; i < arraySize; i++ {
|
||||||
|
massiveArray[i] = "value"
|
||||||
|
}
|
||||||
|
return bson.D{
|
||||||
|
{Key: "massiveArray1", Value: massiveArray},
|
||||||
|
{Key: "massiveArray2", Value: massiveArray},
|
||||||
|
{Key: "massiveArray3", Value: massiveArray},
|
||||||
|
{Key: "massiveArray4", Value: massiveArray},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
b.Run("Print", func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
logger, err := New(mockLogSink{}, 0, map[Component]Level{
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
logger.Print(LevelInfo, ComponentCommand, "foo", "bar", "baz")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockKeyValues(length int) (KeyValues, map[string]interface{}) {
|
||||||
|
keysAndValues := KeyValues{}
|
||||||
|
m := map[string]interface{}{}
|
||||||
|
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
keyName := fmt.Sprintf("key%d", i)
|
||||||
|
valueName := fmt.Sprintf("value%d", i)
|
||||||
|
|
||||||
|
keysAndValues.Add(keyName, valueName)
|
||||||
|
m[keyName] = valueName
|
||||||
|
}
|
||||||
|
|
||||||
|
return keysAndValues, m
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIOSinkInfo(b *testing.B) {
|
||||||
|
keysAndValues, _ := mockKeyValues(10)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
sink := NewIOSink(bytes.NewBuffer(nil))
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
sink.Info(0, "foo", keysAndValues...)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIOSinkInfo(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
const threshold = 1000
|
||||||
|
|
||||||
|
mockKeyValues, kvmap := mockKeyValues(10)
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
sink := NewIOSink(buf)
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(threshold)
|
||||||
|
|
||||||
|
for i := 0; i < threshold; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
sink.Info(0, "foo", mockKeyValues...)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
dec := json.NewDecoder(buf)
|
||||||
|
for dec.More() {
|
||||||
|
var m map[string]interface{}
|
||||||
|
if err := dec.Decode(&m); err != nil {
|
||||||
|
t.Fatalf("error unmarshaling JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m, KeyTimestamp)
|
||||||
|
delete(m, KeyMessage)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(m, kvmap) {
|
||||||
|
t.Fatalf("expected %v, got %v", kvmap, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectMaxDocumentLength(t *testing.T) {
|
||||||
|
for _, tcase := range []struct {
|
||||||
|
name string
|
||||||
|
arg uint
|
||||||
|
expected uint
|
||||||
|
env map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
arg: 0,
|
||||||
|
expected: DefaultMaxDocumentLength,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-zero",
|
||||||
|
arg: 100,
|
||||||
|
expected: 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid env",
|
||||||
|
arg: 0,
|
||||||
|
expected: 100,
|
||||||
|
env: map[string]string{
|
||||||
|
maxDocumentLengthEnvVar: "100",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid env",
|
||||||
|
arg: 0,
|
||||||
|
expected: DefaultMaxDocumentLength,
|
||||||
|
env: map[string]string{
|
||||||
|
maxDocumentLengthEnvVar: "foo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
tcase := tcase
|
||||||
|
|
||||||
|
t.Run(tcase.name, func(t *testing.T) {
|
||||||
|
for k, v := range tcase.env {
|
||||||
|
t.Setenv(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := selectMaxDocumentLength(tcase.arg)
|
||||||
|
if actual != tcase.expected {
|
||||||
|
t.Errorf("expected %d, got %d", tcase.expected, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectLogSink(t *testing.T) {
|
||||||
|
for _, tcase := range []struct {
|
||||||
|
name string
|
||||||
|
arg LogSink
|
||||||
|
expected LogSink
|
||||||
|
env map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
arg: nil,
|
||||||
|
expected: NewIOSink(os.Stderr),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-nil",
|
||||||
|
arg: mockLogSink{},
|
||||||
|
expected: mockLogSink{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stdout",
|
||||||
|
arg: nil,
|
||||||
|
expected: NewIOSink(os.Stdout),
|
||||||
|
env: map[string]string{
|
||||||
|
logSinkPathEnvVar: logSinkPathStdout,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stderr",
|
||||||
|
arg: nil,
|
||||||
|
expected: NewIOSink(os.Stderr),
|
||||||
|
env: map[string]string{
|
||||||
|
logSinkPathEnvVar: logSinkPathStderr,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
tcase := tcase
|
||||||
|
|
||||||
|
t.Run(tcase.name, func(t *testing.T) {
|
||||||
|
for k, v := range tcase.env {
|
||||||
|
t.Setenv(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual, _, _ := selectLogSink(tcase.arg)
|
||||||
|
if !reflect.DeepEqual(actual, tcase.expected) {
|
||||||
|
t.Errorf("expected %+v, got %+v", tcase.expected, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectedComponentLevels(t *testing.T) {
|
||||||
|
for _, tcase := range []struct {
|
||||||
|
name string
|
||||||
|
arg map[Component]Level
|
||||||
|
expected map[Component]Level
|
||||||
|
env map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
arg: nil,
|
||||||
|
expected: map[Component]Level{
|
||||||
|
ComponentCommand: LevelOff,
|
||||||
|
ComponentTopology: LevelOff,
|
||||||
|
ComponentServerSelection: LevelOff,
|
||||||
|
ComponentConnection: LevelOff,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-nil",
|
||||||
|
arg: map[Component]Level{
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
},
|
||||||
|
expected: map[Component]Level{
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
ComponentTopology: LevelOff,
|
||||||
|
ComponentServerSelection: LevelOff,
|
||||||
|
ComponentConnection: LevelOff,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid env",
|
||||||
|
arg: nil,
|
||||||
|
expected: map[Component]Level{
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
ComponentTopology: LevelInfo,
|
||||||
|
ComponentServerSelection: LevelOff,
|
||||||
|
ComponentConnection: LevelOff,
|
||||||
|
},
|
||||||
|
env: map[string]string{
|
||||||
|
mongoDBLogCommandEnvVar: levelLiteralDebug,
|
||||||
|
mongoDBLogTopologyEnvVar: levelLiteralInfo,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid env",
|
||||||
|
arg: nil,
|
||||||
|
expected: map[Component]Level{
|
||||||
|
ComponentCommand: LevelOff,
|
||||||
|
ComponentTopology: LevelOff,
|
||||||
|
ComponentServerSelection: LevelOff,
|
||||||
|
ComponentConnection: LevelOff,
|
||||||
|
},
|
||||||
|
env: map[string]string{
|
||||||
|
mongoDBLogCommandEnvVar: "foo",
|
||||||
|
mongoDBLogTopologyEnvVar: "bar",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
tcase := tcase
|
||||||
|
|
||||||
|
t.Run(tcase.name, func(t *testing.T) {
|
||||||
|
for k, v := range tcase.env {
|
||||||
|
t.Setenv(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := selectComponentLevels(tcase.arg)
|
||||||
|
for k, v := range tcase.expected {
|
||||||
|
if actual[k] != v {
|
||||||
|
t.Errorf("expected %d, got %d", v, actual[k])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogger_LevelComponentEnabled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
logger Logger
|
||||||
|
level Level
|
||||||
|
component Component
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "zero",
|
||||||
|
logger: Logger{},
|
||||||
|
level: LevelOff,
|
||||||
|
component: ComponentCommand,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{},
|
||||||
|
},
|
||||||
|
level: LevelOff,
|
||||||
|
component: ComponentCommand,
|
||||||
|
want: false, // LevelOff should never be considered enabled.
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one level below",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelInfo,
|
||||||
|
component: ComponentCommand,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "equal levels",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentCommand,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one level above",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentCommand: LevelInfo,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentCommand,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component mismatch",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentTopology,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component all enables with topology",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentAll: LevelDebug,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentTopology,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component all enables with server selection",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentAll: LevelDebug,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentServerSelection,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component all enables with connection",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentAll: LevelDebug,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentConnection,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component all enables with command",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentAll: LevelDebug,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentCommand,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component all enables with all",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentAll: LevelDebug,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentAll,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component all does not enable with lower level",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentAll: LevelInfo,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentCommand,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component all has a lower log level than command",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentAll: LevelInfo,
|
||||||
|
ComponentCommand: LevelDebug,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentCommand,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "component all has a higher log level than command",
|
||||||
|
logger: Logger{
|
||||||
|
ComponentLevels: map[Component]Level{
|
||||||
|
ComponentAll: LevelDebug,
|
||||||
|
ComponentCommand: LevelInfo,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
level: LevelDebug,
|
||||||
|
component: ComponentCommand,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tcase := range tests {
|
||||||
|
tcase := tcase // Capture the range variable.
|
||||||
|
|
||||||
|
t.Run(tcase.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := tcase.logger.LevelComponentEnabled(tcase.level, tcase.component)
|
||||||
|
assert.Equal(t, tcase.want, got, "unexpected result for LevelComponentEnabled")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user