Initial commit
This commit is contained in:
commit
4b4cceb81c
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
1692
THIRD-PARTY-NOTICES
Normal file
1692
THIRD-PARTY-NOTICES
Normal file
File diff suppressed because it is too large
Load Diff
42
array_codec.go
Normal file
42
array_codec.go
Normal file
@ -0,0 +1,42 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// arrayCodec is the Codec used for bsoncore.Array values.
|
||||
type arrayCodec struct{}
|
||||
|
||||
// EncodeValue is the ValueEncoder for bsoncore.Array values.
|
||||
func (ac *arrayCodec) EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tCoreArray {
|
||||
return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
|
||||
}
|
||||
|
||||
arr := val.Interface().(bsoncore.Array)
|
||||
return copyArrayFromBytes(vw, arr)
|
||||
}
|
||||
|
||||
// DecodeValue is the ValueDecoder for bsoncore.Array values.
|
||||
func (ac *arrayCodec) DecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||
if !val.CanSet() || val.Type() != tCoreArray {
|
||||
return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
|
||||
}
|
||||
|
||||
if val.IsNil() {
|
||||
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
|
||||
}
|
||||
|
||||
val.SetLen(0)
|
||||
arr, err := appendArrayBytes(val.Interface().(bsoncore.Array), vr)
|
||||
val.Set(reflect.ValueOf(arr))
|
||||
return err
|
||||
}
|
449
benchmark_test.go
Normal file
449
benchmark_test.go
Normal file
@ -0,0 +1,449 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type encodetest struct {
|
||||
Field1String string
|
||||
Field1Int64 int64
|
||||
Field1Float64 float64
|
||||
Field2String string
|
||||
Field2Int64 int64
|
||||
Field2Float64 float64
|
||||
Field3String string
|
||||
Field3Int64 int64
|
||||
Field3Float64 float64
|
||||
Field4String string
|
||||
Field4Int64 int64
|
||||
Field4Float64 float64
|
||||
}
|
||||
|
||||
type nestedtest1 struct {
|
||||
Nested nestedtest2
|
||||
}
|
||||
|
||||
type nestedtest2 struct {
|
||||
Nested nestedtest3
|
||||
}
|
||||
|
||||
type nestedtest3 struct {
|
||||
Nested nestedtest4
|
||||
}
|
||||
|
||||
type nestedtest4 struct {
|
||||
Nested nestedtest5
|
||||
}
|
||||
|
||||
type nestedtest5 struct {
|
||||
Nested nestedtest6
|
||||
}
|
||||
|
||||
type nestedtest6 struct {
|
||||
Nested nestedtest7
|
||||
}
|
||||
|
||||
type nestedtest7 struct {
|
||||
Nested nestedtest8
|
||||
}
|
||||
|
||||
type nestedtest8 struct {
|
||||
Nested nestedtest9
|
||||
}
|
||||
|
||||
type nestedtest9 struct {
|
||||
Nested nestedtest10
|
||||
}
|
||||
|
||||
type nestedtest10 struct {
|
||||
Nested nestedtest11
|
||||
}
|
||||
|
||||
type nestedtest11 struct {
|
||||
Nested encodetest
|
||||
}
|
||||
|
||||
var encodetestInstance = encodetest{
|
||||
Field1String: "foo",
|
||||
Field1Int64: 1,
|
||||
Field1Float64: 3.0,
|
||||
Field2String: "bar",
|
||||
Field2Int64: 2,
|
||||
Field2Float64: 3.1,
|
||||
Field3String: "baz",
|
||||
Field3Int64: 3,
|
||||
Field3Float64: 3.14,
|
||||
Field4String: "qux",
|
||||
Field4Int64: 4,
|
||||
Field4Float64: 3.141,
|
||||
}
|
||||
|
||||
var nestedInstance = nestedtest1{
|
||||
nestedtest2{
|
||||
nestedtest3{
|
||||
nestedtest4{
|
||||
nestedtest5{
|
||||
nestedtest6{
|
||||
nestedtest7{
|
||||
nestedtest8{
|
||||
nestedtest9{
|
||||
nestedtest10{
|
||||
nestedtest11{
|
||||
encodetest{
|
||||
Field1String: "foo",
|
||||
Field1Int64: 1,
|
||||
Field1Float64: 3.0,
|
||||
Field2String: "bar",
|
||||
Field2Int64: 2,
|
||||
Field2Float64: 3.1,
|
||||
Field3String: "baz",
|
||||
Field3Int64: 3,
|
||||
Field3Float64: 3.14,
|
||||
Field4String: "qux",
|
||||
Field4Int64: 4,
|
||||
Field4Float64: 3.141,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const extendedBSONDir = "./testdata/extended_bson"
|
||||
|
||||
var (
|
||||
extJSONFiles map[string]map[string]interface{}
|
||||
extJSONFilesMu sync.Mutex
|
||||
)
|
||||
|
||||
// readExtJSONFile reads the GZIP-compressed extended JSON document from the given filename in the
|
||||
// "extended BSON" test data directory (./testdata/extended_bson) and returns it as a
|
||||
// map[string]interface{}. It panics on any errors.
|
||||
func readExtJSONFile(filename string) map[string]interface{} {
|
||||
extJSONFilesMu.Lock()
|
||||
defer extJSONFilesMu.Unlock()
|
||||
if v, ok := extJSONFiles[filename]; ok {
|
||||
return v
|
||||
}
|
||||
filePath := path.Join(extendedBSONDir, filename)
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error opening file %q: %s", filePath, err))
|
||||
}
|
||||
defer func() {
|
||||
_ = file.Close()
|
||||
}()
|
||||
|
||||
gz, err := gzip.NewReader(file)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error creating GZIP reader: %s", err))
|
||||
}
|
||||
defer func() {
|
||||
_ = gz.Close()
|
||||
}()
|
||||
|
||||
data, err := ioutil.ReadAll(gz)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error reading GZIP contents of file: %s", err))
|
||||
}
|
||||
|
||||
var v map[string]interface{}
|
||||
err = UnmarshalExtJSON(data, false, &v)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error unmarshalling extended JSON: %s", err))
|
||||
}
|
||||
|
||||
if extJSONFiles == nil {
|
||||
extJSONFiles = make(map[string]map[string]interface{})
|
||||
}
|
||||
extJSONFiles[filename] = v
|
||||
return v
|
||||
}
|
||||
|
||||
func BenchmarkMarshal(b *testing.B) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
value interface{}
|
||||
}{
|
||||
{
|
||||
desc: "simple struct",
|
||||
value: encodetestInstance,
|
||||
},
|
||||
{
|
||||
desc: "nested struct",
|
||||
value: nestedInstance,
|
||||
},
|
||||
{
|
||||
desc: "deep_bson.json.gz",
|
||||
value: readExtJSONFile("deep_bson.json.gz"),
|
||||
},
|
||||
{
|
||||
desc: "flat_bson.json.gz",
|
||||
value: readExtJSONFile("flat_bson.json.gz"),
|
||||
},
|
||||
{
|
||||
desc: "full_bson.json.gz",
|
||||
value: readExtJSONFile("full_bson.json.gz"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.desc, func(b *testing.B) {
|
||||
b.Run("BSON", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := Marshal(tc.value)
|
||||
if err != nil {
|
||||
b.Errorf("error marshalling BSON: %s", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("extJSON", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := MarshalExtJSON(tc.value, true, false)
|
||||
if err != nil {
|
||||
b.Errorf("error marshalling extended JSON: %s", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("JSON", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := json.Marshal(tc.value)
|
||||
if err != nil {
|
||||
b.Errorf("error marshalling JSON: %s", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUnmarshal(b *testing.B) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
value interface{}
|
||||
}{
|
||||
{
|
||||
desc: "simple struct",
|
||||
value: encodetestInstance,
|
||||
},
|
||||
{
|
||||
desc: "nested struct",
|
||||
value: nestedInstance,
|
||||
},
|
||||
{
|
||||
desc: "deep_bson.json.gz",
|
||||
value: readExtJSONFile("deep_bson.json.gz"),
|
||||
},
|
||||
{
|
||||
desc: "flat_bson.json.gz",
|
||||
value: readExtJSONFile("flat_bson.json.gz"),
|
||||
},
|
||||
{
|
||||
desc: "full_bson.json.gz",
|
||||
value: readExtJSONFile("full_bson.json.gz"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.desc, func(b *testing.B) {
|
||||
b.Run("BSON", func(b *testing.B) {
|
||||
data, err := Marshal(tc.value)
|
||||
if err != nil {
|
||||
b.Errorf("error marshalling BSON: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
var v2 map[string]interface{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := Unmarshal(data, &v2)
|
||||
if err != nil {
|
||||
b.Errorf("error unmarshalling BSON: %s", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("extJSON", func(b *testing.B) {
|
||||
data, err := MarshalExtJSON(tc.value, true, false)
|
||||
if err != nil {
|
||||
b.Errorf("error marshalling extended JSON: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
var v2 map[string]interface{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := UnmarshalExtJSON(data, true, &v2)
|
||||
if err != nil {
|
||||
b.Errorf("error unmarshalling extended JSON: %s", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("JSON", func(b *testing.B) {
|
||||
data, err := json.Marshal(tc.value)
|
||||
if err != nil {
|
||||
b.Errorf("error marshalling JSON: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
var v2 map[string]interface{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := json.Unmarshal(data, &v2)
|
||||
if err != nil {
|
||||
b.Errorf("error unmarshalling JSON: %s", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// The following benchmarks are copied from the Go standard library's
|
||||
// encoding/json package.
|
||||
|
||||
type codeResponse struct {
|
||||
Tree *codeNode `json:"tree"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type codeNode struct {
|
||||
Name string `json:"name"`
|
||||
Kids []*codeNode `json:"kids"`
|
||||
CLWeight float64 `json:"cl_weight"`
|
||||
Touches int `json:"touches"`
|
||||
MinT int64 `json:"min_t"`
|
||||
MaxT int64 `json:"max_t"`
|
||||
MeanT int64 `json:"mean_t"`
|
||||
}
|
||||
|
||||
var codeJSON []byte
|
||||
var codeBSON []byte
|
||||
var codeStruct codeResponse
|
||||
|
||||
func codeInit() {
|
||||
f, err := os.Open("testdata/code.json.gz")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer f.Close()
|
||||
gz, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
data, err := io.ReadAll(gz)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
codeJSON = data
|
||||
|
||||
if err := json.Unmarshal(codeJSON, &codeStruct); err != nil {
|
||||
panic("json.Unmarshal code.json: " + err.Error())
|
||||
}
|
||||
|
||||
if data, err = json.Marshal(&codeStruct); err != nil {
|
||||
panic("json.Marshal code.json: " + err.Error())
|
||||
}
|
||||
|
||||
if codeBSON, err = Marshal(&codeStruct); err != nil {
|
||||
panic("Marshal code.json: " + err.Error())
|
||||
}
|
||||
|
||||
if !bytes.Equal(data, codeJSON) {
|
||||
println("different lengths", len(data), len(codeJSON))
|
||||
for i := 0; i < len(data) && i < len(codeJSON); i++ {
|
||||
if data[i] != codeJSON[i] {
|
||||
println("re-marshal: changed at byte", i)
|
||||
println("orig: ", string(codeJSON[i-10:i+10]))
|
||||
println("new: ", string(data[i-10:i+10]))
|
||||
break
|
||||
}
|
||||
}
|
||||
panic("re-marshal code.json: different result")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCodeUnmarshal(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
if codeJSON == nil {
|
||||
b.StopTimer()
|
||||
codeInit()
|
||||
b.StartTimer()
|
||||
}
|
||||
b.Run("BSON", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var r codeResponse
|
||||
if err := Unmarshal(codeBSON, &r); err != nil {
|
||||
b.Fatal("Unmarshal:", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.SetBytes(int64(len(codeBSON)))
|
||||
})
|
||||
b.Run("JSON", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
var r codeResponse
|
||||
if err := json.Unmarshal(codeJSON, &r); err != nil {
|
||||
b.Fatal("json.Unmarshal:", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.SetBytes(int64(len(codeJSON)))
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCodeMarshal(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
if codeJSON == nil {
|
||||
b.StopTimer()
|
||||
codeInit()
|
||||
b.StartTimer()
|
||||
}
|
||||
b.Run("BSON", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
if _, err := Marshal(&codeStruct); err != nil {
|
||||
b.Fatal("Marshal:", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.SetBytes(int64(len(codeBSON)))
|
||||
})
|
||||
b.Run("JSON", func(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
if _, err := json.Marshal(&codeStruct); err != nil {
|
||||
b.Fatal("json.Marshal:", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
b.SetBytes(int64(len(codeJSON)))
|
||||
})
|
||||
}
|
191
bson_binary_vector_spec_test.go
Normal file
191
bson_binary_vector_spec_test.go
Normal file
@ -0,0 +1,191 @@
|
||||
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/require"
|
||||
)
|
||||
|
||||
const bsonBinaryVectorDir = "./testdata/bson-binary-vector/"
|
||||
|
||||
type bsonBinaryVectorTests struct {
|
||||
Description string `json:"description"`
|
||||
TestKey string `json:"test_key"`
|
||||
Tests []bsonBinaryVectorTestCase `json:"tests"`
|
||||
}
|
||||
|
||||
type bsonBinaryVectorTestCase struct {
|
||||
Description string `json:"description"`
|
||||
Valid bool `json:"valid"`
|
||||
Vector []interface{} `json:"vector"`
|
||||
DtypeHex string `json:"dtype_hex"`
|
||||
DtypeAlias string `json:"dtype_alias"`
|
||||
Padding int `json:"padding"`
|
||||
CanonicalBson string `json:"canonical_bson"`
|
||||
}
|
||||
|
||||
func TestBsonBinaryVectorSpec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
jsonFiles, err := findJSONFilesInDir(bsonBinaryVectorDir)
|
||||
require.NoErrorf(t, err, "error finding JSON files in %s: %v", bsonBinaryVectorDir, err)
|
||||
|
||||
for _, file := range jsonFiles {
|
||||
filepath := path.Join(bsonBinaryVectorDir, file)
|
||||
content, err := os.ReadFile(filepath)
|
||||
require.NoErrorf(t, err, "reading test file %s", filepath)
|
||||
|
||||
var tests bsonBinaryVectorTests
|
||||
require.NoErrorf(t, json.Unmarshal(content, &tests), "parsing test file %s", filepath)
|
||||
|
||||
t.Run(tests.Description, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, test := range tests.Tests {
|
||||
test := test
|
||||
t.Run(test.Description, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runBsonBinaryVectorTest(t, tests.TestKey, test)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Padding specified with no vector data PACKED_BIT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Marshaling", func(t *testing.T) {
|
||||
_, err := NewPackedBitVector(nil, 1)
|
||||
require.EqualError(t, err, errNonZeroVectorPadding.Error())
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Exceeding maximum padding PACKED_BIT", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Marshaling", func(t *testing.T) {
|
||||
_, err := NewPackedBitVector(nil, 8)
|
||||
require.EqualError(t, err, errVectorPaddingTooLarge.Error())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func convertSlice[T int8 | float32 | byte](s []interface{}) []T {
|
||||
v := make([]T, len(s))
|
||||
for i, e := range s {
|
||||
f := math.NaN()
|
||||
switch val := e.(type) {
|
||||
case float64:
|
||||
f = val
|
||||
case string:
|
||||
if val == "inf" {
|
||||
f = math.Inf(0)
|
||||
} else if val == "-inf" {
|
||||
f = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
v[i] = T(f)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func runBsonBinaryVectorTest(t *testing.T, testKey string, test bsonBinaryVectorTestCase) {
|
||||
testVector := make(map[string]Vector)
|
||||
switch alias := test.DtypeHex; alias {
|
||||
case "0x03":
|
||||
testVector[testKey] = Vector{
|
||||
dType: Int8Vector,
|
||||
int8Data: convertSlice[int8](test.Vector),
|
||||
}
|
||||
case "0x27":
|
||||
testVector[testKey] = Vector{
|
||||
dType: Float32Vector,
|
||||
float32Data: convertSlice[float32](test.Vector),
|
||||
}
|
||||
case "0x10":
|
||||
testVector[testKey] = Vector{
|
||||
dType: PackedBitVector,
|
||||
bitData: convertSlice[byte](test.Vector),
|
||||
bitPadding: uint8(test.Padding),
|
||||
}
|
||||
default:
|
||||
t.Fatalf("unsupported vector type: %s", alias)
|
||||
}
|
||||
|
||||
testBSON, err := hex.DecodeString(test.CanonicalBson)
|
||||
require.NoError(t, err, "decoding canonical BSON")
|
||||
|
||||
t.Run("Unmarshaling", func(t *testing.T) {
|
||||
skipCases := map[string]string{
|
||||
"Overflow Vector INT8": "compile-time restriction",
|
||||
"Underflow Vector INT8": "compile-time restriction",
|
||||
"INT8 with float inputs": "compile-time restriction",
|
||||
"Overflow Vector PACKED_BIT": "compile-time restriction",
|
||||
"Underflow Vector PACKED_BIT": "compile-time restriction",
|
||||
"Vector with float values PACKED_BIT": "compile-time restriction",
|
||||
"Negative padding PACKED_BIT": "compile-time restriction",
|
||||
}
|
||||
if reason, ok := skipCases[test.Description]; ok {
|
||||
t.Skipf("skip test case %s: %s", test.Description, reason)
|
||||
}
|
||||
|
||||
errMap := map[string]string{
|
||||
"FLOAT32 with padding": "padding must be 0",
|
||||
"INT8 with padding": "padding must be 0",
|
||||
"Padding specified with no vector data PACKED_BIT": "padding must be 0",
|
||||
"Exceeding maximum padding PACKED_BIT": "padding cannot be larger than 7",
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
var got map[string]Vector
|
||||
err := Unmarshal(testBSON, &got)
|
||||
if test.Valid {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testVector, got)
|
||||
} else if errMsg, ok := errMap[test.Description]; ok {
|
||||
require.ErrorContains(t, err, errMsg)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Marshaling", func(t *testing.T) {
|
||||
skipCases := map[string]string{
|
||||
"FLOAT32 with padding": "private padding field",
|
||||
"Insufficient vector data with 3 bytes FLOAT32": "invalid case",
|
||||
"Insufficient vector data with 5 bytes FLOAT32": "invalid case",
|
||||
"Overflow Vector INT8": "compile-time restriction",
|
||||
"Underflow Vector INT8": "compile-time restriction",
|
||||
"INT8 with padding": "private padding field",
|
||||
"INT8 with float inputs": "compile-time restriction",
|
||||
"Overflow Vector PACKED_BIT": "compile-time restriction",
|
||||
"Underflow Vector PACKED_BIT": "compile-time restriction",
|
||||
"Vector with float values PACKED_BIT": "compile-time restriction",
|
||||
"Padding specified with no vector data PACKED_BIT": "run in alternative case",
|
||||
"Exceeding maximum padding PACKED_BIT": "run in alternative case",
|
||||
"Negative padding PACKED_BIT": "compile-time restriction",
|
||||
}
|
||||
if reason, ok := skipCases[test.Description]; ok {
|
||||
t.Skipf("skip test case %s: %s", test.Description, reason)
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
got, err := Marshal(testVector)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, testBSON, got)
|
||||
})
|
||||
}
|
504
bson_corpus_spec_test.go
Normal file
504
bson_corpus_spec_test.go
Normal file
@ -0,0 +1,504 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"gitea.psichedelico.com/go/bson/internal/require"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
type testCase struct {
|
||||
Description string `json:"description"`
|
||||
BsonType string `json:"bson_type"`
|
||||
TestKey *string `json:"test_key"`
|
||||
Valid []validityTestCase `json:"valid"`
|
||||
DecodeErrors []decodeErrorTestCase `json:"decodeErrors"`
|
||||
ParseErrors []parseErrorTestCase `json:"parseErrors"`
|
||||
Deprecated *bool `json:"deprecated"`
|
||||
}
|
||||
|
||||
type validityTestCase struct {
|
||||
Description string `json:"description"`
|
||||
CanonicalBson string `json:"canonical_bson"`
|
||||
CanonicalExtJSON string `json:"canonical_extjson"`
|
||||
RelaxedExtJSON *string `json:"relaxed_extjson"`
|
||||
DegenerateBSON *string `json:"degenerate_bson"`
|
||||
DegenerateExtJSON *string `json:"degenerate_extjson"`
|
||||
ConvertedBSON *string `json:"converted_bson"`
|
||||
ConvertedExtJSON *string `json:"converted_extjson"`
|
||||
Lossy *bool `json:"lossy"`
|
||||
}
|
||||
|
||||
type decodeErrorTestCase struct {
|
||||
Description string `json:"description"`
|
||||
Bson string `json:"bson"`
|
||||
}
|
||||
|
||||
type parseErrorTestCase struct {
|
||||
Description string `json:"description"`
|
||||
String string `json:"string"`
|
||||
}
|
||||
|
||||
const dataDir = "./testdata/bson-corpus/"
|
||||
|
||||
func findJSONFilesInDir(dir string) ([]string, error) {
|
||||
files := make([]string, 0)
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
|
||||
continue
|
||||
}
|
||||
|
||||
files = append(files, entry.Name())
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// seedExtJSON will add the byte representation of the "extJSON" string to the fuzzer's coprus.
|
||||
func seedExtJSON(f *testing.F, extJSON string, extJSONType string, desc string) {
|
||||
jbytes, err := jsonToBytes(extJSON, extJSONType, desc)
|
||||
if err != nil {
|
||||
f.Fatalf("failed to convert JSON to bytes: %v", err)
|
||||
}
|
||||
|
||||
f.Add(jbytes)
|
||||
}
|
||||
|
||||
// seedTestCase will add the byte representation for each "extJSON" string of each valid test case to the fuzzer's
|
||||
// corpus.
|
||||
func seedTestCase(f *testing.F, tcase *testCase) {
|
||||
for _, vtc := range tcase.Valid {
|
||||
seedExtJSON(f, vtc.CanonicalExtJSON, "canonical", vtc.Description)
|
||||
|
||||
// Seed the relaxed extended JSON.
|
||||
if vtc.RelaxedExtJSON != nil {
|
||||
seedExtJSON(f, *vtc.RelaxedExtJSON, "relaxed", vtc.Description)
|
||||
}
|
||||
|
||||
// Seed the degenerate extended JSON.
|
||||
if vtc.DegenerateExtJSON != nil {
|
||||
seedExtJSON(f, *vtc.DegenerateExtJSON, "degenerate", vtc.Description)
|
||||
}
|
||||
|
||||
// Seed the converted extended JSON.
|
||||
if vtc.ConvertedExtJSON != nil {
|
||||
seedExtJSON(f, *vtc.ConvertedExtJSON, "converted", vtc.Description)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// seedBSONCorpus will unmarshal the data from "testdata/bson-corpus" into a slice of "testCase" structs and then
|
||||
// marshal the "*_extjson" field of each "validityTestCase" into a slice of bytes to seed the fuzz corpus.
|
||||
func seedBSONCorpus(f *testing.F) {
|
||||
fileNames, err := findJSONFilesInDir(dataDir)
|
||||
if err != nil {
|
||||
f.Fatalf("failed to find JSON files in directory %q: %v", dataDir, err)
|
||||
}
|
||||
|
||||
for _, fileName := range fileNames {
|
||||
filePath := path.Join(dataDir, fileName)
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
f.Fatalf("failed to open file %q: %v", filePath, err)
|
||||
}
|
||||
|
||||
var tcase testCase
|
||||
if err := json.NewDecoder(file).Decode(&tcase); err != nil {
|
||||
f.Fatal(err)
|
||||
}
|
||||
|
||||
seedTestCase(f, &tcase)
|
||||
}
|
||||
}
|
||||
|
||||
func needsEscapedUnicode(bsonType string) bool {
|
||||
return bsonType == "0x02" || bsonType == "0x0D" || bsonType == "0x0E" || bsonType == "0x0F"
|
||||
}
|
||||
|
||||
func unescapeUnicode(s, bsonType string) string {
|
||||
if !needsEscapedUnicode(bsonType) {
|
||||
return s
|
||||
}
|
||||
|
||||
newS := ""
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
switch c {
|
||||
case '\\':
|
||||
switch s[i+1] {
|
||||
case 'u':
|
||||
us := s[i : i+6]
|
||||
u, err := strconv.Unquote(strings.Replace(strconv.Quote(us), `\\u`, `\u`, 1))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, r := range u {
|
||||
if r < ' ' {
|
||||
newS += fmt.Sprintf(`\u%04x`, r)
|
||||
} else {
|
||||
newS += string(r)
|
||||
}
|
||||
}
|
||||
i += 5
|
||||
default:
|
||||
newS += string(c)
|
||||
}
|
||||
default:
|
||||
if c > unicode.MaxASCII {
|
||||
r, size := utf8.DecodeRune([]byte(s[i:]))
|
||||
newS += string(r)
|
||||
i += size - 1
|
||||
} else {
|
||||
newS += string(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return newS
|
||||
}
|
||||
|
||||
func normalizeCanonicalDouble(t *testing.T, key string, cEJ string) string {
|
||||
// Unmarshal string into map
|
||||
cEJMap := make(map[string]map[string]string)
|
||||
err := json.Unmarshal([]byte(cEJ), &cEJMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the float contained by the map.
|
||||
expectedString := cEJMap[key]["$numberDouble"]
|
||||
expectedFloat, err := strconv.ParseFloat(expectedString, 64)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Normalize the string
|
||||
return fmt.Sprintf(`{"%s":{"$numberDouble":"%s"}}`, key, formatDouble(expectedFloat))
|
||||
}
|
||||
|
||||
func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
|
||||
// Unmarshal string into map
|
||||
rEJMap := make(map[string]float64)
|
||||
err := json.Unmarshal([]byte(rEJ), &rEJMap)
|
||||
if err != nil {
|
||||
return normalizeCanonicalDouble(t, key, rEJ)
|
||||
}
|
||||
|
||||
// Parse the float contained by the map.
|
||||
expectedFloat := rEJMap[key]
|
||||
|
||||
// Normalize the string
|
||||
return fmt.Sprintf(`{"%s":%s}`, key, formatDouble(expectedFloat))
|
||||
}
|
||||
|
||||
// bsonToNative decodes the BSON bytes (b) into a native Document
|
||||
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
|
||||
var doc D
|
||||
err := Unmarshal(b, &doc)
|
||||
require.NoErrorf(t, err, "%s: decoding %s BSON", testDesc, bType)
|
||||
return doc
|
||||
}
|
||||
|
||||
// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
|
||||
// canonical BSON (cB)
|
||||
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
|
||||
actual, err := Marshal(doc)
|
||||
require.NoErrorf(t, err, "%s: encoding %s BSON", testDesc, bType)
|
||||
|
||||
if diff := cmp.Diff(cB, actual); diff != "" {
|
||||
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
|
||||
testDesc, docSrcDesc, cB, actual)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// jsonToNative decodes the extended JSON string (ej) into a native Document
|
||||
func jsonToNative(ej, ejType, testDesc string) (D, error) {
|
||||
var doc D
|
||||
if err := UnmarshalExtJSON([]byte(ej), ejType != "relaxed", &doc); err != nil {
|
||||
return nil, fmt.Errorf("%s: decoding %s extended JSON: %w", testDesc, ejType, err)
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// jsonToBytes decodes the extended JSON string (ej) into canonical BSON and then encodes it into a byte slice.
|
||||
func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
|
||||
native, err := jsonToNative(ej, ejType, testDesc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b, err := Marshal(native)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: encoding %s BSON: %w", testDesc, ejType, err)
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// nativeToJSON encodes the native Document (doc) into an extended JSON string
|
||||
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
|
||||
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
|
||||
require.NoErrorf(t, err, "%s: encoding %s extended JSON", testDesc, ejType)
|
||||
|
||||
if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
|
||||
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
|
||||
testDesc, ejType, docSrcDesc, ejShortName, diff)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func runTest(t *testing.T, file string) {
|
||||
filepath := path.Join(dataDir, file)
|
||||
content, err := os.ReadFile(filepath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove ".json" from filename.
|
||||
file = file[:len(file)-5]
|
||||
testName := "bson_corpus--" + file
|
||||
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
var test testCase
|
||||
require.NoError(t, json.Unmarshal(content, &test))
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
for _, v := range test.Valid {
|
||||
t.Run(v.Description, func(t *testing.T) {
|
||||
// get canonical BSON
|
||||
cB, err := hex.DecodeString(v.CanonicalBson)
|
||||
require.NoErrorf(t, err, "%s: reading canonical BSON", v.Description)
|
||||
|
||||
// get canonical extended JSON
|
||||
var compactEJ bytes.Buffer
|
||||
require.NoError(t, json.Compact(&compactEJ, []byte(v.CanonicalExtJSON)))
|
||||
cEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
|
||||
if test.BsonType == "0x01" {
|
||||
cEJ = normalizeCanonicalDouble(t, *test.TestKey, cEJ)
|
||||
}
|
||||
|
||||
/*** canonical BSON round-trip tests ***/
|
||||
doc := bsonToNative(t, cB, "canonical", v.Description)
|
||||
|
||||
// native_to_bson(bson_to_native(cB)) = cB
|
||||
nativeToBSON(t, cB, doc, v.Description, "canonical", "bson_to_native(cB)")
|
||||
|
||||
// native_to_canonical_extended_json(bson_to_native(cB)) = cEJ
|
||||
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "bson_to_native(cB)")
|
||||
|
||||
// native_to_relaxed_extended_json(bson_to_native(cB)) = rEJ (if rEJ exists)
|
||||
if v.RelaxedExtJSON != nil {
|
||||
var compactEJ bytes.Buffer
|
||||
require.NoError(t, json.Compact(&compactEJ, []byte(*v.RelaxedExtJSON)))
|
||||
rEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
|
||||
if test.BsonType == "0x01" {
|
||||
rEJ = normalizeRelaxedDouble(t, *test.TestKey, rEJ)
|
||||
}
|
||||
|
||||
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "rEJ", "bson_to_native(cB)")
|
||||
|
||||
/*** relaxed extended JSON round-trip tests (if exists) ***/
|
||||
doc, err = jsonToNative(rEJ, "relaxed", v.Description)
|
||||
require.NoError(t, err)
|
||||
|
||||
// native_to_relaxed_extended_json(json_to_native(rEJ)) = rEJ
|
||||
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "eJR", "json_to_native(rEJ)")
|
||||
}
|
||||
|
||||
/*** canonical extended JSON round-trip tests ***/
|
||||
doc, err = jsonToNative(cEJ, "canonical", v.Description)
|
||||
require.NoError(t, err)
|
||||
|
||||
// native_to_canonical_extended_json(json_to_native(cEJ)) = cEJ
|
||||
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "json_to_native(cEJ)")
|
||||
|
||||
// native_to_bson(json_to_native(cEJ)) = cb (unless lossy)
|
||||
if v.Lossy == nil || !*v.Lossy {
|
||||
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(cEJ)")
|
||||
}
|
||||
|
||||
/*** degenerate BSON round-trip tests (if exists) ***/
|
||||
if v.DegenerateBSON != nil {
|
||||
dB, err := hex.DecodeString(*v.DegenerateBSON)
|
||||
require.NoErrorf(t, err, "%s: reading degenerate BSON", v.Description)
|
||||
|
||||
doc = bsonToNative(t, dB, "degenerate", v.Description)
|
||||
|
||||
// native_to_bson(bson_to_native(dB)) = cB
|
||||
nativeToBSON(t, cB, doc, v.Description, "degenerate", "bson_to_native(dB)")
|
||||
}
|
||||
|
||||
/*** degenerate JSON round-trip tests (if exists) ***/
|
||||
if v.DegenerateExtJSON != nil {
|
||||
var compactEJ bytes.Buffer
|
||||
require.NoError(t, json.Compact(&compactEJ, []byte(*v.DegenerateExtJSON)))
|
||||
dEJ := unescapeUnicode(compactEJ.String(), test.BsonType)
|
||||
if test.BsonType == "0x01" {
|
||||
dEJ = normalizeCanonicalDouble(t, *test.TestKey, dEJ)
|
||||
}
|
||||
|
||||
doc, err = jsonToNative(dEJ, "degenerate canonical", v.Description)
|
||||
require.NoError(t, err)
|
||||
|
||||
// native_to_canonical_extended_json(json_to_native(dEJ)) = cEJ
|
||||
nativeToJSON(t, cEJ, doc, v.Description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
|
||||
|
||||
// native_to_bson(json_to_native(dEJ)) = cB (unless lossy)
|
||||
if v.Lossy == nil || !*v.Lossy {
|
||||
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(dEJ)")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decode error", func(t *testing.T) {
|
||||
for _, d := range test.DecodeErrors {
|
||||
t.Run(d.Description, func(t *testing.T) {
|
||||
b, err := hex.DecodeString(d.Bson)
|
||||
require.NoError(t, err, d.Description)
|
||||
|
||||
var doc D
|
||||
err = Unmarshal(b, &doc)
|
||||
|
||||
// The driver unmarshals invalid UTF-8 strings without error. Loop over the unmarshalled elements
|
||||
// and assert that there was no error if any of the string or DBPointer values contain invalid UTF-8
|
||||
// characters.
|
||||
for _, elem := range doc {
|
||||
value := reflect.ValueOf(elem.Value)
|
||||
invalidString := (value.Kind() == reflect.String) && !utf8.ValidString(value.String())
|
||||
dbPtr, ok := elem.Value.(DBPointer)
|
||||
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)
|
||||
|
||||
if invalidString || invalidDBPtr {
|
||||
require.NoError(t, err, d.Description)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
require.Errorf(t, err, "%s: expected decode error", d.Description)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parse error", func(t *testing.T) {
|
||||
for _, p := range test.ParseErrors {
|
||||
t.Run(p.Description, func(t *testing.T) {
|
||||
s := unescapeUnicode(p.String, test.BsonType)
|
||||
if test.BsonType == "0x13" {
|
||||
s = fmt.Sprintf(`{"decimal128": {"$numberDecimal": "%s"}}`, s)
|
||||
}
|
||||
|
||||
switch test.BsonType {
|
||||
case "0x00", "0x05", "0x13":
|
||||
var doc D
|
||||
err := UnmarshalExtJSON([]byte(s), true, &doc)
|
||||
// Null bytes are validated when marshaling to BSON
|
||||
if strings.Contains(p.Description, "Null") {
|
||||
_, err = Marshal(doc)
|
||||
}
|
||||
require.Errorf(t, err, "%s: expected parse error", p.Description)
|
||||
default:
|
||||
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
|
||||
t.Fail()
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func Test_BsonCorpus(t *testing.T) {
|
||||
jsonFiles, err := findJSONFilesInDir(dataDir)
|
||||
require.NoErrorf(t, err, "error finding JSON files in %s: %v", dataDir, err)
|
||||
|
||||
for _, file := range jsonFiles {
|
||||
runTest(t, file)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelaxedUUIDValidation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
description string
|
||||
canonicalExtJSON string
|
||||
degenerateExtJSON string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
"valid uuid",
|
||||
"{\"x\" : { \"$binary\" : {\"base64\" : \"c//SZESzTGmQ6OfR38A11A==\", \"subType\" : \"04\"}}}",
|
||||
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035d4\"}}",
|
||||
"",
|
||||
},
|
||||
{
|
||||
"invalid uuid--no hyphens",
|
||||
"",
|
||||
"{\"x\" : { \"$uuid\" : \"73ffd26444b34c6990e8e7d1dfc035d4\"}}",
|
||||
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
|
||||
},
|
||||
{
|
||||
"invalid uuid--trailing hyphens",
|
||||
"",
|
||||
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035--\"}}",
|
||||
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
|
||||
},
|
||||
{
|
||||
"invalid uuid--malformed hex",
|
||||
"",
|
||||
"{\"x\" : { \"$uuid\" : \"q3@fd26l-44b3-4c69-90e8-e7d1dfc035d4\"}}",
|
||||
"$uuid value does not follow RFC 4122 format regarding hex bytes: encoding/hex: invalid byte: U+0071 'q'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
// get canonical extended JSON (if provided)
|
||||
cEJ := ""
|
||||
if tc.canonicalExtJSON != "" {
|
||||
var compactCEJ bytes.Buffer
|
||||
require.NoError(t, json.Compact(&compactCEJ, []byte(tc.canonicalExtJSON)))
|
||||
cEJ = unescapeUnicode(compactCEJ.String(), "0x05")
|
||||
}
|
||||
|
||||
// get degenerate extended JSON
|
||||
var compactDEJ bytes.Buffer
|
||||
require.NoError(t, json.Compact(&compactDEJ, []byte(tc.degenerateExtJSON)))
|
||||
dEJ := unescapeUnicode(compactDEJ.String(), "0x05")
|
||||
|
||||
// convert dEJ to native doc
|
||||
var doc D
|
||||
err := UnmarshalExtJSON([]byte(dEJ), true, &doc)
|
||||
|
||||
if tc.expectedErr != "" {
|
||||
assert.Equal(t, tc.expectedErr, err.Error(), "expected error %v, got %v", tc.expectedErr, err)
|
||||
} else {
|
||||
assert.Nil(t, err, "expected no error, got error: %v", err)
|
||||
|
||||
// Marshal doc into extended JSON and compare with cEJ
|
||||
nativeToJSON(t, cEJ, doc, tc.description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
679
bson_test.go
Normal file
679
bson_test.go
Normal file
@ -0,0 +1,679 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"gitea.psichedelico.com/go/bson/internal/require"
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func noerr(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Helper()
|
||||
t.Errorf("Unexpected error: (%T)%v", err, err)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestamp(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
tp Timestamp
|
||||
tp2 Timestamp
|
||||
expectedAfter bool
|
||||
expectedBefore bool
|
||||
expectedEqual bool
|
||||
expectedCompare int
|
||||
}{
|
||||
{
|
||||
description: "equal",
|
||||
tp: Timestamp{T: 12345, I: 67890},
|
||||
tp2: Timestamp{T: 12345, I: 67890},
|
||||
expectedBefore: false,
|
||||
expectedAfter: false,
|
||||
expectedEqual: true,
|
||||
expectedCompare: 0,
|
||||
},
|
||||
{
|
||||
description: "T greater than",
|
||||
tp: Timestamp{T: 12345, I: 67890},
|
||||
tp2: Timestamp{T: 2345, I: 67890},
|
||||
expectedBefore: false,
|
||||
expectedAfter: true,
|
||||
expectedEqual: false,
|
||||
expectedCompare: 1,
|
||||
},
|
||||
{
|
||||
description: "I greater than",
|
||||
tp: Timestamp{T: 12345, I: 67890},
|
||||
tp2: Timestamp{T: 12345, I: 7890},
|
||||
expectedBefore: false,
|
||||
expectedAfter: true,
|
||||
expectedEqual: false,
|
||||
expectedCompare: 1,
|
||||
},
|
||||
{
|
||||
description: "T less than",
|
||||
tp: Timestamp{T: 12345, I: 67890},
|
||||
tp2: Timestamp{T: 112345, I: 67890},
|
||||
expectedBefore: true,
|
||||
expectedAfter: false,
|
||||
expectedEqual: false,
|
||||
expectedCompare: -1,
|
||||
},
|
||||
{
|
||||
description: "I less than",
|
||||
tp: Timestamp{T: 12345, I: 67890},
|
||||
tp2: Timestamp{T: 12345, I: 167890},
|
||||
expectedBefore: true,
|
||||
expectedAfter: false,
|
||||
expectedEqual: false,
|
||||
expectedCompare: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // Capture range variable.
|
||||
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tc.expectedAfter, tc.tp.After(tc.tp2), "expected After results to be the same")
|
||||
assert.Equal(t, tc.expectedBefore, tc.tp.Before(tc.tp2), "expected Before results to be the same")
|
||||
assert.Equal(t, tc.expectedEqual, tc.tp.Equal(tc.tp2), "expected Equal results to be the same")
|
||||
assert.Equal(t, tc.expectedCompare, tc.tp.Compare(tc.tp2), "expected Compare result to be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrimitiveIsZero(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
zero Zeroer
|
||||
nonzero Zeroer
|
||||
}{
|
||||
{"binary", Binary{}, Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}},
|
||||
{"decimal128", Decimal128{}, NewDecimal128(1, 2)},
|
||||
{"objectID", ObjectID{}, NewObjectID()},
|
||||
{"regex", Regex{}, Regex{Pattern: "foo", Options: "bar"}},
|
||||
{"dbPointer", DBPointer{}, DBPointer{DB: "foobar", Pointer: ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}}},
|
||||
{"timestamp", Timestamp{}, Timestamp{T: 12345, I: 67890}},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
require.True(t, tc.zero.IsZero())
|
||||
require.False(t, tc.nonzero.IsZero())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexCompare(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
r1 Regex
|
||||
r2 Regex
|
||||
eq bool
|
||||
}{
|
||||
{"equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo1", Options: "bar1"}, true},
|
||||
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo2", Options: "bar2"}, false},
|
||||
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo1", Options: "bar2"}, false},
|
||||
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo2", Options: "bar1"}, false},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
require.True(t, tc.r1.Equal(tc.r2) == tc.eq)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateTime(t *testing.T) {
|
||||
t.Run("json", func(t *testing.T) {
|
||||
t.Run("round trip", func(t *testing.T) {
|
||||
original := DateTime(1000)
|
||||
jsonBytes, err := json.Marshal(original)
|
||||
assert.Nil(t, err, "Marshal error: %v", err)
|
||||
|
||||
var unmarshalled DateTime
|
||||
err = json.Unmarshal(jsonBytes, &unmarshalled)
|
||||
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||
|
||||
assert.Equal(t, original, unmarshalled, "expected DateTime %v, got %v", original, unmarshalled)
|
||||
})
|
||||
t.Run("decode null", func(t *testing.T) {
|
||||
jsonBytes := []byte("null")
|
||||
var dt DateTime
|
||||
err := json.Unmarshal(jsonBytes, &dt)
|
||||
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||
assert.Equal(t, DateTime(0), dt, "expected DateTime value to be 0, got %v", dt)
|
||||
})
|
||||
t.Run("UTC", func(t *testing.T) {
|
||||
dt := DateTime(1681145535123)
|
||||
jsonBytes, err := json.Marshal(dt)
|
||||
assert.Nil(t, err, "Marshal error: %v", err)
|
||||
assert.Equal(t, `"2023-04-10T16:52:15.123Z"`, string(jsonBytes))
|
||||
})
|
||||
})
|
||||
t.Run("NewDateTimeFromTime", func(t *testing.T) {
|
||||
t.Run("range is not limited", func(t *testing.T) {
|
||||
// If the implementation internally calls time.Time.UnixNano(), the constructor cannot handle times after
|
||||
// the year 2262.
|
||||
|
||||
timeFormat := "2006-01-02T15:04:05.999Z07:00"
|
||||
timeString := "3001-01-01T00:00:00Z"
|
||||
tt, err := time.Parse(timeFormat, timeString)
|
||||
assert.Nil(t, err, "Parse error: %v", err)
|
||||
|
||||
dt := NewDateTimeFromTime(tt)
|
||||
assert.True(t, dt > 0, "expected a valid DateTime greater than 0, got %v", dt)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestTimeRoundTrip(t *testing.T) {
|
||||
val := struct {
|
||||
Value time.Time
|
||||
ID string
|
||||
}{
|
||||
ID: "time-rt-test",
|
||||
}
|
||||
|
||||
if !val.Value.IsZero() {
|
||||
t.Errorf("Did not get zero time as expected.")
|
||||
}
|
||||
|
||||
bsonOut, err := Marshal(val)
|
||||
noerr(t, err)
|
||||
rtval := struct {
|
||||
Value time.Time
|
||||
ID string
|
||||
}{}
|
||||
|
||||
err = Unmarshal(bsonOut, &rtval)
|
||||
noerr(t, err)
|
||||
if !cmp.Equal(val, rtval) {
|
||||
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
|
||||
}
|
||||
if !rtval.Value.IsZero() {
|
||||
t.Errorf("Did not get zero time as expected.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonNullTimeRoundTrip(t *testing.T) {
|
||||
now := time.Now()
|
||||
now = time.Unix(now.Unix(), 0)
|
||||
val := struct {
|
||||
Value time.Time
|
||||
ID string
|
||||
}{
|
||||
ID: "time-rt-test",
|
||||
Value: now,
|
||||
}
|
||||
|
||||
bsonOut, err := Marshal(val)
|
||||
noerr(t, err)
|
||||
rtval := struct {
|
||||
Value time.Time
|
||||
ID string
|
||||
}{}
|
||||
|
||||
err = Unmarshal(bsonOut, &rtval)
|
||||
noerr(t, err)
|
||||
if !cmp.Equal(val, rtval) {
|
||||
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestD(t *testing.T) {
|
||||
t.Run("can marshal", func(t *testing.T) {
|
||||
d := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
|
||||
idx, want := bsoncore.AppendDocumentStart(nil)
|
||||
want = bsoncore.AppendStringElement(want, "foo", "bar")
|
||||
want = bsoncore.AppendStringElement(want, "hello", "world")
|
||||
want = bsoncore.AppendDoubleElement(want, "pi", 3.14159)
|
||||
want, err := bsoncore.AppendDocumentEnd(want, idx)
|
||||
noerr(t, err)
|
||||
got, err := Marshal(d)
|
||||
noerr(t, err)
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Marshaled documents do not match. got %v; want %v", Raw(got), Raw(want))
|
||||
}
|
||||
})
|
||||
t.Run("can unmarshal", func(t *testing.T) {
|
||||
want := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
|
||||
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||||
doc = bsoncore.AppendStringElement(doc, "foo", "bar")
|
||||
doc = bsoncore.AppendStringElement(doc, "hello", "world")
|
||||
doc = bsoncore.AppendDoubleElement(doc, "pi", 3.14159)
|
||||
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
|
||||
noerr(t, err)
|
||||
var got D
|
||||
err = Unmarshal(doc, &got)
|
||||
noerr(t, err)
|
||||
if !cmp.Equal(got, want) {
|
||||
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDStringer(t *testing.T) {
|
||||
got := D{{"a", 1}, {"b", 2}}.String()
|
||||
want := `{"a":{"$numberInt":"1"},"b":{"$numberInt":"2"}}`
|
||||
assert.Equal(t, want, got, "expected: %s, got: %s", want, got)
|
||||
}
|
||||
|
||||
func TestMStringer(t *testing.T) {
|
||||
type msg struct {
|
||||
A json.RawMessage `json:"a"`
|
||||
B json.RawMessage `json:"b"`
|
||||
}
|
||||
|
||||
var res msg
|
||||
got := M{"a": 1, "b": 2}.String()
|
||||
err := json.Unmarshal([]byte(got), &res)
|
||||
require.NoError(t, err, "Unmarshal error")
|
||||
|
||||
want := msg{
|
||||
A: json.RawMessage(`{"$numberInt":"1"}`),
|
||||
B: json.RawMessage(`{"$numberInt":"2"}`),
|
||||
}
|
||||
|
||||
assert.Equal(t, want, res, "returned string did not unmarshal to the expected document, returned string: %s", got)
|
||||
}
|
||||
func TestD_MarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testcases := []struct {
|
||||
name string
|
||||
test D
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
D{},
|
||||
struct{}{},
|
||||
},
|
||||
{
|
||||
"non-empty",
|
||||
D{
|
||||
{"a", 42},
|
||||
{"b", true},
|
||||
{"c", "answer"},
|
||||
{"d", nil},
|
||||
{"e", 2.71828},
|
||||
{"f", A{42, true, "answer", nil, 2.71828}},
|
||||
{"g", D{{"foo", "bar"}}},
|
||||
},
|
||||
struct {
|
||||
A int `json:"a"`
|
||||
B bool `json:"b"`
|
||||
C string `json:"c"`
|
||||
D interface{} `json:"d"`
|
||||
E float32 `json:"e"`
|
||||
F []interface{} `json:"f"`
|
||||
G map[string]interface{} `json:"g"`
|
||||
}{
|
||||
A: 42,
|
||||
B: true,
|
||||
C: "answer",
|
||||
D: nil,
|
||||
E: 2.71828,
|
||||
F: []interface{}{42, true, "answer", nil, 2.71828},
|
||||
G: map[string]interface{}{"foo": "bar"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
tc := tc
|
||||
t.Run("json.Marshal "+tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := json.Marshal(tc.test)
|
||||
assert.NoError(t, err)
|
||||
want, _ := json.Marshal(tc.expected)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
tc := tc
|
||||
t.Run("json.MarshalIndent "+tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := json.MarshalIndent(tc.test, "<prefix>", "<indent>")
|
||||
assert.NoError(t, err)
|
||||
want, _ := json.MarshalIndent(tc.expected, "<prefix>", "<indent>")
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestD_UnmarshalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test []byte
|
||||
expected D
|
||||
}{
|
||||
{
|
||||
"nil",
|
||||
[]byte(`null`),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
[]byte(`{}`),
|
||||
D{},
|
||||
},
|
||||
{
|
||||
"non-empty",
|
||||
[]byte(`{"hello":"world","pi":3.142,"boolean":true,"nothing":null,"list":["hello world",3.142,false,null,{"Lorem":"ipsum"}],"document":{"foo":"bar"}}`),
|
||||
D{
|
||||
{"hello", "world"},
|
||||
{"pi", 3.142},
|
||||
{"boolean", true},
|
||||
{"nothing", nil},
|
||||
{"list", []interface{}{"hello world", 3.142, false, nil, D{{"Lorem", "ipsum"}}}},
|
||||
{"document", D{{"foo", "bar"}}},
|
||||
},
|
||||
},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var got D
|
||||
err := json.Unmarshal(tc.test, &got)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("failure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
test string
|
||||
}{
|
||||
{
|
||||
"illegal",
|
||||
`nil`,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
`{"pi": 3.142ipsum}`,
|
||||
},
|
||||
{
|
||||
"malformatted",
|
||||
`{"pi", 3.142}`,
|
||||
},
|
||||
{
|
||||
"truncated",
|
||||
`{"pi": 3.142`,
|
||||
},
|
||||
{
|
||||
"array type",
|
||||
`["pi", 3.142]`,
|
||||
},
|
||||
{
|
||||
"boolean type",
|
||||
`true`,
|
||||
},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var a map[string]interface{}
|
||||
want := json.Unmarshal([]byte(tc.test), &a)
|
||||
var b D
|
||||
got := json.Unmarshal([]byte(tc.test), &b)
|
||||
switch w := want.(type) {
|
||||
case *json.UnmarshalTypeError:
|
||||
w.Type = reflect.TypeOf(b)
|
||||
require.IsType(t, want, got)
|
||||
g := got.(*json.UnmarshalTypeError)
|
||||
assert.Equal(t, w, g)
|
||||
default:
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type stringerString string
|
||||
|
||||
func (ss stringerString) String() string {
|
||||
return "bar"
|
||||
}
|
||||
|
||||
type keyBool bool
|
||||
|
||||
func (kb keyBool) MarshalKey() (string, error) {
|
||||
return fmt.Sprintf("%v", kb), nil
|
||||
}
|
||||
|
||||
func (kb *keyBool) UnmarshalKey(key string) error {
|
||||
switch key {
|
||||
case "true":
|
||||
*kb = true
|
||||
case "false":
|
||||
*kb = false
|
||||
default:
|
||||
return fmt.Errorf("invalid bool value %v", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type keyStruct struct {
|
||||
val int64
|
||||
}
|
||||
|
||||
func (k keyStruct) MarshalText() (text []byte, err error) {
|
||||
str := strconv.FormatInt(k.val, 10)
|
||||
|
||||
return []byte(str), nil
|
||||
}
|
||||
|
||||
func (k *keyStruct) UnmarshalText(text []byte) error {
|
||||
val, err := strconv.ParseInt(string(text), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*k = keyStruct{
|
||||
val: val,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMapCodec(t *testing.T) {
|
||||
t.Run("EncodeKeysWithStringer", func(t *testing.T) {
|
||||
strstr := stringerString("foo")
|
||||
mapObj := map[stringerString]int{strstr: 1}
|
||||
testCases := []struct {
|
||||
name string
|
||||
mapCodec *mapCodec
|
||||
key string
|
||||
}{
|
||||
{"default", &mapCodec{}, "foo"},
|
||||
{"true", &mapCodec{encodeKeysWithStringer: true}, "bar"},
|
||||
{"false", &mapCodec{encodeKeysWithStringer: false}, "foo"},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mapRegistry := NewRegistry()
|
||||
mapRegistry.RegisterKindEncoder(reflect.Map, tc.mapCodec)
|
||||
buf := new(bytes.Buffer)
|
||||
vw := NewDocumentWriter(buf)
|
||||
enc := NewEncoder(vw)
|
||||
enc.SetRegistry(mapRegistry)
|
||||
err := enc.Encode(mapObj)
|
||||
assert.Nil(t, err, "Encode error: %v", err)
|
||||
str := buf.String()
|
||||
assert.True(t, strings.Contains(str, tc.key), "expected result to contain %v, got: %v", tc.key, str)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("keys implements keyMarshaler and keyUnmarshaler", func(t *testing.T) {
|
||||
mapObj := map[keyBool]int{keyBool(true): 1}
|
||||
|
||||
doc, err := Marshal(mapObj)
|
||||
assert.Nil(t, err, "Marshal error: %v", err)
|
||||
idx, want := bsoncore.AppendDocumentStart(nil)
|
||||
want = bsoncore.AppendInt32Element(want, "true", 1)
|
||||
want, _ = bsoncore.AppendDocumentEnd(want, idx)
|
||||
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
|
||||
|
||||
var got map[keyBool]int
|
||||
err = Unmarshal(doc, &got)
|
||||
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
|
||||
|
||||
})
|
||||
|
||||
t.Run("keys implements encoding.TextMarshaler and encoding.TextUnmarshaler", func(t *testing.T) {
|
||||
mapObj := map[keyStruct]int{
|
||||
{val: 10}: 100,
|
||||
}
|
||||
|
||||
doc, err := Marshal(mapObj)
|
||||
assert.Nil(t, err, "Marshal error: %v", err)
|
||||
idx, want := bsoncore.AppendDocumentStart(nil)
|
||||
want = bsoncore.AppendInt32Element(want, "10", 100)
|
||||
want, _ = bsoncore.AppendDocumentEnd(want, idx)
|
||||
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
|
||||
|
||||
var got map[keyStruct]int
|
||||
err = Unmarshal(doc, &got)
|
||||
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtJSONEscapeKey(t *testing.T) {
|
||||
doc := D{
|
||||
{
|
||||
Key: "\\usb#",
|
||||
Value: int32(1),
|
||||
},
|
||||
{
|
||||
Key: "regex",
|
||||
Value: Regex{Pattern: "ab\\\\\\\"ab", Options: "\""},
|
||||
},
|
||||
}
|
||||
b, err := MarshalExtJSON(&doc, false, false)
|
||||
noerr(t, err)
|
||||
|
||||
want := `{"\\usb#":1,"regex":{"$regularExpression":{"pattern":"ab\\\\\\\"ab","options":"\""}}}`
|
||||
if diff := cmp.Diff(want, string(b)); diff != "" {
|
||||
t.Errorf("Marshaled documents do not match. got %v, want %v", string(b), want)
|
||||
}
|
||||
|
||||
var got D
|
||||
err = UnmarshalExtJSON(b, false, &got)
|
||||
noerr(t, err)
|
||||
if !cmp.Equal(got, doc) {
|
||||
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, doc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBsoncoreArray(t *testing.T) {
|
||||
type BSONDocumentArray struct {
|
||||
Array []D `bson:"array"`
|
||||
}
|
||||
|
||||
type BSONArray struct {
|
||||
Array bsoncore.Array `bson:"array"`
|
||||
}
|
||||
|
||||
bda := BSONDocumentArray{
|
||||
Array: []D{
|
||||
{{"x", 1}},
|
||||
{{"x", 2}},
|
||||
{{"x", 3}},
|
||||
},
|
||||
}
|
||||
|
||||
expectedBSON, err := Marshal(bda)
|
||||
assert.Nil(t, err, "Marshal bsoncore.Document array error: %v", err)
|
||||
|
||||
var ba BSONArray
|
||||
err = Unmarshal(expectedBSON, &ba)
|
||||
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||
|
||||
actualBSON, err := Marshal(ba)
|
||||
assert.Nil(t, err, "Marshal bsoncore.Array error: %v", err)
|
||||
|
||||
assert.Equal(t, expectedBSON, actualBSON,
|
||||
"expected BSON to be %v after Marshalling again; got %v", expectedBSON, actualBSON)
|
||||
|
||||
doc := bsoncore.Document(actualBSON)
|
||||
v := doc.Lookup("array")
|
||||
assert.Equal(t, bsoncore.TypeArray, v.Type, "expected type array, got %v", v.Type)
|
||||
}
|
||||
|
||||
var baseTime = time.Date(2024, 10, 11, 12, 13, 14, 12345678, time.UTC)
|
||||
|
||||
func BenchmarkDateTimeMarshalJSON(b *testing.B) {
|
||||
t := NewDateTimeFromTime(baseTime)
|
||||
data, err := t.MarshalJSON()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(data)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := t.MarshalJSON(); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDateTimeUnmarshalJSON(b *testing.B) {
|
||||
t := NewDateTimeFromTime(baseTime)
|
||||
data, err := t.MarshalJSON()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(data)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
var dt DateTime
|
||||
if err := dt.UnmarshalJSON(data); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
199
bsoncodec.go
Normal file
199
bsoncodec.go
Normal file
@ -0,0 +1,199 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
emptyValue = reflect.Value{}
|
||||
)
|
||||
|
||||
// ValueEncoderError is an error returned from a ValueEncoder when the provided value can't be
|
||||
// encoded by the ValueEncoder.
|
||||
type ValueEncoderError struct {
|
||||
Name string
|
||||
Types []reflect.Type
|
||||
Kinds []reflect.Kind
|
||||
Received reflect.Value
|
||||
}
|
||||
|
||||
func (vee ValueEncoderError) Error() string {
|
||||
typeKinds := make([]string, 0, len(vee.Types)+len(vee.Kinds))
|
||||
for _, t := range vee.Types {
|
||||
typeKinds = append(typeKinds, t.String())
|
||||
}
|
||||
for _, k := range vee.Kinds {
|
||||
if k == reflect.Map {
|
||||
typeKinds = append(typeKinds, "map[string]*")
|
||||
continue
|
||||
}
|
||||
typeKinds = append(typeKinds, k.String())
|
||||
}
|
||||
received := vee.Received.Kind().String()
|
||||
if vee.Received.IsValid() {
|
||||
received = vee.Received.Type().String()
|
||||
}
|
||||
return fmt.Sprintf("%s can only encode valid %s, but got %s", vee.Name, strings.Join(typeKinds, ", "), received)
|
||||
}
|
||||
|
||||
// ValueDecoderError is an error returned from a ValueDecoder when the provided value can't be
|
||||
// decoded by the ValueDecoder.
|
||||
type ValueDecoderError struct {
|
||||
Name string
|
||||
Types []reflect.Type
|
||||
Kinds []reflect.Kind
|
||||
Received reflect.Value
|
||||
}
|
||||
|
||||
func (vde ValueDecoderError) Error() string {
|
||||
typeKinds := make([]string, 0, len(vde.Types)+len(vde.Kinds))
|
||||
for _, t := range vde.Types {
|
||||
typeKinds = append(typeKinds, t.String())
|
||||
}
|
||||
for _, k := range vde.Kinds {
|
||||
if k == reflect.Map {
|
||||
typeKinds = append(typeKinds, "map[string]*")
|
||||
continue
|
||||
}
|
||||
typeKinds = append(typeKinds, k.String())
|
||||
}
|
||||
received := vde.Received.Kind().String()
|
||||
if vde.Received.IsValid() {
|
||||
received = vde.Received.Type().String()
|
||||
}
|
||||
return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received)
|
||||
}
|
||||
|
||||
// EncodeContext is the contextual information required for a Codec to encode a
|
||||
// value.
|
||||
type EncodeContext struct {
|
||||
*Registry
|
||||
|
||||
// minSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64,
|
||||
// uint, uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits)
|
||||
// that can represent the integer value.
|
||||
minSize bool
|
||||
|
||||
errorOnInlineDuplicates bool
|
||||
stringifyMapKeysWithFmt bool
|
||||
nilMapAsEmpty bool
|
||||
nilSliceAsEmpty bool
|
||||
nilByteSliceAsEmpty bool
|
||||
omitZeroStruct bool
|
||||
useJSONStructTags bool
|
||||
}
|
||||
|
||||
// DecodeContext is the contextual information required for a Codec to decode a
|
||||
// value.
|
||||
type DecodeContext struct {
|
||||
*Registry
|
||||
|
||||
// truncate, if true, instructs decoders to to truncate the fractional part of BSON "double"
|
||||
// values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64,
|
||||
// uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to
|
||||
// BSON "decimal128" values.
|
||||
truncate bool
|
||||
|
||||
// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the
|
||||
// usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is
|
||||
// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an
|
||||
// error.
|
||||
defaultDocumentType reflect.Type
|
||||
|
||||
binaryAsSlice bool
|
||||
|
||||
// a false value results in a decoding error.
|
||||
objectIDAsHexString bool
|
||||
|
||||
useJSONStructTags bool
|
||||
useLocalTimeZone bool
|
||||
zeroMaps bool
|
||||
zeroStructs bool
|
||||
}
|
||||
|
||||
// ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON.
|
||||
// The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the
|
||||
// EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc
|
||||
// is provided to allow use of a function with the correct signature as a ValueEncoder. An
|
||||
// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and
|
||||
// to provide configuration information.
|
||||
type ValueEncoder interface {
|
||||
EncodeValue(EncodeContext, ValueWriter, reflect.Value) error
|
||||
}
|
||||
|
||||
// ValueEncoderFunc is an adapter function that allows a function with the correct signature to be
|
||||
// used as a ValueEncoder.
|
||||
type ValueEncoderFunc func(EncodeContext, ValueWriter, reflect.Value) error
|
||||
|
||||
// EncodeValue implements the ValueEncoder interface.
|
||||
func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
return fn(ec, vw, val)
|
||||
}
|
||||
|
||||
// ValueDecoder is the interface implemented by types that can decode BSON to a provided Go type.
|
||||
// Implementations should ensure that the value they receive is settable. Similar to ValueEncoderFunc,
|
||||
// ValueDecoderFunc is provided to allow the use of a function with the correct signature as a
|
||||
// ValueDecoder. A DecodeContext instance is provided and serves similar functionality to the
|
||||
// EncodeContext.
|
||||
type ValueDecoder interface {
|
||||
DecodeValue(DecodeContext, ValueReader, reflect.Value) error
|
||||
}
|
||||
|
||||
// ValueDecoderFunc is an adapter function that allows a function with the correct signature to be
|
||||
// used as a ValueDecoder.
|
||||
type ValueDecoderFunc func(DecodeContext, ValueReader, reflect.Value) error
|
||||
|
||||
// DecodeValue implements the ValueDecoder interface.
|
||||
func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||
return fn(dc, vr, val)
|
||||
}
|
||||
|
||||
// typeDecoder is the interface implemented by types that can handle the decoding of a value given its type.
|
||||
type typeDecoder interface {
|
||||
decodeType(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
|
||||
}
|
||||
|
||||
// typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder.
|
||||
type typeDecoderFunc func(DecodeContext, ValueReader, reflect.Type) (reflect.Value, error)
|
||||
|
||||
func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
|
||||
return fn(dc, vr, t)
|
||||
}
|
||||
|
||||
// decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder.
|
||||
type decodeAdapter struct {
|
||||
ValueDecoderFunc
|
||||
typeDecoderFunc
|
||||
}
|
||||
|
||||
var _ ValueDecoder = decodeAdapter{}
|
||||
var _ typeDecoder = decodeAdapter{}
|
||||
|
||||
func decodeTypeOrValueWithInfo(vd ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
|
||||
if td, _ := vd.(typeDecoder); td != nil {
|
||||
val, err := td.decodeType(dc, vr, t)
|
||||
if err == nil && val.Type() != t {
|
||||
// This conversion step is necessary for slices and maps. If a user declares variables like:
|
||||
//
|
||||
// type myBool bool
|
||||
// var m map[string]myBool
|
||||
//
|
||||
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
|
||||
// because we'll try to assign a value of type bool to one of type myBool.
|
||||
val = val.Convert(t)
|
||||
}
|
||||
return val, err
|
||||
}
|
||||
|
||||
val := reflect.New(t).Elem()
|
||||
err := vd.DecodeValue(dc, vr, val)
|
||||
return val, err
|
||||
}
|
72
bsoncodec_test.go
Normal file
72
bsoncodec_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ExampleValueEncoder() {
|
||||
var _ ValueEncoderFunc = func(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if val.Kind() != reflect.String {
|
||||
return ValueEncoderError{Name: "StringEncodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
|
||||
}
|
||||
|
||||
return vw.WriteString(val.String())
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleValueDecoder() {
|
||||
var _ ValueDecoderFunc = func(_ DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||
if !val.CanSet() || val.Kind() != reflect.String {
|
||||
return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
|
||||
}
|
||||
|
||||
if vr.Type() != TypeString {
|
||||
return fmt.Errorf("cannot decode %v into a string type", vr.Type())
|
||||
}
|
||||
|
||||
str, err := vr.ReadString()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val.SetString(str)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type llCodec struct {
|
||||
t *testing.T
|
||||
decodeval interface{}
|
||||
encodeval interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (llc *llCodec) EncodeValue(_ EncodeContext, _ ValueWriter, i interface{}) error {
|
||||
if llc.err != nil {
|
||||
return llc.err
|
||||
}
|
||||
|
||||
llc.encodeval = i
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llc *llCodec) DecodeValue(_ DecodeContext, _ ValueReader, val reflect.Value) error {
|
||||
if llc.err != nil {
|
||||
return llc.err
|
||||
}
|
||||
|
||||
if !reflect.TypeOf(llc.decodeval).AssignableTo(val.Type()) {
|
||||
llc.t.Errorf("decodeval must be assignable to val provided to DecodeValue, but is not. decodeval %T; val %T", llc.decodeval, val)
|
||||
return nil
|
||||
}
|
||||
|
||||
val.Set(reflect.ValueOf(llc.decodeval))
|
||||
return nil
|
||||
}
|
846
bsonrw_test.go
Normal file
846
bsonrw_test.go
Normal file
@ -0,0 +1,846 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
var (
|
||||
_ ValueReader = &valueReaderWriter{}
|
||||
_ ValueWriter = &valueReaderWriter{}
|
||||
)
|
||||
|
||||
// invoked is a type used to indicate what method was called last.
|
||||
type invoked byte
|
||||
|
||||
// These are the different methods that can be invoked.
|
||||
const (
|
||||
nothing invoked = iota
|
||||
readArray
|
||||
readBinary
|
||||
readBoolean
|
||||
readDocument
|
||||
readCodeWithScope
|
||||
readDBPointer
|
||||
readDateTime
|
||||
readDecimal128
|
||||
readDouble
|
||||
readInt32
|
||||
readInt64
|
||||
readJavascript
|
||||
readMaxKey
|
||||
readMinKey
|
||||
readNull
|
||||
readObjectID
|
||||
readRegex
|
||||
readString
|
||||
readSymbol
|
||||
readTimestamp
|
||||
readUndefined
|
||||
readElement
|
||||
readValue
|
||||
writeArray
|
||||
writeBinary
|
||||
writeBinaryWithSubtype
|
||||
writeBoolean
|
||||
writeCodeWithScope
|
||||
writeDBPointer
|
||||
writeDateTime
|
||||
writeDecimal128
|
||||
writeDouble
|
||||
writeInt32
|
||||
writeInt64
|
||||
writeJavascript
|
||||
writeMaxKey
|
||||
writeMinKey
|
||||
writeNull
|
||||
writeObjectID
|
||||
writeRegex
|
||||
writeString
|
||||
writeDocument
|
||||
writeSymbol
|
||||
writeTimestamp
|
||||
writeUndefined
|
||||
writeDocumentElement
|
||||
writeDocumentEnd
|
||||
writeArrayElement
|
||||
writeArrayEnd
|
||||
skip
|
||||
)
|
||||
|
||||
func (i invoked) String() string {
|
||||
switch i {
|
||||
case nothing:
|
||||
return "Nothing"
|
||||
case readArray:
|
||||
return "ReadArray"
|
||||
case readBinary:
|
||||
return "ReadBinary"
|
||||
case readBoolean:
|
||||
return "ReadBoolean"
|
||||
case readDocument:
|
||||
return "ReadDocument"
|
||||
case readCodeWithScope:
|
||||
return "ReadCodeWithScope"
|
||||
case readDBPointer:
|
||||
return "ReadDBPointer"
|
||||
case readDateTime:
|
||||
return "ReadDateTime"
|
||||
case readDecimal128:
|
||||
return "ReadDecimal128"
|
||||
case readDouble:
|
||||
return "ReadDouble"
|
||||
case readInt32:
|
||||
return "ReadInt32"
|
||||
case readInt64:
|
||||
return "ReadInt64"
|
||||
case readJavascript:
|
||||
return "ReadJavascript"
|
||||
case readMaxKey:
|
||||
return "ReadMaxKey"
|
||||
case readMinKey:
|
||||
return "ReadMinKey"
|
||||
case readNull:
|
||||
return "ReadNull"
|
||||
case readObjectID:
|
||||
return "ReadObjectID"
|
||||
case readRegex:
|
||||
return "ReadRegex"
|
||||
case readString:
|
||||
return "ReadString"
|
||||
case readSymbol:
|
||||
return "ReadSymbol"
|
||||
case readTimestamp:
|
||||
return "ReadTimestamp"
|
||||
case readUndefined:
|
||||
return "ReadUndefined"
|
||||
case readElement:
|
||||
return "ReadElement"
|
||||
case readValue:
|
||||
return "ReadValue"
|
||||
case writeArray:
|
||||
return "WriteArray"
|
||||
case writeBinary:
|
||||
return "WriteBinary"
|
||||
case writeBinaryWithSubtype:
|
||||
return "WriteBinaryWithSubtype"
|
||||
case writeBoolean:
|
||||
return "WriteBoolean"
|
||||
case writeCodeWithScope:
|
||||
return "WriteCodeWithScope"
|
||||
case writeDBPointer:
|
||||
return "WriteDBPointer"
|
||||
case writeDateTime:
|
||||
return "WriteDateTime"
|
||||
case writeDecimal128:
|
||||
return "WriteDecimal128"
|
||||
case writeDouble:
|
||||
return "WriteDouble"
|
||||
case writeInt32:
|
||||
return "WriteInt32"
|
||||
case writeInt64:
|
||||
return "WriteInt64"
|
||||
case writeJavascript:
|
||||
return "WriteJavascript"
|
||||
case writeMaxKey:
|
||||
return "WriteMaxKey"
|
||||
case writeMinKey:
|
||||
return "WriteMinKey"
|
||||
case writeNull:
|
||||
return "WriteNull"
|
||||
case writeObjectID:
|
||||
return "WriteObjectID"
|
||||
case writeRegex:
|
||||
return "WriteRegex"
|
||||
case writeString:
|
||||
return "WriteString"
|
||||
case writeDocument:
|
||||
return "WriteDocument"
|
||||
case writeSymbol:
|
||||
return "WriteSymbol"
|
||||
case writeTimestamp:
|
||||
return "WriteTimestamp"
|
||||
case writeUndefined:
|
||||
return "WriteUndefined"
|
||||
case writeDocumentElement:
|
||||
return "WriteDocumentElement"
|
||||
case writeDocumentEnd:
|
||||
return "WriteDocumentEnd"
|
||||
case writeArrayElement:
|
||||
return "WriteArrayElement"
|
||||
case writeArrayEnd:
|
||||
return "WriteArrayEnd"
|
||||
default:
|
||||
return "<unknown>"
|
||||
}
|
||||
}
|
||||
|
||||
// valueReaderWriter is a test implementation of a bsonrw.ValueReader and bsonrw.ValueWriter
|
||||
type valueReaderWriter struct {
|
||||
T *testing.T
|
||||
invoked invoked
|
||||
Return interface{} // Can be a primitive or a bsoncore.Value
|
||||
BSONType Type
|
||||
Err error
|
||||
ErrAfter invoked // error after this method is called
|
||||
depth uint64
|
||||
}
|
||||
|
||||
// prevent infinite recursion.
|
||||
func (llvrw *valueReaderWriter) checkdepth() {
|
||||
llvrw.depth++
|
||||
if llvrw.depth > 1000 {
|
||||
panic("max depth exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
// Type implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) Type() Type {
|
||||
llvrw.checkdepth()
|
||||
return llvrw.BSONType
|
||||
}
|
||||
|
||||
// Skip implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) Skip() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = skip
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadArray implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadArray() (ArrayReader, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readArray
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return nil, llvrw.Err
|
||||
}
|
||||
|
||||
return llvrw, nil
|
||||
}
|
||||
|
||||
// ReadBinary implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadBinary() (b []byte, btype byte, err error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readBinary
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return nil, 0x00, llvrw.Err
|
||||
}
|
||||
|
||||
switch tt := llvrw.Return.(type) {
|
||||
case bsoncore.Value:
|
||||
subtype, data, _, ok := bsoncore.ReadBinary(tt.Data)
|
||||
if !ok {
|
||||
llvrw.T.Error("Invalid Value provided for return value of ReadBinary.")
|
||||
return nil, 0x00, nil
|
||||
}
|
||||
return data, subtype, nil
|
||||
default:
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadBinary: %T", llvrw.Return)
|
||||
return nil, 0x00, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReadBoolean implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadBoolean() (bool, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readBoolean
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return false, llvrw.Err
|
||||
}
|
||||
|
||||
switch tt := llvrw.Return.(type) {
|
||||
case bool:
|
||||
return tt, nil
|
||||
case bsoncore.Value:
|
||||
b, _, ok := bsoncore.ReadBoolean(tt.Data)
|
||||
if !ok {
|
||||
llvrw.T.Error("Invalid Value provided for return value of ReadBoolean.")
|
||||
return false, nil
|
||||
}
|
||||
return b, nil
|
||||
default:
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadBoolean: %T", llvrw.Return)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReadDocument implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadDocument() (DocumentReader, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readDocument
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return nil, llvrw.Err
|
||||
}
|
||||
|
||||
return llvrw, nil
|
||||
}
|
||||
|
||||
// ReadCodeWithScope implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readCodeWithScope
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return "", nil, llvrw.Err
|
||||
}
|
||||
|
||||
return "", llvrw, nil
|
||||
}
|
||||
|
||||
// ReadDBPointer implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadDBPointer() (ns string, oid ObjectID, err error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readDBPointer
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return "", ObjectID{}, llvrw.Err
|
||||
}
|
||||
|
||||
switch tt := llvrw.Return.(type) {
|
||||
case bsoncore.Value:
|
||||
ns, oid, _, ok := bsoncore.ReadDBPointer(tt.Data)
|
||||
if !ok {
|
||||
llvrw.T.Error("Invalid Value instance provided for return value of ReadDBPointer")
|
||||
return "", ObjectID{}, nil
|
||||
}
|
||||
return ns, oid, nil
|
||||
default:
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadDBPointer: %T", llvrw.Return)
|
||||
return "", ObjectID{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReadDateTime implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadDateTime() (int64, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readDateTime
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return 0, llvrw.Err
|
||||
}
|
||||
|
||||
dt, ok := llvrw.Return.(int64)
|
||||
if !ok {
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadDateTime: %T", llvrw.Return)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return dt, nil
|
||||
}
|
||||
|
||||
// ReadDecimal128 implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadDecimal128() (Decimal128, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readDecimal128
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return Decimal128{}, llvrw.Err
|
||||
}
|
||||
|
||||
d128, ok := llvrw.Return.(Decimal128)
|
||||
if !ok {
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadDecimal128: %T", llvrw.Return)
|
||||
return Decimal128{}, nil
|
||||
}
|
||||
|
||||
return d128, nil
|
||||
}
|
||||
|
||||
// ReadDouble implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadDouble() (float64, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readDouble
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return 0, llvrw.Err
|
||||
}
|
||||
|
||||
f64, ok := llvrw.Return.(float64)
|
||||
if !ok {
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadDouble: %T", llvrw.Return)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return f64, nil
|
||||
}
|
||||
|
||||
// ReadInt32 implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadInt32() (int32, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readInt32
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return 0, llvrw.Err
|
||||
}
|
||||
|
||||
i32, ok := llvrw.Return.(int32)
|
||||
if !ok {
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadInt32: %T", llvrw.Return)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return i32, nil
|
||||
}
|
||||
|
||||
// ReadInt64 implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadInt64() (int64, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readInt64
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return 0, llvrw.Err
|
||||
}
|
||||
i64, ok := llvrw.Return.(int64)
|
||||
if !ok {
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadInt64: %T", llvrw.Return)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return i64, nil
|
||||
}
|
||||
|
||||
// ReadJavascript implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadJavascript() (code string, err error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readJavascript
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return "", llvrw.Err
|
||||
}
|
||||
js, ok := llvrw.Return.(string)
|
||||
if !ok {
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadJavascript: %T", llvrw.Return)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return js, nil
|
||||
}
|
||||
|
||||
// ReadMaxKey implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadMaxKey() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readMaxKey
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadMinKey implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadMinKey() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readMinKey
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadNull implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadNull() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readNull
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadObjectID implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadObjectID() (ObjectID, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readObjectID
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return ObjectID{}, llvrw.Err
|
||||
}
|
||||
oid, ok := llvrw.Return.(ObjectID)
|
||||
if !ok {
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadObjectID: %T", llvrw.Return)
|
||||
return ObjectID{}, nil
|
||||
}
|
||||
|
||||
return oid, nil
|
||||
}
|
||||
|
||||
// ReadRegex implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadRegex() (pattern string, options string, err error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readRegex
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return "", "", llvrw.Err
|
||||
}
|
||||
switch tt := llvrw.Return.(type) {
|
||||
case bsoncore.Value:
|
||||
pattern, options, _, ok := bsoncore.ReadRegex(tt.Data)
|
||||
if !ok {
|
||||
llvrw.T.Error("Invalid Value instance provided for ReadRegex")
|
||||
return "", "", nil
|
||||
}
|
||||
return pattern, options, nil
|
||||
default:
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadRegex: %T", llvrw.Return)
|
||||
return "", "", nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReadString implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadString() (string, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readString
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return "", llvrw.Err
|
||||
}
|
||||
str, ok := llvrw.Return.(string)
|
||||
if !ok {
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadString: %T", llvrw.Return)
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// ReadSymbol implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadSymbol() (symbol string, err error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readSymbol
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return "", llvrw.Err
|
||||
}
|
||||
switch tt := llvrw.Return.(type) {
|
||||
case string:
|
||||
return tt, nil
|
||||
case bsoncore.Value:
|
||||
symbol, _, ok := bsoncore.ReadSymbol(tt.Data)
|
||||
if !ok {
|
||||
llvrw.T.Error("Invalid Value instance provided for ReadSymbol")
|
||||
return "", nil
|
||||
}
|
||||
return symbol, nil
|
||||
default:
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadSymbol: %T", llvrw.Return)
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReadTimestamp implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadTimestamp() (t uint32, i uint32, err error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readTimestamp
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return 0, 0, llvrw.Err
|
||||
}
|
||||
switch tt := llvrw.Return.(type) {
|
||||
case bsoncore.Value:
|
||||
t, i, _, ok := bsoncore.ReadTimestamp(tt.Data)
|
||||
if !ok {
|
||||
llvrw.T.Errorf("Invalid Value instance provided for return value of ReadTimestamp")
|
||||
return 0, 0, nil
|
||||
}
|
||||
return t, i, nil
|
||||
default:
|
||||
llvrw.T.Errorf("Incorrect type provided for return value of ReadTimestamp: %T", llvrw.Return)
|
||||
return 0, 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReadUndefined implements the ValueReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadUndefined() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readUndefined
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteArray implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteArray() (ArrayWriter, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeArray
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return nil, llvrw.Err
|
||||
}
|
||||
return llvrw, nil
|
||||
}
|
||||
|
||||
// WriteBinary implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteBinary([]byte) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeBinary
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteBinaryWithSubtype implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteBinaryWithSubtype([]byte, byte) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeBinaryWithSubtype
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteBoolean implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteBoolean(bool) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeBoolean
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteCodeWithScope implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteCodeWithScope(string) (DocumentWriter, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeCodeWithScope
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return nil, llvrw.Err
|
||||
}
|
||||
return llvrw, nil
|
||||
}
|
||||
|
||||
// WriteDBPointer implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteDBPointer(string, ObjectID) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeDBPointer
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteDateTime implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteDateTime(int64) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeDateTime
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteDecimal128 implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteDecimal128(Decimal128) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeDecimal128
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteDouble implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteDouble(float64) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeDouble
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteInt32 implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteInt32(int32) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeInt32
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteInt64 implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteInt64(int64) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeInt64
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteJavascript implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteJavascript(string) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeJavascript
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteMaxKey implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteMaxKey() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeMaxKey
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteMinKey implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteMinKey() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeMinKey
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteNull implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteNull() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeNull
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteObjectID implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteObjectID(ObjectID) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeObjectID
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteRegex implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteRegex(string, string) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeRegex
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteString implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteString(string) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeString
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteDocument implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteDocument() (DocumentWriter, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeDocument
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return nil, llvrw.Err
|
||||
}
|
||||
return llvrw, nil
|
||||
}
|
||||
|
||||
// WriteSymbol implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteSymbol(string) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeSymbol
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteTimestamp implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteTimestamp(uint32, uint32) error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeTimestamp
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteUndefined implements the ValueWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteUndefined() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeUndefined
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadElement implements the DocumentReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadElement() (string, ValueReader, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readElement
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return "", nil, llvrw.Err
|
||||
}
|
||||
|
||||
return "", llvrw, nil
|
||||
}
|
||||
|
||||
// WriteDocumentElement implements the DocumentWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteDocumentElement(string) (ValueWriter, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeDocumentElement
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return nil, llvrw.Err
|
||||
}
|
||||
|
||||
return llvrw, nil
|
||||
}
|
||||
|
||||
// WriteDocumentEnd implements the DocumentWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteDocumentEnd() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeDocumentEnd
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadValue implements the ArrayReader interface.
|
||||
func (llvrw *valueReaderWriter) ReadValue() (ValueReader, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = readValue
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return nil, llvrw.Err
|
||||
}
|
||||
|
||||
return llvrw, nil
|
||||
}
|
||||
|
||||
// WriteArrayElement implements the ArrayWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteArrayElement() (ValueWriter, error) {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeArrayElement
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return nil, llvrw.Err
|
||||
}
|
||||
|
||||
return llvrw, nil
|
||||
}
|
||||
|
||||
// WriteArrayEnd implements the ArrayWriter interface.
|
||||
func (llvrw *valueReaderWriter) WriteArrayEnd() error {
|
||||
llvrw.checkdepth()
|
||||
llvrw.invoked = writeArrayEnd
|
||||
if llvrw.ErrAfter == llvrw.invoked {
|
||||
return llvrw.Err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
97
byte_slice_codec.go
Normal file
97
byte_slice_codec.go
Normal file
@ -0,0 +1,97 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// byteSliceCodec is the Codec used for []byte values.
|
||||
type byteSliceCodec struct {
|
||||
// encodeNilAsEmpty causes EncodeValue to marshal nil Go byte slices as empty BSON binary values
|
||||
// instead of BSON null.
|
||||
encodeNilAsEmpty bool
|
||||
}
|
||||
|
||||
// Assert that byteSliceCodec satisfies the typeDecoder interface, which allows it to be
|
||||
// used by collection type decoders (e.g. map, slice, etc) to set individual values in a
|
||||
// collection.
|
||||
var _ typeDecoder = &byteSliceCodec{}
|
||||
|
||||
// EncodeValue is the ValueEncoder for []byte.
|
||||
func (bsc *byteSliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tByteSlice {
|
||||
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
|
||||
}
|
||||
if val.IsNil() && !bsc.encodeNilAsEmpty && !ec.nilByteSliceAsEmpty {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
return vw.WriteBinary(val.Interface().([]byte))
|
||||
}
|
||||
|
||||
func (bsc *byteSliceCodec) decodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
|
||||
if t != tByteSlice {
|
||||
return emptyValue, ValueDecoderError{
|
||||
Name: "ByteSliceDecodeValue",
|
||||
Types: []reflect.Type{tByteSlice},
|
||||
Received: reflect.Zero(t),
|
||||
}
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
switch vrType := vr.Type(); vrType {
|
||||
case TypeString:
|
||||
str, err := vr.ReadString()
|
||||
if err != nil {
|
||||
return emptyValue, err
|
||||
}
|
||||
data = []byte(str)
|
||||
case TypeSymbol:
|
||||
sym, err := vr.ReadSymbol()
|
||||
if err != nil {
|
||||
return emptyValue, err
|
||||
}
|
||||
data = []byte(sym)
|
||||
case TypeBinary:
|
||||
var subtype byte
|
||||
data, subtype, err = vr.ReadBinary()
|
||||
if err != nil {
|
||||
return emptyValue, err
|
||||
}
|
||||
if subtype != TypeBinaryGeneric && subtype != TypeBinaryBinaryOld {
|
||||
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "[]byte"}
|
||||
}
|
||||
case TypeNull:
|
||||
err = vr.ReadNull()
|
||||
case TypeUndefined:
|
||||
err = vr.ReadUndefined()
|
||||
default:
|
||||
return emptyValue, fmt.Errorf("cannot decode %v into a []byte", vrType)
|
||||
}
|
||||
if err != nil {
|
||||
return emptyValue, err
|
||||
}
|
||||
|
||||
return reflect.ValueOf(data), nil
|
||||
}
|
||||
|
||||
// DecodeValue is the ValueDecoder for []byte.
|
||||
func (bsc *byteSliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||
if !val.CanSet() || val.Type() != tByteSlice {
|
||||
return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
|
||||
}
|
||||
|
||||
elem, err := bsc.decodeType(dc, vr, tByteSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val.Set(elem)
|
||||
return nil
|
||||
}
|
166
codec_cache.go
Normal file
166
codec_cache.go
Normal file
@ -0,0 +1,166 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Runtime check that the kind encoder and decoder caches can store any valid
|
||||
// reflect.Kind constant.
|
||||
func init() {
|
||||
if s := reflect.Kind(len(kindEncoderCache{}.entries)).String(); s != "kind27" {
|
||||
panic("The capacity of kindEncoderCache is too small.\n" +
|
||||
"This is due to a new type being added to reflect.Kind.")
|
||||
}
|
||||
}
|
||||
|
||||
// statically assert array size
|
||||
var _ = (kindEncoderCache{}).entries[reflect.UnsafePointer]
|
||||
var _ = (kindDecoderCache{}).entries[reflect.UnsafePointer]
|
||||
|
||||
type typeEncoderCache struct {
|
||||
cache sync.Map // map[reflect.Type]ValueEncoder
|
||||
}
|
||||
|
||||
func (c *typeEncoderCache) Store(rt reflect.Type, enc ValueEncoder) {
|
||||
c.cache.Store(rt, enc)
|
||||
}
|
||||
|
||||
func (c *typeEncoderCache) Load(rt reflect.Type) (ValueEncoder, bool) {
|
||||
if v, _ := c.cache.Load(rt); v != nil {
|
||||
return v.(ValueEncoder), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (c *typeEncoderCache) LoadOrStore(rt reflect.Type, enc ValueEncoder) ValueEncoder {
|
||||
if v, loaded := c.cache.LoadOrStore(rt, enc); loaded {
|
||||
enc = v.(ValueEncoder)
|
||||
}
|
||||
return enc
|
||||
}
|
||||
|
||||
func (c *typeEncoderCache) Clone() *typeEncoderCache {
|
||||
cc := new(typeEncoderCache)
|
||||
c.cache.Range(func(k, v interface{}) bool {
|
||||
if k != nil && v != nil {
|
||||
cc.cache.Store(k, v)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return cc
|
||||
}
|
||||
|
||||
type typeDecoderCache struct {
|
||||
cache sync.Map // map[reflect.Type]ValueDecoder
|
||||
}
|
||||
|
||||
func (c *typeDecoderCache) Store(rt reflect.Type, dec ValueDecoder) {
|
||||
c.cache.Store(rt, dec)
|
||||
}
|
||||
|
||||
func (c *typeDecoderCache) Load(rt reflect.Type) (ValueDecoder, bool) {
|
||||
if v, _ := c.cache.Load(rt); v != nil {
|
||||
return v.(ValueDecoder), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (c *typeDecoderCache) LoadOrStore(rt reflect.Type, dec ValueDecoder) ValueDecoder {
|
||||
if v, loaded := c.cache.LoadOrStore(rt, dec); loaded {
|
||||
dec = v.(ValueDecoder)
|
||||
}
|
||||
return dec
|
||||
}
|
||||
|
||||
func (c *typeDecoderCache) Clone() *typeDecoderCache {
|
||||
cc := new(typeDecoderCache)
|
||||
c.cache.Range(func(k, v interface{}) bool {
|
||||
if k != nil && v != nil {
|
||||
cc.cache.Store(k, v)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return cc
|
||||
}
|
||||
|
||||
// atomic.Value requires that all calls to Store() have the same concrete type
|
||||
// so we wrap the ValueEncoder with a kindEncoderCacheEntry to ensure the type
|
||||
// is always the same (since different concrete types may implement the
|
||||
// ValueEncoder interface).
|
||||
type kindEncoderCacheEntry struct {
|
||||
enc ValueEncoder
|
||||
}
|
||||
|
||||
type kindEncoderCache struct {
|
||||
entries [reflect.UnsafePointer + 1]atomic.Value // *kindEncoderCacheEntry
|
||||
}
|
||||
|
||||
func (c *kindEncoderCache) Store(rt reflect.Kind, enc ValueEncoder) {
|
||||
if enc != nil && rt < reflect.Kind(len(c.entries)) {
|
||||
c.entries[rt].Store(&kindEncoderCacheEntry{enc: enc})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *kindEncoderCache) Load(rt reflect.Kind) (ValueEncoder, bool) {
|
||||
if rt < reflect.Kind(len(c.entries)) {
|
||||
if ent, ok := c.entries[rt].Load().(*kindEncoderCacheEntry); ok {
|
||||
return ent.enc, ent.enc != nil
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (c *kindEncoderCache) Clone() *kindEncoderCache {
|
||||
cc := new(kindEncoderCache)
|
||||
for i, v := range c.entries {
|
||||
if val := v.Load(); val != nil {
|
||||
cc.entries[i].Store(val)
|
||||
}
|
||||
}
|
||||
return cc
|
||||
}
|
||||
|
||||
// atomic.Value requires that all calls to Store() have the same concrete type
|
||||
// so we wrap the ValueDecoder with a kindDecoderCacheEntry to ensure the type
|
||||
// is always the same (since different concrete types may implement the
|
||||
// ValueDecoder interface).
|
||||
type kindDecoderCacheEntry struct {
|
||||
dec ValueDecoder
|
||||
}
|
||||
|
||||
type kindDecoderCache struct {
|
||||
entries [reflect.UnsafePointer + 1]atomic.Value // *kindDecoderCacheEntry
|
||||
}
|
||||
|
||||
func (c *kindDecoderCache) Store(rt reflect.Kind, dec ValueDecoder) {
|
||||
if rt < reflect.Kind(len(c.entries)) {
|
||||
c.entries[rt].Store(&kindDecoderCacheEntry{dec: dec})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *kindDecoderCache) Load(rt reflect.Kind) (ValueDecoder, bool) {
|
||||
if rt < reflect.Kind(len(c.entries)) {
|
||||
if ent, ok := c.entries[rt].Load().(*kindDecoderCacheEntry); ok {
|
||||
return ent.dec, ent.dec != nil
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (c *kindDecoderCache) Clone() *kindDecoderCache {
|
||||
cc := new(kindDecoderCache)
|
||||
for i, v := range c.entries {
|
||||
if val := v.Load(); val != nil {
|
||||
cc.entries[i].Store(val)
|
||||
}
|
||||
}
|
||||
return cc
|
||||
}
|
176
codec_cache_test.go
Normal file
176
codec_cache_test.go
Normal file
@ -0,0 +1,176 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// NB(charlie): the array size is a power of 2 because we use the remainder of
|
||||
// it (mod) in benchmarks and that is faster when the size is a power of 2.
|
||||
var codecCacheTestTypes = [16]reflect.Type{
|
||||
reflect.TypeOf(uint8(0)),
|
||||
reflect.TypeOf(uint16(0)),
|
||||
reflect.TypeOf(uint32(0)),
|
||||
reflect.TypeOf(uint64(0)),
|
||||
reflect.TypeOf(uint(0)),
|
||||
reflect.TypeOf(uintptr(0)),
|
||||
reflect.TypeOf(int8(0)),
|
||||
reflect.TypeOf(int16(0)),
|
||||
reflect.TypeOf(int32(0)),
|
||||
reflect.TypeOf(int64(0)),
|
||||
reflect.TypeOf(int(0)),
|
||||
reflect.TypeOf(float32(0)),
|
||||
reflect.TypeOf(float64(0)),
|
||||
reflect.TypeOf(true),
|
||||
reflect.TypeOf(struct{ A int }{}),
|
||||
reflect.TypeOf(map[int]int{}),
|
||||
}
|
||||
|
||||
func TestTypeCache(t *testing.T) {
|
||||
rt := reflect.TypeOf(int(0))
|
||||
ec := new(typeEncoderCache)
|
||||
dc := new(typeDecoderCache)
|
||||
|
||||
codec := new(fakeCodec)
|
||||
ec.Store(rt, codec)
|
||||
dc.Store(rt, codec)
|
||||
if v, ok := ec.Load(rt); !ok || !reflect.DeepEqual(v, codec) {
|
||||
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true)
|
||||
}
|
||||
if v, ok := dc.Load(rt); !ok || !reflect.DeepEqual(v, codec) {
|
||||
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, codec, true)
|
||||
}
|
||||
|
||||
// Make sure we overwrite the stored value with nil
|
||||
ec.Store(rt, nil)
|
||||
dc.Store(rt, nil)
|
||||
if v, ok := ec.Load(rt); ok || v != nil {
|
||||
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false)
|
||||
}
|
||||
if v, ok := dc.Load(rt); ok || v != nil {
|
||||
t.Errorf("Load(%s) = %v, %t; want: %v, %t", rt, v, ok, nil, false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeCacheClone(t *testing.T) {
|
||||
codec := new(fakeCodec)
|
||||
ec1 := new(typeEncoderCache)
|
||||
dc1 := new(typeDecoderCache)
|
||||
for _, rt := range codecCacheTestTypes {
|
||||
ec1.Store(rt, codec)
|
||||
dc1.Store(rt, codec)
|
||||
}
|
||||
ec2 := ec1.Clone()
|
||||
dc2 := dc1.Clone()
|
||||
for _, rt := range codecCacheTestTypes {
|
||||
if v, _ := ec2.Load(rt); !reflect.DeepEqual(v, codec) {
|
||||
t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec)
|
||||
}
|
||||
if v, _ := dc2.Load(rt); !reflect.DeepEqual(v, codec) {
|
||||
t.Errorf("Load(%s) = %#v; want: %#v", rt, v, codec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKindCacheArray(t *testing.T) {
|
||||
// Check array bounds
|
||||
var c kindEncoderCache
|
||||
codec := new(fakeCodec)
|
||||
c.Store(reflect.UnsafePointer, codec) // valid
|
||||
c.Store(reflect.UnsafePointer+1, codec) // ignored
|
||||
if v, ok := c.Load(reflect.UnsafePointer); !ok || v != codec {
|
||||
t.Errorf("Load(reflect.UnsafePointer) = %v, %t; want: %v, %t", v, ok, codec, true)
|
||||
}
|
||||
if v, ok := c.Load(reflect.UnsafePointer + 1); ok || v != nil {
|
||||
t.Errorf("Load(reflect.UnsafePointer + 1) = %v, %t; want: %v, %t", v, ok, nil, false)
|
||||
}
|
||||
|
||||
// Make sure that reflect.UnsafePointer is the last/largest reflect.Type.
|
||||
//
|
||||
// The String() method of invalid reflect.Type types are of the format
|
||||
// "kind{NUMBER}".
|
||||
for rt := reflect.UnsafePointer + 1; rt < reflect.UnsafePointer+16; rt++ {
|
||||
s := rt.String()
|
||||
if !strings.Contains(s, strconv.Itoa(int(rt))) {
|
||||
t.Errorf("reflect.Type(%d) appears to be valid: %q", rt, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKindCacheClone(t *testing.T) {
|
||||
e1 := new(kindEncoderCache)
|
||||
d1 := new(kindDecoderCache)
|
||||
codec := new(fakeCodec)
|
||||
for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ {
|
||||
e1.Store(k, codec)
|
||||
d1.Store(k, codec)
|
||||
}
|
||||
e2 := e1.Clone()
|
||||
for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ {
|
||||
v1, ok1 := e1.Load(k)
|
||||
v2, ok2 := e2.Load(k)
|
||||
if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil {
|
||||
t.Errorf("Encoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2)
|
||||
}
|
||||
}
|
||||
d2 := d1.Clone()
|
||||
for k := reflect.Invalid; k <= reflect.UnsafePointer; k++ {
|
||||
v1, ok1 := d1.Load(k)
|
||||
v2, ok2 := d2.Load(k)
|
||||
if ok1 != ok2 || !reflect.DeepEqual(v1, v2) || v1 == nil || v2 == nil {
|
||||
t.Errorf("Decoder(%s): %#v, %t != %#v, %t", k, v1, ok1, v2, ok2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKindCacheEncoderNilEncoder(t *testing.T) {
|
||||
t.Run("Encoder", func(t *testing.T) {
|
||||
c := new(kindEncoderCache)
|
||||
c.Store(reflect.Invalid, ValueEncoder(nil))
|
||||
v, ok := c.Load(reflect.Invalid)
|
||||
if v != nil || ok {
|
||||
t.Errorf("Load of nil ValueEncoder should return: nil, false; got: %v, %t", v, ok)
|
||||
}
|
||||
})
|
||||
t.Run("Decoder", func(t *testing.T) {
|
||||
c := new(kindDecoderCache)
|
||||
c.Store(reflect.Invalid, ValueDecoder(nil))
|
||||
v, ok := c.Load(reflect.Invalid)
|
||||
if v != nil || ok {
|
||||
t.Errorf("Load of nil ValueDecoder should return: nil, false; got: %v, %t", v, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkEncoderCacheLoad(b *testing.B) {
|
||||
c := new(typeEncoderCache)
|
||||
codec := new(fakeCodec)
|
||||
typs := codecCacheTestTypes
|
||||
for _, t := range typs {
|
||||
c.Store(t, codec)
|
||||
}
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for i := 0; pb.Next(); i++ {
|
||||
c.Load(typs[i%len(typs)])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkEncoderCacheStore(b *testing.B) {
|
||||
c := new(typeEncoderCache)
|
||||
codec := new(fakeCodec)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
typs := codecCacheTestTypes
|
||||
for i := 0; pb.Next(); i++ {
|
||||
c.Store(typs[i%len(typs)], codec)
|
||||
}
|
||||
})
|
||||
}
|
61
cond_addr_codec.go
Normal file
61
cond_addr_codec.go
Normal file
@ -0,0 +1,61 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder.
|
||||
type condAddrEncoder struct {
|
||||
canAddrEnc ValueEncoder
|
||||
elseEnc ValueEncoder
|
||||
}
|
||||
|
||||
var _ ValueEncoder = &condAddrEncoder{}
|
||||
|
||||
// newCondAddrEncoder returns an condAddrEncoder.
|
||||
func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder {
|
||||
encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
|
||||
return &encoder
|
||||
}
|
||||
|
||||
// EncodeValue is the ValueEncoderFunc for a value that may be addressable.
|
||||
func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if val.CanAddr() {
|
||||
return cae.canAddrEnc.EncodeValue(ec, vw, val)
|
||||
}
|
||||
if cae.elseEnc != nil {
|
||||
return cae.elseEnc.EncodeValue(ec, vw, val)
|
||||
}
|
||||
return errNoEncoder{Type: val.Type()}
|
||||
}
|
||||
|
||||
// condAddrDecoder is the decoder used when a pointer to the value has a decoder.
|
||||
type condAddrDecoder struct {
|
||||
canAddrDec ValueDecoder
|
||||
elseDec ValueDecoder
|
||||
}
|
||||
|
||||
var _ ValueDecoder = &condAddrDecoder{}
|
||||
|
||||
// newCondAddrDecoder returns an CondAddrDecoder.
|
||||
func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder {
|
||||
decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec}
|
||||
return &decoder
|
||||
}
|
||||
|
||||
// DecodeValue is the ValueDecoderFunc for a value that may be addressable.
|
||||
func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||
if val.CanAddr() {
|
||||
return cad.canAddrDec.DecodeValue(dc, vr, val)
|
||||
}
|
||||
if cad.elseDec != nil {
|
||||
return cad.elseDec.DecodeValue(dc, vr, val)
|
||||
}
|
||||
return errNoDecoder{Type: val.Type()}
|
||||
}
|
95
cond_addr_codec_test.go
Normal file
95
cond_addr_codec_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
)
|
||||
|
||||
func TestCondAddrCodec(t *testing.T) {
|
||||
var inner int
|
||||
canAddrVal := reflect.ValueOf(&inner)
|
||||
addressable := canAddrVal.Elem()
|
||||
unaddressable := reflect.ValueOf(inner)
|
||||
rw := &valueReaderWriter{}
|
||||
|
||||
t.Run("addressEncode", func(t *testing.T) {
|
||||
invoked := 0
|
||||
encode1 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error {
|
||||
invoked = 1
|
||||
return nil
|
||||
})
|
||||
encode2 := ValueEncoderFunc(func(EncodeContext, ValueWriter, reflect.Value) error {
|
||||
invoked = 2
|
||||
return nil
|
||||
})
|
||||
condEncoder := newCondAddrEncoder(encode1, encode2)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
val reflect.Value
|
||||
invoked int
|
||||
}{
|
||||
{"canAddr", addressable, 1},
|
||||
{"else", unaddressable, 2},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := condEncoder.EncodeValue(EncodeContext{}, rw, tc.val)
|
||||
assert.Nil(t, err, "CondAddrEncoder error: %v", err)
|
||||
|
||||
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("error", func(t *testing.T) {
|
||||
errEncoder := newCondAddrEncoder(encode1, nil)
|
||||
err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable)
|
||||
want := errNoEncoder{Type: unaddressable.Type()}
|
||||
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
|
||||
})
|
||||
})
|
||||
t.Run("addressDecode", func(t *testing.T) {
|
||||
invoked := 0
|
||||
decode1 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error {
|
||||
invoked = 1
|
||||
return nil
|
||||
})
|
||||
decode2 := ValueDecoderFunc(func(DecodeContext, ValueReader, reflect.Value) error {
|
||||
invoked = 2
|
||||
return nil
|
||||
})
|
||||
condDecoder := newCondAddrDecoder(decode1, decode2)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
val reflect.Value
|
||||
invoked int
|
||||
}{
|
||||
{"canAddr", addressable, 1},
|
||||
{"else", unaddressable, 2},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val)
|
||||
assert.Nil(t, err, "CondAddrDecoder error: %v", err)
|
||||
|
||||
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("error", func(t *testing.T) {
|
||||
errDecoder := newCondAddrDecoder(decode1, nil)
|
||||
err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable)
|
||||
want := errNoDecoder{Type: unaddressable.Type()}
|
||||
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
|
||||
})
|
||||
})
|
||||
}
|
431
copier.go
Normal file
431
copier.go
Normal file
@ -0,0 +1,431 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// copyDocument handles copying one document from the src to the dst.
|
||||
func copyDocument(dst ValueWriter, src ValueReader) error {
|
||||
dr, err := src.ReadDocument()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dw, err := dst.WriteDocument()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return copyDocumentCore(dw, dr)
|
||||
}
|
||||
|
||||
// copyArrayFromBytes copies the values from a BSON array represented as a
|
||||
// []byte to a ValueWriter.
|
||||
func copyArrayFromBytes(dst ValueWriter, src []byte) error {
|
||||
aw, err := dst.WriteArray()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = copyBytesToArrayWriter(aw, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return aw.WriteArrayEnd()
|
||||
}
|
||||
|
||||
// copyDocumentFromBytes copies the values from a BSON document represented as a
|
||||
// []byte to a ValueWriter.
|
||||
func copyDocumentFromBytes(dst ValueWriter, src []byte) error {
|
||||
dw, err := dst.WriteDocument()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = copyBytesToDocumentWriter(dw, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return dw.WriteDocumentEnd()
|
||||
}
|
||||
|
||||
type writeElementFn func(key string) (ValueWriter, error)
|
||||
|
||||
// copyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an
|
||||
// ArrayWriter.
|
||||
func copyBytesToArrayWriter(dst ArrayWriter, src []byte) error {
|
||||
wef := func(_ string) (ValueWriter, error) {
|
||||
return dst.WriteArrayElement()
|
||||
}
|
||||
|
||||
return copyBytesToValueWriter(src, wef)
|
||||
}
|
||||
|
||||
// copyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a
|
||||
// DocumentWriter.
|
||||
func copyBytesToDocumentWriter(dst DocumentWriter, src []byte) error {
|
||||
wef := func(key string) (ValueWriter, error) {
|
||||
return dst.WriteDocumentElement(key)
|
||||
}
|
||||
|
||||
return copyBytesToValueWriter(src, wef)
|
||||
}
|
||||
|
||||
func copyBytesToValueWriter(src []byte, wef writeElementFn) error {
|
||||
// TODO(skriptble): Create errors types here. Anything that is a tag should be a property.
|
||||
length, rem, ok := bsoncore.ReadLength(src)
|
||||
if !ok {
|
||||
return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src))
|
||||
}
|
||||
if len(src) < int(length) {
|
||||
return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length)
|
||||
}
|
||||
rem = rem[:length-4]
|
||||
|
||||
var t bsoncore.Type
|
||||
var key string
|
||||
var val bsoncore.Value
|
||||
for {
|
||||
t, rem, ok = bsoncore.ReadType(rem)
|
||||
if !ok {
|
||||
return io.EOF
|
||||
}
|
||||
if t == bsoncore.Type(0) {
|
||||
if len(rem) != 0 {
|
||||
return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
key, rem, ok = bsoncore.ReadKey(rem)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid key found. remaining bytes=%v", rem)
|
||||
}
|
||||
|
||||
// write as either array element or document element using writeElementFn
|
||||
vw, err := wef(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, rem, ok = bsoncore.ReadValue(rem, t)
|
||||
if !ok {
|
||||
return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t)
|
||||
}
|
||||
err = copyValueFromBytes(vw, Type(t), val.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// copyDocumentToBytes copies an entire document from the ValueReader and
|
||||
// returns it as bytes.
|
||||
func copyDocumentToBytes(src ValueReader) ([]byte, error) {
|
||||
return appendDocumentBytes(nil, src)
|
||||
}
|
||||
|
||||
// appendDocumentBytes functions the same as CopyDocumentToBytes, but will
|
||||
// append the result to dst.
|
||||
func appendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) {
|
||||
if br, ok := src.(bytesReader); ok {
|
||||
_, dst, err := br.readValueBytes(dst)
|
||||
return dst, err
|
||||
}
|
||||
|
||||
vw := vwPool.Get().(*valueWriter)
|
||||
defer putValueWriter(vw)
|
||||
|
||||
vw.reset(dst)
|
||||
|
||||
err := copyDocument(vw, src)
|
||||
dst = vw.buf
|
||||
return dst, err
|
||||
}
|
||||
|
||||
// appendArrayBytes copies an array from the ValueReader to dst.
|
||||
func appendArrayBytes(dst []byte, src ValueReader) ([]byte, error) {
|
||||
if br, ok := src.(bytesReader); ok {
|
||||
_, dst, err := br.readValueBytes(dst)
|
||||
return dst, err
|
||||
}
|
||||
|
||||
vw := vwPool.Get().(*valueWriter)
|
||||
defer putValueWriter(vw)
|
||||
|
||||
vw.reset(dst)
|
||||
|
||||
err := copyArray(vw, src)
|
||||
dst = vw.buf
|
||||
return dst, err
|
||||
}
|
||||
|
||||
// copyValueFromBytes will write the value represtend by t and src to dst.
|
||||
func copyValueFromBytes(dst ValueWriter, t Type, src []byte) error {
|
||||
if wvb, ok := dst.(bytesWriter); ok {
|
||||
return wvb.writeValueBytes(t, src)
|
||||
}
|
||||
|
||||
vr := newDocumentReader(bytes.NewReader(src))
|
||||
vr.pushElement(t)
|
||||
|
||||
return copyValue(dst, vr)
|
||||
}
|
||||
|
||||
// copyValueToBytes copies a value from src and returns it as a Type and a
|
||||
// []byte.
|
||||
func copyValueToBytes(src ValueReader) (Type, []byte, error) {
|
||||
if br, ok := src.(bytesReader); ok {
|
||||
return br.readValueBytes(nil)
|
||||
}
|
||||
|
||||
vw := vwPool.Get().(*valueWriter)
|
||||
defer putValueWriter(vw)
|
||||
|
||||
vw.reset(nil)
|
||||
vw.push(mElement)
|
||||
|
||||
err := copyValue(vw, src)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
return Type(vw.buf[0]), vw.buf[2:], nil
|
||||
}
|
||||
|
||||
// copyValue will copy a single value from src to dst.
|
||||
func copyValue(dst ValueWriter, src ValueReader) error {
|
||||
var err error
|
||||
switch src.Type() {
|
||||
case TypeDouble:
|
||||
var f64 float64
|
||||
f64, err = src.ReadDouble()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteDouble(f64)
|
||||
case TypeString:
|
||||
var str string
|
||||
str, err = src.ReadString()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = dst.WriteString(str)
|
||||
case TypeEmbeddedDocument:
|
||||
err = copyDocument(dst, src)
|
||||
case TypeArray:
|
||||
err = copyArray(dst, src)
|
||||
case TypeBinary:
|
||||
var data []byte
|
||||
var subtype byte
|
||||
data, subtype, err = src.ReadBinary()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteBinaryWithSubtype(data, subtype)
|
||||
case TypeUndefined:
|
||||
err = src.ReadUndefined()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteUndefined()
|
||||
case TypeObjectID:
|
||||
var oid ObjectID
|
||||
oid, err = src.ReadObjectID()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteObjectID(oid)
|
||||
case TypeBoolean:
|
||||
var b bool
|
||||
b, err = src.ReadBoolean()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteBoolean(b)
|
||||
case TypeDateTime:
|
||||
var dt int64
|
||||
dt, err = src.ReadDateTime()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteDateTime(dt)
|
||||
case TypeNull:
|
||||
err = src.ReadNull()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteNull()
|
||||
case TypeRegex:
|
||||
var pattern, options string
|
||||
pattern, options, err = src.ReadRegex()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteRegex(pattern, options)
|
||||
case TypeDBPointer:
|
||||
var ns string
|
||||
var pointer ObjectID
|
||||
ns, pointer, err = src.ReadDBPointer()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteDBPointer(ns, pointer)
|
||||
case TypeJavaScript:
|
||||
var js string
|
||||
js, err = src.ReadJavascript()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteJavascript(js)
|
||||
case TypeSymbol:
|
||||
var symbol string
|
||||
symbol, err = src.ReadSymbol()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteSymbol(symbol)
|
||||
case TypeCodeWithScope:
|
||||
var code string
|
||||
var srcScope DocumentReader
|
||||
code, srcScope, err = src.ReadCodeWithScope()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
var dstScope DocumentWriter
|
||||
dstScope, err = dst.WriteCodeWithScope(code)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = copyDocumentCore(dstScope, srcScope)
|
||||
case TypeInt32:
|
||||
var i32 int32
|
||||
i32, err = src.ReadInt32()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteInt32(i32)
|
||||
case TypeTimestamp:
|
||||
var t, i uint32
|
||||
t, i, err = src.ReadTimestamp()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteTimestamp(t, i)
|
||||
case TypeInt64:
|
||||
var i64 int64
|
||||
i64, err = src.ReadInt64()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteInt64(i64)
|
||||
case TypeDecimal128:
|
||||
var d128 Decimal128
|
||||
d128, err = src.ReadDecimal128()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteDecimal128(d128)
|
||||
case TypeMinKey:
|
||||
err = src.ReadMinKey()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteMinKey()
|
||||
case TypeMaxKey:
|
||||
err = src.ReadMaxKey()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = dst.WriteMaxKey()
|
||||
default:
|
||||
err = fmt.Errorf("cannot copy unknown BSON type %s", src.Type())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func copyArray(dst ValueWriter, src ValueReader) error {
|
||||
ar, err := src.ReadArray()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
aw, err := dst.WriteArray()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
vr, err := ar.ReadValue()
|
||||
if errors.Is(err, ErrEOA) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vw, err := aw.WriteArrayElement()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = copyValue(vw, vr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return aw.WriteArrayEnd()
|
||||
}
|
||||
|
||||
func copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
|
||||
for {
|
||||
key, vr, err := dr.ReadElement()
|
||||
if errors.Is(err, ErrEOD) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vw, err := dw.WriteDocumentElement(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = copyValue(vw, vr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return dw.WriteDocumentEnd()
|
||||
}
|
||||
|
||||
// bytesReader is the interface used to read BSON bytes from a valueReader.
|
||||
//
|
||||
// The bytes of the value will be appended to dst.
|
||||
type bytesReader interface {
|
||||
readValueBytes(dst []byte) (Type, []byte, error)
|
||||
}
|
||||
|
||||
// bytesWriter is the interface used to write BSON bytes to a valueWriter.
|
||||
type bytesWriter interface {
|
||||
writeValueBytes(t Type, b []byte) error
|
||||
}
|
528
copier_test.go
Normal file
528
copier_test.go
Normal file
@ -0,0 +1,528 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
func TestCopier(t *testing.T) {
|
||||
t.Run("CopyDocument", func(t *testing.T) {
|
||||
t.Run("ReadDocument Error", func(t *testing.T) {
|
||||
want := errors.New("ReadDocumentError")
|
||||
src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadDocument}
|
||||
got := copyDocument(nil, src)
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteDocument Error", func(t *testing.T) {
|
||||
want := errors.New("WriteDocumentError")
|
||||
src := &TestValueReaderWriter{}
|
||||
dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteDocument}
|
||||
got := copyDocument(dst, src)
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("success", func(t *testing.T) {
|
||||
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||||
doc = bsoncore.AppendStringElement(doc, "Hello", "world")
|
||||
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
|
||||
noerr(t, err)
|
||||
src := newDocumentReader(bytes.NewReader(doc))
|
||||
dst := newValueWriterFromSlice(make([]byte, 0))
|
||||
want := doc
|
||||
err = copyDocument(dst, src)
|
||||
noerr(t, err)
|
||||
got := dst.buf
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("copyArray", func(t *testing.T) {
|
||||
t.Run("ReadArray Error", func(t *testing.T) {
|
||||
want := errors.New("ReadArrayError")
|
||||
src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadArray}
|
||||
got := copyArray(nil, src)
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteArray Error", func(t *testing.T) {
|
||||
want := errors.New("WriteArrayError")
|
||||
src := &TestValueReaderWriter{}
|
||||
dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteArray}
|
||||
got := copyArray(dst, src)
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("success", func(t *testing.T) {
|
||||
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||||
aidx, doc := bsoncore.AppendArrayElementStart(doc, "foo")
|
||||
doc = bsoncore.AppendStringElement(doc, "0", "Hello, world!")
|
||||
doc, err := bsoncore.AppendArrayEnd(doc, aidx)
|
||||
noerr(t, err)
|
||||
doc, err = bsoncore.AppendDocumentEnd(doc, idx)
|
||||
noerr(t, err)
|
||||
src := newDocumentReader(bytes.NewReader(doc))
|
||||
|
||||
_, err = src.ReadDocument()
|
||||
noerr(t, err)
|
||||
_, _, err = src.ReadElement()
|
||||
noerr(t, err)
|
||||
|
||||
dst := newValueWriterFromSlice(make([]byte, 0))
|
||||
_, err = dst.WriteDocument()
|
||||
noerr(t, err)
|
||||
_, err = dst.WriteDocumentElement("foo")
|
||||
noerr(t, err)
|
||||
want := doc
|
||||
|
||||
err = copyArray(dst, src)
|
||||
noerr(t, err)
|
||||
|
||||
err = dst.WriteDocumentEnd()
|
||||
noerr(t, err)
|
||||
|
||||
got := dst.buf
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("CopyValue", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
dst *TestValueReaderWriter
|
||||
src *TestValueReaderWriter
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"Double/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeDouble, err: errors.New("1"), errAfter: llvrwReadDouble},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Double/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeDouble, err: errors.New("2"), errAfter: llvrwWriteDouble},
|
||||
&TestValueReaderWriter{bsontype: TypeDouble, readval: float64(3.14159)},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"String/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeString, err: errors.New("1"), errAfter: llvrwReadString},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"String/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeString, err: errors.New("2"), errAfter: llvrwWriteString},
|
||||
&TestValueReaderWriter{bsontype: TypeString, readval: "hello, world"},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Document/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeEmbeddedDocument, err: errors.New("1"), errAfter: llvrwReadDocument},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Array/dst/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeArray, err: errors.New("2"), errAfter: llvrwReadArray},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Binary/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeBinary, err: errors.New("1"), errAfter: llvrwReadBinary},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Binary/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeBinary, err: errors.New("2"), errAfter: llvrwWriteBinaryWithSubtype},
|
||||
&TestValueReaderWriter{
|
||||
bsontype: TypeBinary,
|
||||
readval: bsoncore.Value{
|
||||
Type: bsoncore.TypeBinary,
|
||||
Data: []byte{0x03, 0x00, 0x00, 0x00, 0xFF, 0x01, 0x02, 0x03},
|
||||
},
|
||||
},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Undefined/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeUndefined, err: errors.New("1"), errAfter: llvrwReadUndefined},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Undefined/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeUndefined, err: errors.New("2"), errAfter: llvrwWriteUndefined},
|
||||
&TestValueReaderWriter{bsontype: TypeUndefined},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"ObjectID/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeObjectID, err: errors.New("1"), errAfter: llvrwReadObjectID},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"ObjectID/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeObjectID, err: errors.New("2"), errAfter: llvrwWriteObjectID},
|
||||
&TestValueReaderWriter{bsontype: TypeObjectID, readval: ObjectID{0x01, 0x02, 0x03}},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Boolean/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeBoolean, err: errors.New("1"), errAfter: llvrwReadBoolean},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Boolean/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeBoolean, err: errors.New("2"), errAfter: llvrwWriteBoolean},
|
||||
&TestValueReaderWriter{bsontype: TypeBoolean, readval: bool(true)},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"DateTime/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeDateTime, err: errors.New("1"), errAfter: llvrwReadDateTime},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"DateTime/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeDateTime, err: errors.New("2"), errAfter: llvrwWriteDateTime},
|
||||
&TestValueReaderWriter{bsontype: TypeDateTime, readval: int64(1234567890)},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Null/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeNull, err: errors.New("1"), errAfter: llvrwReadNull},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Null/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeNull, err: errors.New("2"), errAfter: llvrwWriteNull},
|
||||
&TestValueReaderWriter{bsontype: TypeNull},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Regex/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeRegex, err: errors.New("1"), errAfter: llvrwReadRegex},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Regex/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeRegex, err: errors.New("2"), errAfter: llvrwWriteRegex},
|
||||
&TestValueReaderWriter{
|
||||
bsontype: TypeRegex,
|
||||
readval: bsoncore.Value{
|
||||
Type: bsoncore.TypeRegex,
|
||||
Data: bsoncore.AppendRegex(nil, "hello", "world"),
|
||||
},
|
||||
},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"DBPointer/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeDBPointer, err: errors.New("1"), errAfter: llvrwReadDBPointer},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"DBPointer/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeDBPointer, err: errors.New("2"), errAfter: llvrwWriteDBPointer},
|
||||
&TestValueReaderWriter{
|
||||
bsontype: TypeDBPointer,
|
||||
readval: bsoncore.Value{
|
||||
Type: bsoncore.TypeDBPointer,
|
||||
Data: bsoncore.AppendDBPointer(nil, "foo", ObjectID{0x01, 0x02, 0x03}),
|
||||
},
|
||||
},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Javascript/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeJavaScript, err: errors.New("1"), errAfter: llvrwReadJavascript},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Javascript/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeJavaScript, err: errors.New("2"), errAfter: llvrwWriteJavascript},
|
||||
&TestValueReaderWriter{bsontype: TypeJavaScript, readval: "hello, world"},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Symbol/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeSymbol, err: errors.New("1"), errAfter: llvrwReadSymbol},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Symbol/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeSymbol, err: errors.New("2"), errAfter: llvrwWriteSymbol},
|
||||
&TestValueReaderWriter{
|
||||
bsontype: TypeSymbol,
|
||||
readval: bsoncore.Value{
|
||||
Type: bsoncore.TypeSymbol,
|
||||
Data: bsoncore.AppendSymbol(nil, "hello, world"),
|
||||
},
|
||||
},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"CodeWithScope/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeCodeWithScope, err: errors.New("1"), errAfter: llvrwReadCodeWithScope},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"CodeWithScope/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeCodeWithScope, err: errors.New("2"), errAfter: llvrwWriteCodeWithScope},
|
||||
&TestValueReaderWriter{bsontype: TypeCodeWithScope},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"CodeWithScope/dst/copyDocumentCore error",
|
||||
&TestValueReaderWriter{err: errors.New("3"), errAfter: llvrwWriteDocumentElement},
|
||||
&TestValueReaderWriter{bsontype: TypeCodeWithScope},
|
||||
errors.New("3"),
|
||||
},
|
||||
{
|
||||
"Int32/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeInt32, err: errors.New("1"), errAfter: llvrwReadInt32},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Int32/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeInt32, err: errors.New("2"), errAfter: llvrwWriteInt32},
|
||||
&TestValueReaderWriter{bsontype: TypeInt32, readval: int32(12345)},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Timestamp/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeTimestamp, err: errors.New("1"), errAfter: llvrwReadTimestamp},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Timestamp/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeTimestamp, err: errors.New("2"), errAfter: llvrwWriteTimestamp},
|
||||
&TestValueReaderWriter{
|
||||
bsontype: TypeTimestamp,
|
||||
readval: bsoncore.Value{
|
||||
Type: bsoncore.TypeTimestamp,
|
||||
Data: bsoncore.AppendTimestamp(nil, 12345, 67890),
|
||||
},
|
||||
},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Int64/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeInt64, err: errors.New("1"), errAfter: llvrwReadInt64},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Int64/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeInt64, err: errors.New("2"), errAfter: llvrwWriteInt64},
|
||||
&TestValueReaderWriter{bsontype: TypeInt64, readval: int64(1234567890)},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Decimal128/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeDecimal128, err: errors.New("1"), errAfter: llvrwReadDecimal128},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"Decimal128/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeDecimal128, err: errors.New("2"), errAfter: llvrwWriteDecimal128},
|
||||
&TestValueReaderWriter{bsontype: TypeDecimal128, readval: NewDecimal128(12345, 67890)},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"MinKey/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeMinKey, err: errors.New("1"), errAfter: llvrwReadMinKey},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"MinKey/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeMinKey, err: errors.New("2"), errAfter: llvrwWriteMinKey},
|
||||
&TestValueReaderWriter{bsontype: TypeMinKey},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"MaxKey/src/error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{bsontype: TypeMaxKey, err: errors.New("1"), errAfter: llvrwReadMaxKey},
|
||||
errors.New("1"),
|
||||
},
|
||||
{
|
||||
"MaxKey/dst/error",
|
||||
&TestValueReaderWriter{bsontype: TypeMaxKey, err: errors.New("2"), errAfter: llvrwWriteMaxKey},
|
||||
&TestValueReaderWriter{bsontype: TypeMaxKey},
|
||||
errors.New("2"),
|
||||
},
|
||||
{
|
||||
"Unknown BSON type error",
|
||||
&TestValueReaderWriter{},
|
||||
&TestValueReaderWriter{},
|
||||
fmt.Errorf("cannot copy unknown BSON type %s", Type(0)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.dst.t, tc.src.t = t, t
|
||||
err := copyValue(tc.dst, tc.src)
|
||||
if !assert.CompareErrors(err, tc.err) {
|
||||
t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("CopyValueFromBytes", func(t *testing.T) {
|
||||
t.Run("BytesWriter", func(t *testing.T) {
|
||||
vw := newValueWriterFromSlice(make([]byte, 0))
|
||||
_, err := vw.WriteDocument()
|
||||
noerr(t, err)
|
||||
_, err = vw.WriteDocumentElement("foo")
|
||||
noerr(t, err)
|
||||
err = copyValueFromBytes(vw, TypeString, bsoncore.AppendString(nil, "bar"))
|
||||
noerr(t, err)
|
||||
err = vw.WriteDocumentEnd()
|
||||
noerr(t, err)
|
||||
var idx int32
|
||||
want, err := bsoncore.AppendDocumentEnd(
|
||||
bsoncore.AppendStringElement(
|
||||
bsoncore.AppendDocumentStartInline(nil, &idx),
|
||||
"foo", "bar",
|
||||
),
|
||||
idx,
|
||||
)
|
||||
noerr(t, err)
|
||||
got := vw.buf
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("Non BytesWriter", func(t *testing.T) {
|
||||
llvrw := &TestValueReaderWriter{t: t}
|
||||
err := copyValueFromBytes(llvrw, TypeString, bsoncore.AppendString(nil, "bar"))
|
||||
noerr(t, err)
|
||||
got, want := llvrw.invoked, llvrwWriteString
|
||||
if got != want {
|
||||
t.Errorf("Incorrect method invoked on llvrw. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("CopyValueToBytes", func(t *testing.T) {
|
||||
t.Run("BytesReader", func(t *testing.T) {
|
||||
var idx int32
|
||||
b, err := bsoncore.AppendDocumentEnd(
|
||||
bsoncore.AppendStringElement(
|
||||
bsoncore.AppendDocumentStartInline(nil, &idx),
|
||||
"hello", "world",
|
||||
),
|
||||
idx,
|
||||
)
|
||||
noerr(t, err)
|
||||
vr := newDocumentReader(bytes.NewReader(b))
|
||||
_, err = vr.ReadDocument()
|
||||
noerr(t, err)
|
||||
_, _, err = vr.ReadElement()
|
||||
noerr(t, err)
|
||||
btype, got, err := copyValueToBytes(vr)
|
||||
noerr(t, err)
|
||||
want := bsoncore.AppendString(nil, "world")
|
||||
if btype != TypeString {
|
||||
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
|
||||
}
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Bytes do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("Non BytesReader", func(t *testing.T) {
|
||||
llvrw := &TestValueReaderWriter{t: t, bsontype: TypeString, readval: "Hello, world!"}
|
||||
btype, got, err := copyValueToBytes(llvrw)
|
||||
noerr(t, err)
|
||||
want := bsoncore.AppendString(nil, "Hello, world!")
|
||||
if btype != TypeString {
|
||||
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
|
||||
}
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Bytes do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("AppendValueBytes", func(t *testing.T) {
|
||||
t.Run("BytesReader", func(t *testing.T) {
|
||||
var idx int32
|
||||
b, err := bsoncore.AppendDocumentEnd(
|
||||
bsoncore.AppendStringElement(
|
||||
bsoncore.AppendDocumentStartInline(nil, &idx),
|
||||
"hello", "world",
|
||||
),
|
||||
idx,
|
||||
)
|
||||
noerr(t, err)
|
||||
vr := newDocumentReader(bytes.NewReader(b))
|
||||
_, err = vr.ReadDocument()
|
||||
noerr(t, err)
|
||||
_, _, err = vr.ReadElement()
|
||||
noerr(t, err)
|
||||
btype, got, err := copyValueToBytes(vr)
|
||||
noerr(t, err)
|
||||
want := bsoncore.AppendString(nil, "world")
|
||||
if btype != TypeString {
|
||||
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
|
||||
}
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Bytes do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("Non BytesReader", func(t *testing.T) {
|
||||
llvrw := &TestValueReaderWriter{t: t, bsontype: TypeString, readval: "Hello, world!"}
|
||||
btype, got, err := copyValueToBytes(llvrw)
|
||||
noerr(t, err)
|
||||
want := bsoncore.AppendString(nil, "Hello, world!")
|
||||
if btype != TypeString {
|
||||
t.Errorf("Incorrect type returned. got %v; want %v", btype, TypeString)
|
||||
}
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Bytes do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("CopyValue error", func(t *testing.T) {
|
||||
want := errors.New("CopyValue error")
|
||||
llvrw := &TestValueReaderWriter{t: t, bsontype: TypeString, err: want, errAfter: llvrwReadString}
|
||||
_, _, got := copyValueToBytes(llvrw)
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
339
decimal.go
Normal file
339
decimal.go
Normal file
@ -0,0 +1,339 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
|
||||
// See THIRD-PARTY-NOTICES for original license terms.
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/decimal128"
|
||||
)
|
||||
|
||||
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
|
||||
const (
|
||||
MaxDecimal128Exp = 6111
|
||||
MinDecimal128Exp = -6176
|
||||
)
|
||||
|
||||
// These errors are returned when an invalid value is parsed as a big.Int.
|
||||
var (
|
||||
ErrParseNaN = errors.New("cannot parse NaN as a *big.Int")
|
||||
ErrParseInf = errors.New("cannot parse Infinity as a *big.Int")
|
||||
ErrParseNegInf = errors.New("cannot parse -Infinity as a *big.Int")
|
||||
)
|
||||
|
||||
// Decimal128 holds decimal128 BSON values.
|
||||
type Decimal128 struct {
|
||||
h, l uint64
|
||||
}
|
||||
|
||||
// NewDecimal128 creates a Decimal128 using the provide high and low uint64s.
|
||||
func NewDecimal128(h, l uint64) Decimal128 {
|
||||
return Decimal128{h: h, l: l}
|
||||
}
|
||||
|
||||
// GetBytes returns the underlying bytes of the BSON decimal value as two uint64 values. The first
|
||||
// contains the most first 8 bytes of the value and the second contains the latter.
|
||||
func (d Decimal128) GetBytes() (uint64, uint64) {
|
||||
return d.h, d.l
|
||||
}
|
||||
|
||||
// String returns a string representation of the decimal value.
|
||||
func (d Decimal128) String() string {
|
||||
return decimal128.String(d.h, d.l)
|
||||
}
|
||||
|
||||
// BigInt returns significand as big.Int and exponent, bi * 10 ^ exp.
|
||||
func (d Decimal128) BigInt() (*big.Int, int, error) {
|
||||
high, low := d.GetBytes()
|
||||
posSign := high>>63&1 == 0 // positive sign
|
||||
|
||||
switch high >> 58 & (1<<5 - 1) {
|
||||
case 0x1F:
|
||||
return nil, 0, ErrParseNaN
|
||||
case 0x1E:
|
||||
if posSign {
|
||||
return nil, 0, ErrParseInf
|
||||
}
|
||||
return nil, 0, ErrParseNegInf
|
||||
}
|
||||
|
||||
var exp int
|
||||
if high>>61&3 == 3 {
|
||||
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
|
||||
// Implicit 0b100 prefix in significand.
|
||||
exp = int(high >> 47 & (1<<14 - 1))
|
||||
// Spec says all of these values are out of range.
|
||||
high, low = 0, 0
|
||||
} else {
|
||||
// Bits: 1*sign 14*exponent 113*significand
|
||||
exp = int(high >> 49 & (1<<14 - 1))
|
||||
high &= (1<<49 - 1)
|
||||
}
|
||||
exp += MinDecimal128Exp
|
||||
|
||||
// Would be handled by the logic below, but that's trivial and common.
|
||||
if high == 0 && low == 0 && exp == 0 {
|
||||
return new(big.Int), 0, nil
|
||||
}
|
||||
|
||||
bi := big.NewInt(0)
|
||||
const host32bit = ^uint(0)>>32 == 0
|
||||
if host32bit {
|
||||
bi.SetBits([]big.Word{big.Word(low), big.Word(low >> 32), big.Word(high), big.Word(high >> 32)})
|
||||
} else {
|
||||
bi.SetBits([]big.Word{big.Word(low), big.Word(high)})
|
||||
}
|
||||
|
||||
if !posSign {
|
||||
return bi.Neg(bi), exp, nil
|
||||
}
|
||||
return bi, exp, nil
|
||||
}
|
||||
|
||||
// IsNaN returns whether d is NaN.
|
||||
func (d Decimal128) IsNaN() bool {
|
||||
return d.h>>58&(1<<5-1) == 0x1F
|
||||
}
|
||||
|
||||
// IsInf returns:
|
||||
//
|
||||
// +1 d == Infinity
|
||||
// 0 other case
|
||||
// -1 d == -Infinity
|
||||
func (d Decimal128) IsInf() int {
|
||||
if d.h>>58&(1<<5-1) != 0x1E {
|
||||
return 0
|
||||
}
|
||||
|
||||
if d.h>>63&1 == 0 {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// IsZero returns true if d is the empty Decimal128.
|
||||
func (d Decimal128) IsZero() bool {
|
||||
return d.h == 0 && d.l == 0
|
||||
}
|
||||
|
||||
// MarshalJSON returns Decimal128 as a string.
|
||||
func (d Decimal128) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON creates a Decimal128 from a JSON string, an extended JSON $numberDecimal value, or the string
|
||||
// "null". If b is a JSON string or extended JSON value, d will have the value of that string, and if b is "null", d will
|
||||
// be unchanged.
|
||||
func (d *Decimal128) UnmarshalJSON(b []byte) error {
|
||||
// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer Decimal128 field
|
||||
// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
|
||||
// enter the UnmarshalJSON hook.
|
||||
if string(b) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var res interface{}
|
||||
err := json.Unmarshal(b, &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
str, ok := res.(string)
|
||||
|
||||
// Extended JSON
|
||||
if !ok {
|
||||
m, ok := res.(map[string]interface{})
|
||||
if !ok {
|
||||
return errors.New("not an extended JSON Decimal128: expected document")
|
||||
}
|
||||
d128, ok := m["$numberDecimal"]
|
||||
if !ok {
|
||||
return errors.New("not an extended JSON Decimal128: expected key $numberDecimal")
|
||||
}
|
||||
str, ok = d128.(string)
|
||||
if !ok {
|
||||
return errors.New("not an extended JSON Decimal128: expected decimal to be string")
|
||||
}
|
||||
}
|
||||
|
||||
*d, err = ParseDecimal128(str)
|
||||
return err
|
||||
}
|
||||
|
||||
var dNaN = Decimal128{0x1F << 58, 0}
|
||||
var dPosInf = Decimal128{0x1E << 58, 0}
|
||||
var dNegInf = Decimal128{0x3E << 58, 0}
|
||||
|
||||
func dErr(s string) (Decimal128, error) {
|
||||
return dNaN, fmt.Errorf("cannot parse %q as a decimal128", s)
|
||||
}
|
||||
|
||||
// match scientific notation number, example -10.15e-18
|
||||
var normalNumber = regexp.MustCompile(`^(?P<int>[-+]?\d*)?(?:\.(?P<dec>\d*))?(?:[Ee](?P<exp>[-+]?\d+))?$`)
|
||||
|
||||
// ParseDecimal128 takes the given string and attempts to parse it into a valid
|
||||
// Decimal128 value.
|
||||
func ParseDecimal128(s string) (Decimal128, error) {
|
||||
if s == "" {
|
||||
return dErr(s)
|
||||
}
|
||||
|
||||
matches := normalNumber.FindStringSubmatch(s)
|
||||
if len(matches) == 0 {
|
||||
orig := s
|
||||
neg := s[0] == '-'
|
||||
if neg || s[0] == '+' {
|
||||
s = s[1:]
|
||||
}
|
||||
|
||||
if s == "NaN" || s == "nan" || strings.EqualFold(s, "nan") {
|
||||
return dNaN, nil
|
||||
}
|
||||
if s == "Inf" || s == "inf" || strings.EqualFold(s, "inf") || strings.EqualFold(s, "infinity") {
|
||||
if neg {
|
||||
return dNegInf, nil
|
||||
}
|
||||
return dPosInf, nil
|
||||
}
|
||||
return dErr(orig)
|
||||
}
|
||||
|
||||
intPart := matches[1]
|
||||
decPart := matches[2]
|
||||
expPart := matches[3]
|
||||
|
||||
var err error
|
||||
exp := 0
|
||||
if expPart != "" {
|
||||
exp, err = strconv.Atoi(expPart)
|
||||
if err != nil {
|
||||
return dErr(s)
|
||||
}
|
||||
}
|
||||
if decPart != "" {
|
||||
exp -= len(decPart)
|
||||
}
|
||||
|
||||
if len(strings.Trim(intPart+decPart, "-0")) > 35 {
|
||||
return dErr(s)
|
||||
}
|
||||
|
||||
// Parse the significand (i.e. the non-exponent part) as a big.Int.
|
||||
bi, ok := new(big.Int).SetString(intPart+decPart, 10)
|
||||
if !ok {
|
||||
return dErr(s)
|
||||
}
|
||||
|
||||
d, ok := ParseDecimal128FromBigInt(bi, exp)
|
||||
if !ok {
|
||||
return dErr(s)
|
||||
}
|
||||
|
||||
if bi.Sign() == 0 && s[0] == '-' {
|
||||
d.h |= 1 << 63
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
var (
|
||||
ten = big.NewInt(10)
|
||||
zero = new(big.Int)
|
||||
|
||||
maxS, _ = new(big.Int).SetString("9999999999999999999999999999999999", 10)
|
||||
)
|
||||
|
||||
// ParseDecimal128FromBigInt attempts to parse the given significand and exponent into a valid Decimal128 value.
|
||||
func ParseDecimal128FromBigInt(bi *big.Int, exp int) (Decimal128, bool) {
|
||||
// copy
|
||||
bi = new(big.Int).Set(bi)
|
||||
|
||||
q := new(big.Int)
|
||||
r := new(big.Int)
|
||||
|
||||
// If the significand is zero, the logical value will always be zero, independent of the
|
||||
// exponent. However, the loops for handling out-of-range exponent values below may be extremely
|
||||
// slow for zero values because the significand never changes. Limit the exponent value to the
|
||||
// supported range here to prevent entering the loops below.
|
||||
if bi.Cmp(zero) == 0 {
|
||||
if exp > MaxDecimal128Exp {
|
||||
exp = MaxDecimal128Exp
|
||||
}
|
||||
if exp < MinDecimal128Exp {
|
||||
exp = MinDecimal128Exp
|
||||
}
|
||||
}
|
||||
|
||||
for bigIntCmpAbs(bi, maxS) == 1 {
|
||||
bi, _ = q.QuoRem(bi, ten, r)
|
||||
if r.Cmp(zero) != 0 {
|
||||
return Decimal128{}, false
|
||||
}
|
||||
exp++
|
||||
if exp > MaxDecimal128Exp {
|
||||
return Decimal128{}, false
|
||||
}
|
||||
}
|
||||
|
||||
for exp < MinDecimal128Exp {
|
||||
// Subnormal.
|
||||
bi, _ = q.QuoRem(bi, ten, r)
|
||||
if r.Cmp(zero) != 0 {
|
||||
return Decimal128{}, false
|
||||
}
|
||||
exp++
|
||||
}
|
||||
for exp > MaxDecimal128Exp {
|
||||
// Clamped.
|
||||
bi.Mul(bi, ten)
|
||||
if bigIntCmpAbs(bi, maxS) == 1 {
|
||||
return Decimal128{}, false
|
||||
}
|
||||
exp--
|
||||
}
|
||||
|
||||
b := bi.Bytes()
|
||||
var h, l uint64
|
||||
for i := 0; i < len(b); i++ {
|
||||
if i < len(b)-8 {
|
||||
h = h<<8 | uint64(b[i])
|
||||
continue
|
||||
}
|
||||
l = l<<8 | uint64(b[i])
|
||||
}
|
||||
|
||||
h |= uint64(exp-MinDecimal128Exp) & uint64(1<<14-1) << 49
|
||||
if bi.Sign() == -1 {
|
||||
h |= 1 << 63
|
||||
}
|
||||
|
||||
return Decimal128{h: h, l: l}, true
|
||||
}
|
||||
|
||||
// bigIntCmpAbs computes big.Int.Cmp(absoluteValue(x), absoluteValue(y)).
|
||||
func bigIntCmpAbs(x, y *big.Int) int {
|
||||
xAbs := bigIntAbsValue(x)
|
||||
yAbs := bigIntAbsValue(y)
|
||||
return xAbs.Cmp(yAbs)
|
||||
}
|
||||
|
||||
// bigIntAbsValue returns a big.Int containing the absolute value of b.
|
||||
// If b is already a non-negative number, it is returned without any changes or copies.
|
||||
func bigIntAbsValue(b *big.Int) *big.Int {
|
||||
if b.Sign() >= 0 {
|
||||
return b // already positive
|
||||
}
|
||||
return new(big.Int).Abs(b)
|
||||
}
|
236
decimal_test.go
Normal file
236
decimal_test.go
Normal file
@ -0,0 +1,236 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"gitea.psichedelico.com/go/bson/internal/require"
|
||||
)
|
||||
|
||||
type bigIntTestCase struct {
|
||||
s string
|
||||
|
||||
h uint64
|
||||
l uint64
|
||||
|
||||
bi *big.Int
|
||||
exp int
|
||||
|
||||
remark string
|
||||
}
|
||||
|
||||
func parseBigInt(s string) *big.Int {
|
||||
bi, _ := new(big.Int).SetString(s, 10)
|
||||
return bi
|
||||
}
|
||||
|
||||
var (
|
||||
one = big.NewInt(1)
|
||||
|
||||
biMaxS = new(big.Int).SetBytes([]byte{0x1, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
|
||||
biNMaxS = new(big.Int).Neg(biMaxS)
|
||||
|
||||
biOverflow = new(big.Int).Add(biMaxS, one)
|
||||
biNOverflow = new(big.Int).Neg(biOverflow)
|
||||
|
||||
bi12345 = parseBigInt("12345")
|
||||
biN12345 = parseBigInt("-12345")
|
||||
|
||||
bi9_14 = parseBigInt("90123456789012")
|
||||
biN9_14 = parseBigInt("-90123456789012")
|
||||
|
||||
bi9_34 = parseBigInt("9999999999999999999999999999999999")
|
||||
biN9_34 = parseBigInt("-9999999999999999999999999999999999")
|
||||
)
|
||||
|
||||
var bigIntTestCases = []bigIntTestCase{
|
||||
{s: "12345", h: 0x3040000000000000, l: 12345, bi: bi12345},
|
||||
{s: "-12345", h: 0xB040000000000000, l: 12345, bi: biN12345},
|
||||
|
||||
{s: "90123456.789012", h: 0x3034000000000000, l: 90123456789012, bi: bi9_14, exp: -6},
|
||||
{s: "-90123456.789012", h: 0xB034000000000000, l: 90123456789012, bi: biN9_14, exp: -6},
|
||||
{s: "9.0123456789012E+22", h: 0x3052000000000000, l: 90123456789012, bi: bi9_14, exp: 9},
|
||||
{s: "-9.0123456789012E+22", h: 0xB052000000000000, l: 90123456789012, bi: biN9_14, exp: 9},
|
||||
{s: "9.0123456789012E-8", h: 0x3016000000000000, l: 90123456789012, bi: bi9_14, exp: -21},
|
||||
{s: "-9.0123456789012E-8", h: 0xB016000000000000, l: 90123456789012, bi: biN9_14, exp: -21},
|
||||
|
||||
{s: "9999999999999999999999999999999999", h: 3477321013416265664, l: 4003012203950112767, bi: bi9_34},
|
||||
{s: "-9999999999999999999999999999999999", h: 12700693050271041472, l: 4003012203950112767, bi: biN9_34},
|
||||
{s: "0.9999999999999999999999999999999999", h: 3458180714999941056, l: 4003012203950112767, bi: bi9_34, exp: -34},
|
||||
{s: "-0.9999999999999999999999999999999999", h: 12681552751854716864, l: 4003012203950112767, bi: biN9_34, exp: -34},
|
||||
{s: "99999999999999999.99999999999999999", h: 3467750864208103360, l: 4003012203950112767, bi: bi9_34, exp: -17},
|
||||
{s: "-99999999999999999.99999999999999999", h: 12691122901062879168, l: 4003012203950112767, bi: biN9_34, exp: -17},
|
||||
{s: "9.999999999999999999999999999999999E+35", h: 3478446913323108288, l: 4003012203950112767, bi: bi9_34, exp: 2},
|
||||
{s: "-9.999999999999999999999999999999999E+35", h: 12701818950177884096, l: 4003012203950112767, bi: biN9_34, exp: 2},
|
||||
{s: "9.999999999999999999999999999999999E+40", h: 3481261663090214848, l: 4003012203950112767, bi: bi9_34, exp: 7},
|
||||
{s: "-9.999999999999999999999999999999999E+40", h: 12704633699944990656, l: 4003012203950112767, bi: biN9_34, exp: 7},
|
||||
{s: "99999999999999999999999999999.99999", h: 3474506263649159104, l: 4003012203950112767, bi: bi9_34, exp: -5},
|
||||
{s: "-99999999999999999999999999999.99999", h: 12697878300503934912, l: 4003012203950112767, bi: biN9_34, exp: -5},
|
||||
|
||||
{s: "1.038459371706965525706099265844019E-6143", remark: "subnormal", h: 0x333333333333, l: 0x3333333333333333, bi: parseBigInt("10384593717069655257060992658440190"), exp: MinDecimal128Exp - 1},
|
||||
{s: "-1.038459371706965525706099265844019E-6143", remark: "subnormal", h: 0x8000333333333333, l: 0x3333333333333333, bi: parseBigInt("-10384593717069655257060992658440190"), exp: MinDecimal128Exp - 1},
|
||||
|
||||
{s: "rounding overflow 1", remark: "overflow", bi: parseBigInt("103845937170696552570609926584401910"), exp: MaxDecimal128Exp},
|
||||
{s: "rounding overflow 2", remark: "overflow", bi: parseBigInt("103845937170696552570609926584401910"), exp: MaxDecimal128Exp},
|
||||
|
||||
{s: "subnormal overflow 1", remark: "overflow", bi: biMaxS, exp: MinDecimal128Exp - 1},
|
||||
{s: "subnormal overflow 2", remark: "overflow", bi: biNMaxS, exp: MinDecimal128Exp - 1},
|
||||
|
||||
{s: "clamped overflow 1", remark: "overflow", bi: biMaxS, exp: MaxDecimal128Exp + 1},
|
||||
{s: "clamped overflow 2", remark: "overflow", bi: biNMaxS, exp: MaxDecimal128Exp + 1},
|
||||
|
||||
{s: "biMaxS+1 overflow", remark: "overflow", bi: biOverflow, exp: MaxDecimal128Exp},
|
||||
{s: "biNMaxS-1 overflow", remark: "overflow", bi: biNOverflow, exp: MaxDecimal128Exp},
|
||||
|
||||
{s: "NaN", h: 0x7c00000000000000, l: 0, remark: "NaN"},
|
||||
{s: "Infinity", h: 0x7800000000000000, l: 0, remark: "Infinity"},
|
||||
{s: "-Infinity", h: 0xf800000000000000, l: 0, remark: "-Infinity"},
|
||||
}
|
||||
|
||||
func TestDecimal128_BigInt(t *testing.T) {
|
||||
for _, c := range bigIntTestCases {
|
||||
t.Run(c.s, func(t *testing.T) {
|
||||
switch c.remark {
|
||||
case "NaN", "Infinity", "-Infinity":
|
||||
d128 := NewDecimal128(c.h, c.l)
|
||||
_, _, err := d128.BigInt()
|
||||
require.Error(t, err, "case %s", c.s)
|
||||
case "":
|
||||
d128 := NewDecimal128(c.h, c.l)
|
||||
bi, e, err := d128.BigInt()
|
||||
require.NoError(t, err, "case %s", c.s)
|
||||
require.Equal(t, 0, c.bi.Cmp(bi), "case %s e:%s a:%s", c.s, c.bi.String(), bi.String())
|
||||
require.Equal(t, c.exp, e, "case %s", c.s, d128.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDecimal128FromBigInt(t *testing.T) {
|
||||
for _, c := range bigIntTestCases {
|
||||
switch c.remark {
|
||||
case "overflow":
|
||||
d128, ok := ParseDecimal128FromBigInt(c.bi, c.exp)
|
||||
require.Equal(t, false, ok, "case %s %s", c.s, d128.String(), c.remark)
|
||||
case "", "rounding", "subnormal", "clamped":
|
||||
d128, ok := ParseDecimal128FromBigInt(c.bi, c.exp)
|
||||
require.Equal(t, true, ok, "case %s", c.s)
|
||||
require.Equal(t, c.s, d128.String(), "case %s", c.s)
|
||||
|
||||
require.Equal(t, c.h, d128.h, "case %s", c.s, d128.l)
|
||||
require.Equal(t, c.l, d128.l, "case %s", c.s, d128.h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDecimal128(t *testing.T) {
|
||||
cases := make([]bigIntTestCase, 0, len(bigIntTestCases))
|
||||
cases = append(cases, bigIntTestCases...)
|
||||
cases = append(cases,
|
||||
bigIntTestCase{s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800},
|
||||
bigIntTestCase{s: "12345E+21", h: 0x306a000000000000, l: 12345},
|
||||
bigIntTestCase{s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"},
|
||||
bigIntTestCase{s: ".125e1", h: 0x303c000000000000, l: 125},
|
||||
bigIntTestCase{s: ".125", h: 0x303a000000000000, l: 125},
|
||||
// Test that parsing negative zero returns negative zero with a zero exponent.
|
||||
bigIntTestCase{s: "-0", h: 0xb040000000000000, l: 0},
|
||||
// Test that parsing negative zero with an in-range exponent returns negative zero and
|
||||
// preserves the specified exponent value.
|
||||
bigIntTestCase{s: "-0E999", h: 0xb80e000000000000, l: 0},
|
||||
// Test that parsing zero with an out-of-range positive exponent returns zero with the
|
||||
// maximum positive exponent (i.e. 0e+6111).
|
||||
bigIntTestCase{s: "0E2000000000000", h: 0x5ffe000000000000, l: 0},
|
||||
// Test that parsing zero with an out-of-range negative exponent returns zero with the
|
||||
// minimum negative exponent (i.e. 0e-6176).
|
||||
bigIntTestCase{s: "-0E2000000000000", h: 0xdffe000000000000, l: 0},
|
||||
bigIntTestCase{s: "", remark: "parse fail"})
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.s, func(t *testing.T) {
|
||||
switch c.remark {
|
||||
case "overflow", "parse fail":
|
||||
_, err := ParseDecimal128(c.s)
|
||||
assert.Error(t, err, "ParseDecimal128(%q) should return an error", c.s)
|
||||
default:
|
||||
got, err := ParseDecimal128(c.s)
|
||||
require.NoError(t, err, "ParseDecimal128(%q) error", c.s)
|
||||
|
||||
want := Decimal128{h: c.h, l: c.l}
|
||||
// Decimal128 doesn't implement an equality function, so compare the expected
|
||||
// low/high uint64 values directly. Also print the string representation of each
|
||||
// number to make debugging failures easier.
|
||||
assert.Equal(t, want, got, "ParseDecimal128(%q) = %s, want %s", c.s, got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecimal128_JSON(t *testing.T) {
|
||||
t.Run("roundTrip", func(t *testing.T) {
|
||||
decimal := NewDecimal128(0x3040000000000000, 12345)
|
||||
bytes, err := json.Marshal(decimal)
|
||||
assert.Nil(t, err, "json.Marshal error: %v", err)
|
||||
got := NewDecimal128(0, 0)
|
||||
err = json.Unmarshal(bytes, &got)
|
||||
assert.Nil(t, err, "json.Unmarshal error: %v", err)
|
||||
assert.Equal(t, decimal.h, got.h, "expected h: %v got: %v", decimal.h, got.h)
|
||||
assert.Equal(t, decimal.l, got.l, "expected l: %v got: %v", decimal.l, got.l)
|
||||
})
|
||||
t.Run("unmarshal extendedJSON", func(t *testing.T) {
|
||||
want := NewDecimal128(0x3040000000000000, 12345)
|
||||
extJSON := fmt.Sprintf(`{"$numberDecimal": %q}`, want.String())
|
||||
|
||||
got := NewDecimal128(0, 0)
|
||||
err := json.Unmarshal([]byte(extJSON), &got)
|
||||
assert.Nil(t, err, "json.Unmarshal error: %v", err)
|
||||
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
|
||||
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
|
||||
})
|
||||
t.Run("unmarshal null", func(t *testing.T) {
|
||||
want := NewDecimal128(0, 0)
|
||||
extJSON := `null`
|
||||
|
||||
got := NewDecimal128(0, 0)
|
||||
err := json.Unmarshal([]byte(extJSON), &got)
|
||||
assert.Nil(t, err, "json.Unmarshal error: %v", err)
|
||||
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
|
||||
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
|
||||
})
|
||||
t.Run("unmarshal", func(t *testing.T) {
|
||||
cases := make([]bigIntTestCase, 0, len(bigIntTestCases))
|
||||
cases = append(cases, bigIntTestCases...)
|
||||
cases = append(cases,
|
||||
bigIntTestCase{s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800},
|
||||
bigIntTestCase{s: "12345E+21", h: 0x306a000000000000, l: 12345},
|
||||
bigIntTestCase{s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"},
|
||||
bigIntTestCase{s: ".125e1", h: 0x303c000000000000, l: 125},
|
||||
bigIntTestCase{s: ".125", h: 0x303a000000000000, l: 125})
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.s, func(t *testing.T) {
|
||||
input := fmt.Sprintf(`{"foo": %q}`, c.s)
|
||||
var got map[string]Decimal128
|
||||
err := json.Unmarshal([]byte(input), &got)
|
||||
|
||||
switch c.remark {
|
||||
case "overflow", "parse fail":
|
||||
assert.NotNil(t, err, "expected Unmarshal error, got nil")
|
||||
default:
|
||||
assert.Nil(t, err, "Unmarshal error: %v", err)
|
||||
gotDecimal := got["foo"]
|
||||
assert.Equal(t, c.h, gotDecimal.h, "expected h: %v got: %v", c.h, gotDecimal.l)
|
||||
assert.Equal(t, c.l, gotDecimal.l, "expected l: %v got: %v", c.l, gotDecimal.h)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
136
decoder.go
Normal file
136
decoder.go
Normal file
@ -0,0 +1,136 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrDecodeToNil is the error returned when trying to decode to a nil value
|
||||
var ErrDecodeToNil = errors.New("cannot Decode to nil value")
|
||||
|
||||
// This pool is used to keep the allocations of Decoders down. This is only used for the Marshal*
|
||||
// methods and is not consumable from outside of this package. The Decoders retrieved from this pool
|
||||
// must have both Reset and SetRegistry called on them.
|
||||
var decPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(Decoder)
|
||||
},
|
||||
}
|
||||
|
||||
// A Decoder reads and decodes BSON documents from a stream. It reads from a ValueReader as
|
||||
// the source of BSON data.
|
||||
type Decoder struct {
|
||||
dc DecodeContext
|
||||
vr ValueReader
|
||||
}
|
||||
|
||||
// NewDecoder returns a new decoder that reads from vr.
|
||||
func NewDecoder(vr ValueReader) *Decoder {
|
||||
return &Decoder{
|
||||
dc: DecodeContext{Registry: defaultRegistry},
|
||||
vr: vr,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode reads the next BSON document from the stream and decodes it into the
|
||||
// value pointed to by val.
|
||||
//
|
||||
// See [Unmarshal] for details about BSON unmarshaling behavior.
|
||||
func (d *Decoder) Decode(val interface{}) error {
|
||||
if unmarshaler, ok := val.(Unmarshaler); ok {
|
||||
// TODO(skriptble): Reuse a []byte here and use the AppendDocumentBytes method.
|
||||
buf, err := copyDocumentToBytes(d.vr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return unmarshaler.UnmarshalBSON(buf)
|
||||
}
|
||||
|
||||
rval := reflect.ValueOf(val)
|
||||
switch rval.Kind() {
|
||||
case reflect.Ptr:
|
||||
if rval.IsNil() {
|
||||
return ErrDecodeToNil
|
||||
}
|
||||
rval = rval.Elem()
|
||||
case reflect.Map:
|
||||
if rval.IsNil() {
|
||||
return ErrDecodeToNil
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval)
|
||||
}
|
||||
decoder, err := d.dc.LookupDecoder(rval.Type())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return decoder.DecodeValue(d.dc, d.vr, rval)
|
||||
}
|
||||
|
||||
// Reset will reset the state of the decoder, using the same *DecodeContext used in
|
||||
// the original construction but using vr for reading.
|
||||
func (d *Decoder) Reset(vr ValueReader) {
|
||||
d.vr = vr
|
||||
}
|
||||
|
||||
// SetRegistry replaces the current registry of the decoder with r.
|
||||
func (d *Decoder) SetRegistry(r *Registry) {
|
||||
d.dc.Registry = r
|
||||
}
|
||||
|
||||
// DefaultDocumentM causes the Decoder to always unmarshal documents into the bson.M type. This
|
||||
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
|
||||
func (d *Decoder) DefaultDocumentM() {
|
||||
d.dc.defaultDocumentType = reflect.TypeOf(M{})
|
||||
}
|
||||
|
||||
// AllowTruncatingDoubles causes the Decoder to truncate the fractional part of BSON "double" values
|
||||
// when attempting to unmarshal them into a Go integer (int, int8, int16, int32, or int64) struct
|
||||
// field. The truncation logic does not apply to BSON "decimal128" values.
|
||||
func (d *Decoder) AllowTruncatingDoubles() {
|
||||
d.dc.truncate = true
|
||||
}
|
||||
|
||||
// BinaryAsSlice causes the Decoder to unmarshal BSON binary field values that are the "Generic" or
|
||||
// "Old" BSON binary subtype as a Go byte slice instead of a bson.Binary.
|
||||
func (d *Decoder) BinaryAsSlice() {
|
||||
d.dc.binaryAsSlice = true
|
||||
}
|
||||
|
||||
// ObjectIDAsHexString causes the Decoder to decode object IDs to their hex representation.
|
||||
func (d *Decoder) ObjectIDAsHexString() {
|
||||
d.dc.objectIDAsHexString = true
|
||||
}
|
||||
|
||||
// UseJSONStructTags causes the Decoder to fall back to using the "json" struct tag if a "bson"
|
||||
// struct tag is not specified.
|
||||
func (d *Decoder) UseJSONStructTags() {
|
||||
d.dc.useJSONStructTags = true
|
||||
}
|
||||
|
||||
// UseLocalTimeZone causes the Decoder to unmarshal time.Time values in the local timezone instead
|
||||
// of the UTC timezone.
|
||||
func (d *Decoder) UseLocalTimeZone() {
|
||||
d.dc.useLocalTimeZone = true
|
||||
}
|
||||
|
||||
// ZeroMaps causes the Decoder to delete any existing values from Go maps in the destination value
|
||||
// passed to Decode before unmarshaling BSON documents into them.
|
||||
func (d *Decoder) ZeroMaps() {
|
||||
d.dc.zeroMaps = true
|
||||
}
|
||||
|
||||
// ZeroStructs causes the Decoder to delete any existing values from Go structs in the destination
|
||||
// value passed to Decode before unmarshaling BSON documents into them.
|
||||
func (d *Decoder) ZeroStructs() {
|
||||
d.dc.zeroStructs = true
|
||||
}
|
208
decoder_example_test.go
Normal file
208
decoder_example_test.go
Normal file
@ -0,0 +1,208 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
)
|
||||
|
||||
func ExampleDecoder() {
|
||||
// Marshal a BSON document that contains the name, SKU, and price (in cents)
|
||||
// of a product.
|
||||
doc := bson.D{
|
||||
{Key: "name", Value: "Cereal Rounds"},
|
||||
{Key: "sku", Value: "AB12345"},
|
||||
{Key: "price_cents", Value: 399},
|
||||
}
|
||||
data, err := bson.Marshal(doc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create a Decoder that reads the marshaled BSON document and use it to
|
||||
// unmarshal the document into a Product struct.
|
||||
decoder := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(data)))
|
||||
|
||||
type Product struct {
|
||||
Name string `bson:"name"`
|
||||
SKU string `bson:"sku"`
|
||||
Price int64 `bson:"price_cents"`
|
||||
}
|
||||
|
||||
var res Product
|
||||
err = decoder.Decode(&res)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%+v\n", res)
|
||||
// Output: {Name:Cereal Rounds SKU:AB12345 Price:399}
|
||||
}
|
||||
|
||||
func ExampleDecoder_DefaultDocumentM() {
|
||||
// Marshal a BSON document that contains a city name and a nested document
|
||||
// with various city properties.
|
||||
doc := bson.D{
|
||||
{Key: "name", Value: "New York"},
|
||||
{Key: "properties", Value: bson.D{
|
||||
{Key: "state", Value: "NY"},
|
||||
{Key: "population", Value: 8_804_190},
|
||||
{Key: "elevation", Value: 10},
|
||||
}},
|
||||
}
|
||||
data, err := bson.Marshal(doc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create a Decoder that reads the marshaled BSON document and use it to unmarshal the document
|
||||
// into a City struct.
|
||||
decoder := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(data)))
|
||||
|
||||
type City struct {
|
||||
Name string `bson:"name"`
|
||||
Properties interface{} `bson:"properties"`
|
||||
}
|
||||
|
||||
// Configure the Decoder to default to decoding BSON documents as the M
|
||||
// type if the decode destination has no type information. The Properties
|
||||
// field in the City struct will be decoded as a "M" (i.e. map) instead
|
||||
// of the default "D".
|
||||
decoder.DefaultDocumentM()
|
||||
|
||||
var res City
|
||||
err = decoder.Decode(&res)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
data, err = json.Marshal(res)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("%+v\n", string(data))
|
||||
// Output: {"Name":"New York","Properties":{"elevation":10,"population":8804190,"state":"NY"}}
|
||||
}
|
||||
|
||||
func ExampleDecoder_UseJSONStructTags() {
|
||||
// Marshal a BSON document that contains the name, SKU, and price (in cents)
|
||||
// of a product.
|
||||
doc := bson.D{
|
||||
{Key: "name", Value: "Cereal Rounds"},
|
||||
{Key: "sku", Value: "AB12345"},
|
||||
{Key: "price_cents", Value: 399},
|
||||
}
|
||||
data, err := bson.Marshal(doc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create a Decoder that reads the marshaled BSON document and use it to
|
||||
// unmarshal the document into a Product struct.
|
||||
decoder := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(data)))
|
||||
|
||||
type Product struct {
|
||||
Name string `json:"name"`
|
||||
SKU string `json:"sku"`
|
||||
Price int64 `json:"price_cents"`
|
||||
}
|
||||
|
||||
// Configure the Decoder to use "json" struct tags when decoding if "bson"
|
||||
// struct tags are not present.
|
||||
decoder.UseJSONStructTags()
|
||||
|
||||
var res Product
|
||||
err = decoder.Decode(&res)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%+v\n", res)
|
||||
// Output: {Name:Cereal Rounds SKU:AB12345 Price:399}
|
||||
}
|
||||
|
||||
func ExampleDecoder_extendedJSON() {
|
||||
// Define an Extended JSON document that contains the name, SKU, and price
|
||||
// (in cents) of a product.
|
||||
data := []byte(`{"name":"Cereal Rounds","sku":"AB12345","price_cents":{"$numberLong":"399"}}`)
|
||||
|
||||
// Create a Decoder that reads the Extended JSON document and use it to
|
||||
// unmarshal the document into a Product struct.
|
||||
vr, err := bson.NewExtJSONValueReader(bytes.NewReader(data), true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
decoder := bson.NewDecoder(vr)
|
||||
|
||||
type Product struct {
|
||||
Name string `bson:"name"`
|
||||
SKU string `bson:"sku"`
|
||||
Price int64 `bson:"price_cents"`
|
||||
}
|
||||
|
||||
var res Product
|
||||
err = decoder.Decode(&res)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%+v\n", res)
|
||||
// Output: {Name:Cereal Rounds SKU:AB12345 Price:399}
|
||||
}
|
||||
|
||||
func ExampleDecoder_multipleExtendedJSONDocuments() {
|
||||
// Define a newline-separated sequence of Extended JSON documents that
|
||||
// contain X,Y coordinates.
|
||||
data := []byte(`
|
||||
{"x":{"$numberInt":"0"},"y":{"$numberInt":"0"}}
|
||||
{"x":{"$numberInt":"1"},"y":{"$numberInt":"1"}}
|
||||
{"x":{"$numberInt":"2"},"y":{"$numberInt":"2"}}
|
||||
{"x":{"$numberInt":"3"},"y":{"$numberInt":"3"}}
|
||||
{"x":{"$numberInt":"4"},"y":{"$numberInt":"4"}}
|
||||
`)
|
||||
|
||||
// Create a Decoder that reads the Extended JSON documents and use it to
|
||||
// unmarshal the documents Coordinate structs.
|
||||
vr, err := bson.NewExtJSONValueReader(bytes.NewReader(data), true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
decoder := bson.NewDecoder(vr)
|
||||
|
||||
type Coordinate struct {
|
||||
X int
|
||||
Y int
|
||||
}
|
||||
|
||||
// Read and unmarshal each Extended JSON document from the sequence. If
|
||||
// Decode returns error io.EOF, that means the Decoder has reached the end
|
||||
// of the input, so break the loop.
|
||||
for {
|
||||
var res Coordinate
|
||||
err = decoder.Decode(&res)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Printf("%+v\n", res)
|
||||
}
|
||||
// Output:
|
||||
// {X:0 Y:0}
|
||||
// {X:1 Y:1}
|
||||
// {X:2 Y:2}
|
||||
// {X:3 Y:3}
|
||||
// {X:4 Y:4}
|
||||
}
|
699
decoder_test.go
Normal file
699
decoder_test.go
Normal file
@ -0,0 +1,699 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"gitea.psichedelico.com/go/bson/internal/require"
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
func TestDecodeValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range unmarshalingTestCases() {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := reflect.New(tc.sType).Elem()
|
||||
vr := NewDocumentReader(bytes.NewReader(tc.data))
|
||||
reg := defaultRegistry
|
||||
decoder, err := reg.LookupDecoder(reflect.TypeOf(got))
|
||||
noerr(t, err)
|
||||
err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got)
|
||||
noerr(t, err)
|
||||
assert.Equal(t, tc.want, got.Addr().Interface(), "Results do not match.")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodingInterfaces(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
stub func() ([]byte, interface{}, func(*testing.T))
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "struct with interface containing a concrete value",
|
||||
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||
type testStruct struct {
|
||||
Value interface{}
|
||||
}
|
||||
var value string
|
||||
|
||||
data := docToBytes(struct {
|
||||
Value string
|
||||
}{
|
||||
Value: "foo",
|
||||
})
|
||||
|
||||
receiver := testStruct{&value}
|
||||
|
||||
check := func(t *testing.T) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "foo", value)
|
||||
}
|
||||
|
||||
return data, &receiver, check
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct with interface containing a struct",
|
||||
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||
type demo struct {
|
||||
Data string
|
||||
}
|
||||
|
||||
type testStruct struct {
|
||||
Value interface{}
|
||||
}
|
||||
var value demo
|
||||
|
||||
data := docToBytes(struct {
|
||||
Value demo
|
||||
}{
|
||||
Value: demo{"foo"},
|
||||
})
|
||||
|
||||
receiver := testStruct{&value}
|
||||
|
||||
check := func(t *testing.T) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "foo", value.Data)
|
||||
}
|
||||
|
||||
return data, &receiver, check
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct with interface containing a slice",
|
||||
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||
type testStruct struct {
|
||||
Values interface{}
|
||||
}
|
||||
var values []string
|
||||
|
||||
data := docToBytes(struct {
|
||||
Values []string
|
||||
}{
|
||||
Values: []string{"foo", "bar"},
|
||||
})
|
||||
|
||||
receiver := testStruct{&values}
|
||||
|
||||
check := func(t *testing.T) {
|
||||
t.Helper()
|
||||
assert.Equal(t, []string{"foo", "bar"}, values)
|
||||
}
|
||||
|
||||
return data, &receiver, check
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct with interface containing an array",
|
||||
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||
type testStruct struct {
|
||||
Values interface{}
|
||||
}
|
||||
var values [2]string
|
||||
|
||||
data := docToBytes(struct {
|
||||
Values []string
|
||||
}{
|
||||
Values: []string{"foo", "bar"},
|
||||
})
|
||||
|
||||
receiver := testStruct{&values}
|
||||
|
||||
check := func(t *testing.T) {
|
||||
t.Helper()
|
||||
assert.Equal(t, [2]string{"foo", "bar"}, values)
|
||||
}
|
||||
|
||||
return data, &receiver, check
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct with interface array containing concrete values",
|
||||
stub: func() ([]byte, interface{}, func(*testing.T)) {
|
||||
type testStruct struct {
|
||||
Values [3]interface{}
|
||||
}
|
||||
var str string
|
||||
var i, j int
|
||||
|
||||
data := docToBytes(struct {
|
||||
Values []interface{}
|
||||
}{
|
||||
Values: []interface{}{"foo", 42, nil},
|
||||
})
|
||||
|
||||
receiver := testStruct{[3]interface{}{&str, &i, &j}}
|
||||
|
||||
check := func(t *testing.T) {
|
||||
t.Helper()
|
||||
assert.Equal(t, "foo", str)
|
||||
assert.Equal(t, 42, i)
|
||||
assert.Equal(t, 0, j)
|
||||
assert.Equal(t, testStruct{[3]interface{}{&str, &i, nil}}, receiver)
|
||||
}
|
||||
|
||||
return data, &receiver, check
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data, receiver, check := tc.stub()
|
||||
got := reflect.ValueOf(receiver).Elem()
|
||||
vr := NewDocumentReader(bytes.NewReader(data))
|
||||
reg := defaultRegistry
|
||||
decoder, err := reg.LookupDecoder(got.Type())
|
||||
noerr(t, err)
|
||||
err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got)
|
||||
noerr(t, err)
|
||||
check(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoder(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Decode", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("basic", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range unmarshalingTestCases() {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := reflect.New(tc.sType).Interface()
|
||||
vr := NewDocumentReader(bytes.NewReader(tc.data))
|
||||
dec := NewDecoder(vr)
|
||||
err := dec.Decode(got)
|
||||
noerr(t, err)
|
||||
assert.Equal(t, tc.want, got, "Results do not match.")
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("stream", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var buf bytes.Buffer
|
||||
vr := NewDocumentReader(&buf)
|
||||
dec := NewDecoder(vr)
|
||||
for _, tc := range unmarshalingTestCases() {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
buf.Write(tc.data)
|
||||
got := reflect.New(tc.sType).Interface()
|
||||
err := dec.Decode(got)
|
||||
noerr(t, err)
|
||||
assert.Equal(t, tc.want, got, "Results do not match.")
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("lookup error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type certainlydoesntexistelsewhereihope func(string, string) string
|
||||
// Avoid unused code lint error.
|
||||
_ = certainlydoesntexistelsewhereihope(func(string, string) string { return "" })
|
||||
|
||||
cdeih := func(string, string) string { return "certainlydoesntexistelsewhereihope" }
|
||||
dec := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
|
||||
want := errNoDecoder{Type: reflect.TypeOf(cdeih)}
|
||||
got := dec.Decode(&cdeih)
|
||||
assert.Equal(t, want, got, "Received unexpected error.")
|
||||
})
|
||||
t.Run("Unmarshaler", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
err error
|
||||
vr ValueReader
|
||||
invoked bool
|
||||
}{
|
||||
{
|
||||
"error",
|
||||
errors.New("Unmarshaler error"),
|
||||
&valueReaderWriter{BSONType: TypeEmbeddedDocument, Err: ErrEOD, ErrAfter: readElement},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"copy error",
|
||||
errors.New("copy error"),
|
||||
&valueReaderWriter{Err: errors.New("copy error"), ErrAfter: readDocument},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"success",
|
||||
nil,
|
||||
&valueReaderWriter{BSONType: TypeEmbeddedDocument, Err: ErrEOD, ErrAfter: readElement},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
unmarshaler := &testUnmarshaler{Err: tc.err}
|
||||
dec := NewDecoder(tc.vr)
|
||||
got := dec.Decode(unmarshaler)
|
||||
want := tc.err
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
|
||||
}
|
||||
if unmarshaler.Invoked != tc.invoked {
|
||||
if tc.invoked {
|
||||
t.Error("Expected to have UnmarshalBSON invoked, but it wasn't.")
|
||||
} else {
|
||||
t.Error("Expected UnmarshalBSON to not be invoked, but it was.")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Unmarshaler/success ValueReader", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
want := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))
|
||||
unmarshaler := &testUnmarshaler{}
|
||||
vr := NewDocumentReader(bytes.NewReader(want))
|
||||
dec := NewDecoder(vr)
|
||||
err := dec.Decode(unmarshaler)
|
||||
noerr(t, err)
|
||||
got := unmarshaler.Val
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Did not unmarshal properly. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
t.Run("NewDecoder", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
|
||||
if got == nil {
|
||||
t.Errorf("Was expecting a non-nil Decoder, but got <nil>")
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("NewDecoderWithContext", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
|
||||
if got == nil {
|
||||
t.Errorf("Was expecting a non-nil Decoder, but got <nil>")
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("Decode doesn't zero struct", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type foo struct {
|
||||
Item string
|
||||
Qty int
|
||||
Bonus int
|
||||
}
|
||||
var got foo
|
||||
got.Item = "apple"
|
||||
got.Bonus = 2
|
||||
data := docToBytes(D{{"item", "canvas"}, {"qty", 4}})
|
||||
vr := NewDocumentReader(bytes.NewReader(data))
|
||||
dec := NewDecoder(vr)
|
||||
err := dec.Decode(&got)
|
||||
noerr(t, err)
|
||||
want := foo{Item: "canvas", Qty: 4, Bonus: 2}
|
||||
assert.Equal(t, want, got, "Results do not match.")
|
||||
})
|
||||
t.Run("Reset", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
vr1, vr2 := NewDocumentReader(bytes.NewReader([]byte{})), NewDocumentReader(bytes.NewReader([]byte{}))
|
||||
dec := NewDecoder(vr1)
|
||||
if dec.vr != vr1 {
|
||||
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1)
|
||||
}
|
||||
dec.Reset(vr2)
|
||||
if dec.vr != vr2 {
|
||||
t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2)
|
||||
}
|
||||
})
|
||||
t.Run("SetRegistry", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r1, r2 := defaultRegistry, NewRegistry()
|
||||
dc1 := DecodeContext{Registry: r1}
|
||||
dc2 := DecodeContext{Registry: r2}
|
||||
dec := NewDecoder(NewDocumentReader(bytes.NewReader([]byte{})))
|
||||
if !reflect.DeepEqual(dec.dc, dc1) {
|
||||
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1)
|
||||
}
|
||||
dec.SetRegistry(r2)
|
||||
if !reflect.DeepEqual(dec.dc, dc2) {
|
||||
t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2)
|
||||
}
|
||||
})
|
||||
t.Run("DecodeToNil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := docToBytes(D{{"item", "canvas"}, {"qty", 4}})
|
||||
vr := NewDocumentReader(bytes.NewReader(data))
|
||||
dec := NewDecoder(vr)
|
||||
|
||||
var got *D
|
||||
err := dec.Decode(got)
|
||||
if !errors.Is(err, ErrDecodeToNil) {
|
||||
t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type testUnmarshaler struct {
|
||||
Invoked bool
|
||||
Val []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func (tu *testUnmarshaler) UnmarshalBSON(d []byte) error {
|
||||
tu.Invoked = true
|
||||
tu.Val = d
|
||||
return tu.Err
|
||||
}
|
||||
|
||||
func TestDecoderConfiguration(t *testing.T) {
|
||||
type truncateDoublesTest struct {
|
||||
MyInt int
|
||||
MyInt8 int8
|
||||
MyInt16 int16
|
||||
MyInt32 int32
|
||||
MyInt64 int64
|
||||
MyUint uint
|
||||
MyUint8 uint8
|
||||
MyUint16 uint16
|
||||
MyUint32 uint32
|
||||
MyUint64 uint64
|
||||
}
|
||||
|
||||
type objectIDTest struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
type jsonStructTest struct {
|
||||
StructFieldName string `json:"jsonFieldName"`
|
||||
}
|
||||
|
||||
type localTimeZoneTest struct {
|
||||
MyTime time.Time
|
||||
}
|
||||
|
||||
type zeroMapsTest struct {
|
||||
MyMap map[string]string
|
||||
}
|
||||
|
||||
type zeroStructsTest struct {
|
||||
MyString string
|
||||
MyInt int
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
configure func(*Decoder)
|
||||
input []byte
|
||||
decodeInto func() interface{}
|
||||
want interface{}
|
||||
}{
|
||||
// Test that AllowTruncatingDoubles causes the Decoder to unmarshal BSON doubles with
|
||||
// fractional parts into Go integer types by truncating the fractional part.
|
||||
{
|
||||
description: "AllowTruncatingDoubles",
|
||||
configure: func(dec *Decoder) {
|
||||
dec.AllowTruncatingDoubles()
|
||||
},
|
||||
input: bsoncore.NewDocumentBuilder().
|
||||
AppendDouble("myInt", 1.999).
|
||||
AppendDouble("myInt8", 1.999).
|
||||
AppendDouble("myInt16", 1.999).
|
||||
AppendDouble("myInt32", 1.999).
|
||||
AppendDouble("myInt64", 1.999).
|
||||
AppendDouble("myUint", 1.999).
|
||||
AppendDouble("myUint8", 1.999).
|
||||
AppendDouble("myUint16", 1.999).
|
||||
AppendDouble("myUint32", 1.999).
|
||||
AppendDouble("myUint64", 1.999).
|
||||
Build(),
|
||||
decodeInto: func() interface{} { return &truncateDoublesTest{} },
|
||||
want: &truncateDoublesTest{
|
||||
MyInt: 1,
|
||||
MyInt8: 1,
|
||||
MyInt16: 1,
|
||||
MyInt32: 1,
|
||||
MyInt64: 1,
|
||||
MyUint: 1,
|
||||
MyUint8: 1,
|
||||
MyUint16: 1,
|
||||
MyUint32: 1,
|
||||
MyUint64: 1,
|
||||
},
|
||||
},
|
||||
// Test that BinaryAsSlice causes the Decoder to unmarshal BSON binary fields into Go byte
|
||||
// slices when there is no type information (e.g when unmarshaling into a bson.D).
|
||||
{
|
||||
description: "BinaryAsSlice",
|
||||
configure: func(dec *Decoder) {
|
||||
dec.BinaryAsSlice()
|
||||
},
|
||||
input: bsoncore.NewDocumentBuilder().
|
||||
AppendBinary("myBinary", TypeBinaryGeneric, []byte{}).
|
||||
Build(),
|
||||
decodeInto: func() interface{} { return &D{} },
|
||||
want: &D{{Key: "myBinary", Value: []byte{}}},
|
||||
},
|
||||
// Test that the default decoder always decodes BSON documents into bson.D values,
|
||||
// independent of the top-level Go value type.
|
||||
{
|
||||
description: "DocumentD nested by default",
|
||||
configure: func(_ *Decoder) {},
|
||||
input: bsoncore.NewDocumentBuilder().
|
||||
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
|
||||
AppendString("myString", "test value").
|
||||
Build()).
|
||||
Build(),
|
||||
decodeInto: func() interface{} { return M{} },
|
||||
want: M{
|
||||
"myDocument": D{{Key: "myString", Value: "test value"}},
|
||||
},
|
||||
},
|
||||
// Test that DefaultDocumentM always decodes BSON documents into bson.M values,
|
||||
// independent of the top-level Go value type.
|
||||
{
|
||||
description: "DefaultDocumentM nested",
|
||||
configure: func(dec *Decoder) {
|
||||
dec.DefaultDocumentM()
|
||||
},
|
||||
input: bsoncore.NewDocumentBuilder().
|
||||
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
|
||||
AppendString("myString", "test value").
|
||||
Build()).
|
||||
Build(),
|
||||
decodeInto: func() interface{} { return &D{} },
|
||||
want: &D{
|
||||
{Key: "myDocument", Value: M{"myString": "test value"}},
|
||||
},
|
||||
},
|
||||
// Test that ObjectIDAsHexString causes the Decoder to decode object ID to hex.
|
||||
{
|
||||
description: "ObjectIDAsHexString",
|
||||
configure: func(dec *Decoder) {
|
||||
dec.ObjectIDAsHexString()
|
||||
},
|
||||
input: bsoncore.NewDocumentBuilder().
|
||||
AppendObjectID("id", func() ObjectID {
|
||||
id, _ := ObjectIDFromHex("5ef7fdd91c19e3222b41b839")
|
||||
return id
|
||||
}()).
|
||||
Build(),
|
||||
decodeInto: func() interface{} { return &objectIDTest{} },
|
||||
want: &objectIDTest{ID: "5ef7fdd91c19e3222b41b839"},
|
||||
},
|
||||
// Test that UseJSONStructTags causes the Decoder to fall back to "json" struct tags if
|
||||
// "bson" struct tags are not available.
|
||||
{
|
||||
description: "UseJSONStructTags",
|
||||
configure: func(dec *Decoder) {
|
||||
dec.UseJSONStructTags()
|
||||
},
|
||||
input: bsoncore.NewDocumentBuilder().
|
||||
AppendString("jsonFieldName", "test value").
|
||||
Build(),
|
||||
decodeInto: func() interface{} { return &jsonStructTest{} },
|
||||
want: &jsonStructTest{StructFieldName: "test value"},
|
||||
},
|
||||
// Test that UseLocalTimeZone causes the Decoder to use the local time zone for decoded
|
||||
// time.Time values instead of UTC.
|
||||
{
|
||||
description: "UseLocalTimeZone",
|
||||
configure: func(dec *Decoder) {
|
||||
dec.UseLocalTimeZone()
|
||||
},
|
||||
input: bsoncore.NewDocumentBuilder().
|
||||
AppendDateTime("myTime", 1684349179939).
|
||||
Build(),
|
||||
decodeInto: func() interface{} { return &localTimeZoneTest{} },
|
||||
want: &localTimeZoneTest{MyTime: time.UnixMilli(1684349179939)},
|
||||
},
|
||||
// Test that ZeroMaps causes the Decoder to empty any Go map values before decoding BSON
|
||||
// documents into them.
|
||||
{
|
||||
description: "ZeroMaps",
|
||||
configure: func(dec *Decoder) {
|
||||
dec.ZeroMaps()
|
||||
},
|
||||
input: bsoncore.NewDocumentBuilder().
|
||||
AppendDocument("myMap", bsoncore.NewDocumentBuilder().
|
||||
AppendString("myString", "test value").
|
||||
Build()).
|
||||
Build(),
|
||||
decodeInto: func() interface{} {
|
||||
return &zeroMapsTest{MyMap: map[string]string{"myExtraValue": "extra value"}}
|
||||
},
|
||||
want: &zeroMapsTest{MyMap: map[string]string{"myString": "test value"}},
|
||||
},
|
||||
// Test that ZeroStructs causes the Decoder to empty any Go struct values before decoding
|
||||
// BSON documents into them.
|
||||
{
|
||||
description: "ZeroStructs",
|
||||
configure: func(dec *Decoder) {
|
||||
dec.ZeroStructs()
|
||||
},
|
||||
input: bsoncore.NewDocumentBuilder().
|
||||
AppendString("myString", "test value").
|
||||
Build(),
|
||||
decodeInto: func() interface{} {
|
||||
return &zeroStructsTest{MyInt: 1}
|
||||
},
|
||||
want: &zeroStructsTest{MyString: "test value"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // Capture range variable.
|
||||
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dec := NewDecoder(NewDocumentReader(bytes.NewReader(tc.input)))
|
||||
|
||||
tc.configure(dec)
|
||||
|
||||
got := tc.decodeInto()
|
||||
err := dec.Decode(got)
|
||||
require.NoError(t, err, "Decode error")
|
||||
|
||||
assert.Equal(t, tc.want, got, "expected and actual decode results do not match")
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Decoding an object ID to string", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type objectIDTest struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
doc := bsoncore.NewDocumentBuilder().
|
||||
AppendObjectID("id", func() ObjectID {
|
||||
id, _ := ObjectIDFromHex("5ef7fdd91c19e3222b41b839")
|
||||
return id
|
||||
}()).
|
||||
Build()
|
||||
|
||||
dec := NewDecoder(NewDocumentReader(bytes.NewReader(doc)))
|
||||
|
||||
var got objectIDTest
|
||||
err := dec.Decode(&got)
|
||||
const want = "error decoding key id: decoding an object ID into a string is not supported by default (set Decoder.ObjectIDAsHexString to enable decoding as a hexadecimal string)"
|
||||
assert.EqualError(t, err, want)
|
||||
})
|
||||
t.Run("DefaultDocumentM top-level", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := bsoncore.NewDocumentBuilder().
|
||||
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
|
||||
AppendString("myString", "test value").
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
dec := NewDecoder(NewDocumentReader(bytes.NewReader(input)))
|
||||
|
||||
dec.DefaultDocumentM()
|
||||
|
||||
var got interface{}
|
||||
err := dec.Decode(&got)
|
||||
require.NoError(t, err, "Decode error")
|
||||
|
||||
want := M{
|
||||
"myDocument": M{
|
||||
"myString": "test value",
|
||||
},
|
||||
}
|
||||
assert.Equal(t, want, got, "expected and actual decode results do not match")
|
||||
})
|
||||
t.Run("Default decodes DocumentD for top-level", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := bsoncore.NewDocumentBuilder().
|
||||
AppendDocument("myDocument", bsoncore.NewDocumentBuilder().
|
||||
AppendString("myString", "test value").
|
||||
Build()).
|
||||
Build()
|
||||
|
||||
dec := NewDecoder(NewDocumentReader(bytes.NewReader(input)))
|
||||
|
||||
var got interface{}
|
||||
err := dec.Decode(&got)
|
||||
require.NoError(t, err, "Decode error")
|
||||
|
||||
want := D{
|
||||
{Key: "myDocument", Value: D{
|
||||
{Key: "myString", Value: "test value"},
|
||||
}},
|
||||
}
|
||||
assert.Equal(t, want, got, "expected and actual decode results do not match")
|
||||
})
|
||||
}
|
1497
default_value_decoders.go
Normal file
1497
default_value_decoders.go
Normal file
File diff suppressed because it is too large
Load Diff
3806
default_value_decoders_test.go
Normal file
3806
default_value_decoders_test.go
Normal file
File diff suppressed because it is too large
Load Diff
517
default_value_encoders.go
Normal file
517
default_value_encoders.go
Normal file
@ -0,0 +1,517 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
var bvwPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(valueWriter)
|
||||
},
|
||||
}
|
||||
|
||||
var errInvalidValue = errors.New("cannot encode invalid element")
|
||||
|
||||
var sliceWriterPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
sw := make(sliceWriter, 0)
|
||||
return &sw
|
||||
},
|
||||
}
|
||||
|
||||
func encodeElement(ec EncodeContext, dw DocumentWriter, e E) error {
|
||||
vw, err := dw.WriteDocumentElement(e.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if e.Value == nil {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with
|
||||
// the provided RegistryBuilder.
|
||||
func registerDefaultEncoders(reg *Registry) {
|
||||
mapEncoder := &mapCodec{}
|
||||
uintCodec := &uintCodec{}
|
||||
|
||||
reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{})
|
||||
reg.RegisterTypeEncoder(tTime, &timeCodec{})
|
||||
reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{})
|
||||
reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{})
|
||||
reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue))
|
||||
reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue))
|
||||
reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue))
|
||||
reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue))
|
||||
reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue))
|
||||
reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue))
|
||||
reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue))
|
||||
reg.RegisterTypeEncoder(tVector, ValueEncoderFunc(vectorEncodeValue))
|
||||
reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue))
|
||||
reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue))
|
||||
reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue))
|
||||
reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue))
|
||||
reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue))
|
||||
reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue))
|
||||
reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue))
|
||||
reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue))
|
||||
reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue))
|
||||
reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Int, ValueEncoderFunc(intEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Int8, ValueEncoderFunc(intEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Int16, ValueEncoderFunc(intEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Int32, ValueEncoderFunc(intEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Int64, ValueEncoderFunc(intEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Uint, uintCodec)
|
||||
reg.RegisterKindEncoder(reflect.Uint8, uintCodec)
|
||||
reg.RegisterKindEncoder(reflect.Uint16, uintCodec)
|
||||
reg.RegisterKindEncoder(reflect.Uint32, uintCodec)
|
||||
reg.RegisterKindEncoder(reflect.Uint64, uintCodec)
|
||||
reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue))
|
||||
reg.RegisterKindEncoder(reflect.Map, mapEncoder)
|
||||
reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{})
|
||||
reg.RegisterKindEncoder(reflect.String, &stringCodec{})
|
||||
reg.RegisterKindEncoder(reflect.Struct, newStructCodec(mapEncoder))
|
||||
reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{})
|
||||
reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue))
|
||||
reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue))
|
||||
}
|
||||
|
||||
// booleanEncodeValue is the ValueEncoderFunc for bool types.
|
||||
func booleanEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Kind() != reflect.Bool {
|
||||
return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val}
|
||||
}
|
||||
return vw.WriteBoolean(val.Bool())
|
||||
}
|
||||
|
||||
func fitsIn32Bits(i int64) bool {
|
||||
return math.MinInt32 <= i && i <= math.MaxInt32
|
||||
}
|
||||
|
||||
// intEncodeValue is the ValueEncoderFunc for int types.
|
||||
func intEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
switch val.Kind() {
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32:
|
||||
return vw.WriteInt32(int32(val.Int()))
|
||||
case reflect.Int:
|
||||
i64 := val.Int()
|
||||
if fitsIn32Bits(i64) {
|
||||
return vw.WriteInt32(int32(i64))
|
||||
}
|
||||
return vw.WriteInt64(i64)
|
||||
case reflect.Int64:
|
||||
i64 := val.Int()
|
||||
if ec.minSize && fitsIn32Bits(i64) {
|
||||
return vw.WriteInt32(int32(i64))
|
||||
}
|
||||
return vw.WriteInt64(i64)
|
||||
}
|
||||
|
||||
return ValueEncoderError{
|
||||
Name: "IntEncodeValue",
|
||||
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
|
||||
Received: val,
|
||||
}
|
||||
}
|
||||
|
||||
// floatEncodeValue is the ValueEncoderFunc for float types.
|
||||
func floatEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
switch val.Kind() {
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return vw.WriteDouble(val.Float())
|
||||
}
|
||||
|
||||
return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val}
|
||||
}
|
||||
|
||||
// objectIDEncodeValue is the ValueEncoderFunc for ObjectID.
|
||||
func objectIDEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tOID {
|
||||
return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val}
|
||||
}
|
||||
return vw.WriteObjectID(val.Interface().(ObjectID))
|
||||
}
|
||||
|
||||
// decimal128EncodeValue is the ValueEncoderFunc for Decimal128.
|
||||
func decimal128EncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tDecimal {
|
||||
return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val}
|
||||
}
|
||||
return vw.WriteDecimal128(val.Interface().(Decimal128))
|
||||
}
|
||||
|
||||
// jsonNumberEncodeValue is the ValueEncoderFunc for json.Number.
|
||||
func jsonNumberEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tJSONNumber {
|
||||
return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val}
|
||||
}
|
||||
jsnum := val.Interface().(json.Number)
|
||||
|
||||
// Attempt int first, then float64
|
||||
if i64, err := jsnum.Int64(); err == nil {
|
||||
return intEncodeValue(ec, vw, reflect.ValueOf(i64))
|
||||
}
|
||||
|
||||
f64, err := jsnum.Float64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return floatEncodeValue(ec, vw, reflect.ValueOf(f64))
|
||||
}
|
||||
|
||||
// urlEncodeValue is the ValueEncoderFunc for url.URL.
|
||||
func urlEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tURL {
|
||||
return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val}
|
||||
}
|
||||
u := val.Interface().(url.URL)
|
||||
return vw.WriteString(u.String())
|
||||
}
|
||||
|
||||
// arrayEncodeValue is the ValueEncoderFunc for array types.
|
||||
func arrayEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Kind() != reflect.Array {
|
||||
return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val}
|
||||
}
|
||||
|
||||
// If we have a []E we want to treat it as a document instead of as an array.
|
||||
if val.Type().Elem() == tE {
|
||||
dw, err := vw.WriteDocument()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for idx := 0; idx < val.Len(); idx++ {
|
||||
e := val.Index(idx).Interface().(E)
|
||||
err = encodeElement(ec, dw, e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return dw.WriteDocumentEnd()
|
||||
}
|
||||
|
||||
// If we have a []byte we want to treat it as a binary instead of as an array.
|
||||
if val.Type().Elem() == tByte {
|
||||
var byteSlice []byte
|
||||
for idx := 0; idx < val.Len(); idx++ {
|
||||
byteSlice = append(byteSlice, val.Index(idx).Interface().(byte))
|
||||
}
|
||||
return vw.WriteBinary(byteSlice)
|
||||
}
|
||||
|
||||
aw, err := vw.WriteArray()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elemType := val.Type().Elem()
|
||||
encoder, err := ec.LookupEncoder(elemType)
|
||||
if err != nil && elemType.Kind() != reflect.Interface {
|
||||
return err
|
||||
}
|
||||
|
||||
for idx := 0; idx < val.Len(); idx++ {
|
||||
currEncoder, currVal, lookupErr := lookupElementEncoder(ec, encoder, val.Index(idx))
|
||||
if lookupErr != nil && !errors.Is(lookupErr, errInvalidValue) {
|
||||
return lookupErr
|
||||
}
|
||||
|
||||
vw, err := aw.WriteArrayElement()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if errors.Is(lookupErr, errInvalidValue) {
|
||||
err = vw.WriteNull()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
err = currEncoder.EncodeValue(ec, vw, currVal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return aw.WriteArrayEnd()
|
||||
}
|
||||
|
||||
func lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) {
|
||||
if origEncoder != nil || (currVal.Kind() != reflect.Interface) {
|
||||
return origEncoder, currVal, nil
|
||||
}
|
||||
currVal = currVal.Elem()
|
||||
if !currVal.IsValid() {
|
||||
return nil, currVal, errInvalidValue
|
||||
}
|
||||
currEncoder, err := ec.LookupEncoder(currVal.Type())
|
||||
|
||||
return currEncoder, currVal, err
|
||||
}
|
||||
|
||||
// valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations.
|
||||
func valueMarshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
// Either val or a pointer to val must implement ValueMarshaler
|
||||
switch {
|
||||
case !val.IsValid():
|
||||
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
|
||||
case val.Type().Implements(tValueMarshaler):
|
||||
// If ValueMarshaler is implemented on a concrete type, make sure that val isn't a nil pointer
|
||||
if isImplementationNil(val, tValueMarshaler) {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
case reflect.PtrTo(val.Type()).Implements(tValueMarshaler) && val.CanAddr():
|
||||
val = val.Addr()
|
||||
default:
|
||||
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
|
||||
}
|
||||
|
||||
m, ok := val.Interface().(ValueMarshaler)
|
||||
if !ok {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
t, data, err := m.MarshalBSONValue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return copyValueFromBytes(vw, Type(t), data)
|
||||
}
|
||||
|
||||
// marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations.
|
||||
func marshalerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
// Either val or a pointer to val must implement Marshaler
|
||||
switch {
|
||||
case !val.IsValid():
|
||||
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
|
||||
case val.Type().Implements(tMarshaler):
|
||||
// If Marshaler is implemented on a concrete type, make sure that val isn't a nil pointer
|
||||
if isImplementationNil(val, tMarshaler) {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
case reflect.PtrTo(val.Type()).Implements(tMarshaler) && val.CanAddr():
|
||||
val = val.Addr()
|
||||
default:
|
||||
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
|
||||
}
|
||||
|
||||
m, ok := val.Interface().(Marshaler)
|
||||
if !ok {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
data, err := m.MarshalBSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return copyValueFromBytes(vw, TypeEmbeddedDocument, data)
|
||||
}
|
||||
|
||||
// javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type.
|
||||
func javaScriptEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tJavaScript {
|
||||
return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val}
|
||||
}
|
||||
|
||||
return vw.WriteJavascript(val.String())
|
||||
}
|
||||
|
||||
// symbolEncodeValue is the ValueEncoderFunc for the Symbol type.
|
||||
func symbolEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tSymbol {
|
||||
return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val}
|
||||
}
|
||||
|
||||
return vw.WriteSymbol(val.String())
|
||||
}
|
||||
|
||||
// binaryEncodeValue is the ValueEncoderFunc for Binary.
|
||||
func binaryEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tBinary {
|
||||
return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val}
|
||||
}
|
||||
b := val.Interface().(Binary)
|
||||
|
||||
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
|
||||
}
|
||||
|
||||
// vectorEncodeValue is the ValueEncoderFunc for Vector.
|
||||
func vectorEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
t := val.Type()
|
||||
if !val.IsValid() || t != tVector {
|
||||
return ValueEncoderError{Name: "VectorEncodeValue",
|
||||
Types: []reflect.Type{tVector},
|
||||
Received: val,
|
||||
}
|
||||
}
|
||||
v := val.Interface().(Vector)
|
||||
b := v.Binary()
|
||||
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
|
||||
}
|
||||
|
||||
// undefinedEncodeValue is the ValueEncoderFunc for Undefined.
|
||||
func undefinedEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tUndefined {
|
||||
return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val}
|
||||
}
|
||||
|
||||
return vw.WriteUndefined()
|
||||
}
|
||||
|
||||
// dateTimeEncodeValue is the ValueEncoderFunc for DateTime.
|
||||
func dateTimeEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tDateTime {
|
||||
return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val}
|
||||
}
|
||||
|
||||
return vw.WriteDateTime(val.Int())
|
||||
}
|
||||
|
||||
// nullEncodeValue is the ValueEncoderFunc for Null.
|
||||
func nullEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tNull {
|
||||
return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val}
|
||||
}
|
||||
|
||||
return vw.WriteNull()
|
||||
}
|
||||
|
||||
// regexEncodeValue is the ValueEncoderFunc for Regex.
|
||||
func regexEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tRegex {
|
||||
return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val}
|
||||
}
|
||||
|
||||
regex := val.Interface().(Regex)
|
||||
|
||||
return vw.WriteRegex(regex.Pattern, regex.Options)
|
||||
}
|
||||
|
||||
// dbPointerEncodeValue is the ValueEncoderFunc for DBPointer.
|
||||
func dbPointerEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tDBPointer {
|
||||
return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val}
|
||||
}
|
||||
|
||||
dbp := val.Interface().(DBPointer)
|
||||
|
||||
return vw.WriteDBPointer(dbp.DB, dbp.Pointer)
|
||||
}
|
||||
|
||||
// timestampEncodeValue is the ValueEncoderFunc for Timestamp.
|
||||
func timestampEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tTimestamp {
|
||||
return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val}
|
||||
}
|
||||
|
||||
ts := val.Interface().(Timestamp)
|
||||
|
||||
return vw.WriteTimestamp(ts.T, ts.I)
|
||||
}
|
||||
|
||||
// minKeyEncodeValue is the ValueEncoderFunc for MinKey.
|
||||
func minKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tMinKey {
|
||||
return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val}
|
||||
}
|
||||
|
||||
return vw.WriteMinKey()
|
||||
}
|
||||
|
||||
// maxKeyEncodeValue is the ValueEncoderFunc for MaxKey.
|
||||
func maxKeyEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tMaxKey {
|
||||
return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val}
|
||||
}
|
||||
|
||||
return vw.WriteMaxKey()
|
||||
}
|
||||
|
||||
// coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document.
|
||||
func coreDocumentEncodeValue(_ EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tCoreDocument {
|
||||
return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val}
|
||||
}
|
||||
|
||||
cdoc := val.Interface().(bsoncore.Document)
|
||||
|
||||
return copyDocumentFromBytes(vw, cdoc)
|
||||
}
|
||||
|
||||
// codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope.
|
||||
func codeWithScopeEncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tCodeWithScope {
|
||||
return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val}
|
||||
}
|
||||
|
||||
cws := val.Interface().(CodeWithScope)
|
||||
|
||||
dw, err := vw.WriteCodeWithScope(string(cws.Code))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sw := sliceWriterPool.Get().(*sliceWriter)
|
||||
defer sliceWriterPool.Put(sw)
|
||||
*sw = (*sw)[:0]
|
||||
|
||||
scopeVW := bvwPool.Get().(*valueWriter)
|
||||
scopeVW.reset(scopeVW.buf[:0])
|
||||
scopeVW.w = sw
|
||||
defer bvwPool.Put(scopeVW)
|
||||
|
||||
encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = copyBytesToDocumentWriter(dw, *sw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return dw.WriteDocumentEnd()
|
||||
}
|
||||
|
||||
// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type
|
||||
func isImplementationNil(val reflect.Value, inter reflect.Type) bool {
|
||||
vt := val.Type()
|
||||
for vt.Kind() == reflect.Ptr {
|
||||
vt = vt.Elem()
|
||||
}
|
||||
return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil()
|
||||
}
|
1758
default_value_encoders_test.go
Normal file
1758
default_value_encoders_test.go
Normal file
File diff suppressed because it is too large
Load Diff
155
doc.go
Normal file
155
doc.go
Normal file
@ -0,0 +1,155 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Package bson is a library for reading, writing, and manipulating BSON. BSON is a binary serialization format used to
|
||||
// store documents and make remote procedure calls in MongoDB. The BSON specification is located at https://bsonspec.org.
|
||||
// The BSON library handles marshaling and unmarshaling of values through a configurable codec system. For a description
|
||||
// of the codec system and examples of registering custom codecs, see the bsoncodec package. For additional information
|
||||
// and usage examples, check out the [Work with BSON] page in the Go Driver docs site.
|
||||
//
|
||||
// # Raw BSON
|
||||
//
|
||||
// The Raw family of types is used to validate and retrieve elements from a slice of bytes. This
|
||||
// type is most useful when you want do lookups on BSON bytes without unmarshaling it into another
|
||||
// type.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var raw bson.Raw = ... // bytes from somewhere
|
||||
// err := raw.Validate()
|
||||
// if err != nil { return err }
|
||||
// val := raw.Lookup("foo")
|
||||
// i32, ok := val.Int32OK()
|
||||
// // do something with i32...
|
||||
//
|
||||
// # Native Go Types
|
||||
//
|
||||
// The D and M types defined in this package can be used to build representations of BSON using native Go types. D is a
|
||||
// slice and M is a map. For more information about the use cases for these types, see the documentation on the type
|
||||
// definitions.
|
||||
//
|
||||
// Note that a D should not be constructed with duplicate key names, as that can cause undefined server behavior.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
|
||||
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
|
||||
//
|
||||
// When decoding BSON to a D or M, the following type mappings apply when unmarshaling:
|
||||
//
|
||||
// 1. BSON int32 unmarshals to an int32.
|
||||
// 2. BSON int64 unmarshals to an int64.
|
||||
// 3. BSON double unmarshals to a float64.
|
||||
// 4. BSON string unmarshals to a string.
|
||||
// 5. BSON boolean unmarshals to a bool.
|
||||
// 6. BSON embedded document unmarshals to the parent type (i.e. D for a D, M for an M).
|
||||
// 7. BSON array unmarshals to a bson.A.
|
||||
// 8. BSON ObjectId unmarshals to a bson.ObjectID.
|
||||
// 9. BSON datetime unmarshals to a bson.DateTime.
|
||||
// 10. BSON binary unmarshals to a bson.Binary.
|
||||
// 11. BSON regular expression unmarshals to a bson.Regex.
|
||||
// 12. BSON JavaScript unmarshals to a bson.JavaScript.
|
||||
// 13. BSON code with scope unmarshals to a bson.CodeWithScope.
|
||||
// 14. BSON timestamp unmarshals to an bson.Timestamp.
|
||||
// 15. BSON 128-bit decimal unmarshals to an bson.Decimal128.
|
||||
// 16. BSON min key unmarshals to an bson.MinKey.
|
||||
// 17. BSON max key unmarshals to an bson.MaxKey.
|
||||
// 18. BSON undefined unmarshals to a bson.Undefined.
|
||||
// 19. BSON null unmarshals to nil.
|
||||
// 20. BSON DBPointer unmarshals to a bson.DBPointer.
|
||||
// 21. BSON symbol unmarshals to a bson.Symbol.
|
||||
//
|
||||
// The above mappings also apply when marshaling a D or M to BSON. Some other useful marshaling mappings are:
|
||||
//
|
||||
// 1. time.Time marshals to a BSON datetime.
|
||||
// 2. int8, int16, and int32 marshal to a BSON int32.
|
||||
// 3. int marshals to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, inclusive, and a BSON int64
|
||||
// otherwise.
|
||||
// 4. int64 marshals to BSON int64 (unless [Encoder.IntMinSize] is set).
|
||||
// 5. uint8 and uint16 marshal to a BSON int32.
|
||||
// 6. uint, uint32, and uint64 marshal to a BSON int64 (unless [Encoder.IntMinSize] is set).
|
||||
// 7. BSON null and undefined values will unmarshal into the zero value of a field (e.g. unmarshaling a BSON null or
|
||||
// undefined value into a string will yield the empty string.).
|
||||
//
|
||||
// # Structs
|
||||
//
|
||||
// Structs can be marshaled/unmarshaled to/from BSON or Extended JSON. When transforming structs to/from BSON or Extended
|
||||
// JSON, the following rules apply:
|
||||
//
|
||||
// 1. Only exported fields in structs will be marshaled or unmarshaled.
|
||||
//
|
||||
// 2. When marshaling a struct, each field will be lowercased to generate the key for the corresponding BSON element.
|
||||
// For example, a struct field named "Foo" will generate key "foo". This can be overridden via a struct tag (e.g.
|
||||
// `bson:"fooField"` to generate key "fooField" instead).
|
||||
//
|
||||
// 3. An embedded struct field is marshaled as a subdocument. The key will be the lowercased name of the field's type.
|
||||
//
|
||||
// 4. A pointer field is marshaled as the underlying type if the pointer is non-nil. If the pointer is nil, it is
|
||||
// marshaled as a BSON null value.
|
||||
//
|
||||
// 5. When unmarshaling, a field of type interface{} will follow the D/M type mappings listed above. BSON documents
|
||||
// unmarshaled into an interface{} field will be unmarshaled as a D.
|
||||
//
|
||||
// The encoding of each struct field can be customized by the "bson" struct tag.
|
||||
//
|
||||
// This tag behavior is configurable, and different struct tag behavior can be configured by initializing a new
|
||||
// bsoncodec.StructCodec with the desired tag parser and registering that StructCodec onto the Registry. By default, JSON
|
||||
// tags are not honored, but that can be enabled by creating a StructCodec with JSONFallbackStructTagParser, like below:
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// structcodec, _ := bsoncodec.NewStructCodec(bsoncodec.JSONFallbackStructTagParser)
|
||||
//
|
||||
// The bson tag gives the name of the field, possibly followed by a comma-separated list of options.
|
||||
// The name may be empty in order to specify options without overriding the default field name. The following options can
|
||||
// be used to configure behavior:
|
||||
//
|
||||
// 1. omitempty: If the "omitempty" struct tag is specified on a field, the field will not be marshaled if it is set to
|
||||
// an "empty" value. Numbers, booleans, and strings are considered empty if their value is equal to the zero value for
|
||||
// the type (i.e. 0 for numbers, false for booleans, and "" for strings). Slices, maps, and arrays are considered
|
||||
// empty if they are of length zero. Interfaces and pointers are considered empty if their value is nil. By default,
|
||||
// structs are only considered empty if the struct type implements [bsoncodec.Zeroer] and the IsZero
|
||||
// method returns true. Struct types that do not implement [bsoncodec.Zeroer] are never considered empty and will be
|
||||
// marshaled as embedded documents. NOTE: It is recommended that this tag be used for all slice and map fields.
|
||||
//
|
||||
// 2. minsize: If the minsize struct tag is specified on a field of type int64, uint, uint32, or uint64 and the value of
|
||||
// the field can fit in a signed int32, the field will be serialized as a BSON int32 rather than a BSON int64. For
|
||||
// other types, this tag is ignored.
|
||||
//
|
||||
// 3. truncate: If the truncate struct tag is specified on a field with a non-float numeric type, BSON doubles
|
||||
// unmarshaled into that field will be truncated at the decimal point. For example, if 3.14 is unmarshaled into a
|
||||
// field of type int, it will be unmarshaled as 3. If this tag is not specified, the decoder will throw an error if
|
||||
// the value cannot be decoded without losing precision. For float64 or non-numeric types, this tag is ignored.
|
||||
//
|
||||
// 4. inline: If the inline struct tag is specified for a struct or map field, the field will be "flattened" when
|
||||
// marshaling and "un-flattened" when unmarshaling. This means that all of the fields in that struct/map will be
|
||||
// pulled up one level and will become top-level fields rather than being fields in a nested document. For example,
|
||||
// if a map field named "Map" with value map[string]interface{}{"foo": "bar"} is inlined, the resulting document will
|
||||
// be {"foo": "bar"} instead of {"map": {"foo": "bar"}}. There can only be one inlined map field in a struct. If
|
||||
// there are duplicated fields in the resulting document when an inlined struct is marshaled, the inlined field will
|
||||
// be overwritten. If there are duplicated fields in the resulting document when an inlined map is marshaled, an
|
||||
// error will be returned. This tag can be used with fields that are pointers to structs. If an inlined pointer field
|
||||
// is nil, it will not be marshaled. For fields that are not maps or structs, this tag is ignored.
|
||||
//
|
||||
// # Marshaling and Unmarshaling
|
||||
//
|
||||
// Manually marshaling and unmarshaling can be done with the Marshal and Unmarshal family of functions.
|
||||
//
|
||||
// bsoncodec code provides a system for encoding values to BSON representations and decoding
|
||||
// values from BSON representations. This package considers both binary BSON and ExtendedJSON as
|
||||
// BSON representations. The types in this package enable a flexible system for handling this
|
||||
// encoding and decoding.
|
||||
//
|
||||
// The codec system is composed of two parts:
|
||||
//
|
||||
// 1) [ValueEncoder] and [ValueDecoder] that handle encoding and decoding Go values to and from BSON
|
||||
// representations.
|
||||
//
|
||||
// 2) A [Registry] that holds these ValueEncoders and ValueDecoders and provides methods for
|
||||
// retrieving them.
|
||||
//
|
||||
// [Work with BSON]: https://www.mongodb.com/docs/drivers/go/current/fundamentals/bson/
|
||||
package bson
|
127
empty_interface_codec.go
Normal file
127
empty_interface_codec.go
Normal file
@ -0,0 +1,127 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// emptyInterfaceCodec is the Codec used for interface{} values.
|
||||
type emptyInterfaceCodec struct {
|
||||
// decodeBinaryAsSlice causes DecodeValue to unmarshal BSON binary field values that are the
|
||||
// "Generic" or "Old" BSON binary subtype as a Go byte slice instead of a Binary.
|
||||
decodeBinaryAsSlice bool
|
||||
}
|
||||
|
||||
// Assert that emptyInterfaceCodec satisfies the typeDecoder interface, which allows it
|
||||
// to be used by collection type decoders (e.g. map, slice, etc) to set individual values in a
|
||||
// collection.
|
||||
var _ typeDecoder = &emptyInterfaceCodec{}
|
||||
|
||||
// EncodeValue is the ValueEncoderFunc for interface{}.
|
||||
func (eic *emptyInterfaceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tEmpty {
|
||||
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
|
||||
}
|
||||
|
||||
if val.IsNil() {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
encoder, err := ec.LookupEncoder(val.Elem().Type())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return encoder.EncodeValue(ec, vw, val.Elem())
|
||||
}
|
||||
|
||||
func (eic *emptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType Type) (reflect.Type, error) {
|
||||
isDocument := valueType == Type(0) || valueType == TypeEmbeddedDocument
|
||||
if isDocument {
|
||||
if dc.defaultDocumentType != nil {
|
||||
// If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return
|
||||
// that type.
|
||||
return dc.defaultDocumentType, nil
|
||||
}
|
||||
}
|
||||
|
||||
rtype, err := dc.LookupTypeMapEntry(valueType)
|
||||
if err == nil {
|
||||
return rtype, nil
|
||||
}
|
||||
|
||||
if isDocument {
|
||||
// For documents, fallback to looking up a type map entry for Type(0) or TypeEmbeddedDocument,
|
||||
// depending on the original valueType.
|
||||
var lookupType Type
|
||||
switch valueType {
|
||||
case Type(0):
|
||||
lookupType = TypeEmbeddedDocument
|
||||
case TypeEmbeddedDocument:
|
||||
lookupType = Type(0)
|
||||
}
|
||||
|
||||
rtype, err = dc.LookupTypeMapEntry(lookupType)
|
||||
if err == nil {
|
||||
return rtype, nil
|
||||
}
|
||||
// fallback to bson.D
|
||||
return tD, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (eic *emptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) {
|
||||
if t != tEmpty {
|
||||
return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)}
|
||||
}
|
||||
|
||||
rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type())
|
||||
if err != nil {
|
||||
switch vr.Type() {
|
||||
case TypeNull:
|
||||
return reflect.Zero(t), vr.ReadNull()
|
||||
default:
|
||||
return emptyValue, err
|
||||
}
|
||||
}
|
||||
|
||||
decoder, err := dc.LookupDecoder(rtype)
|
||||
if err != nil {
|
||||
return emptyValue, err
|
||||
}
|
||||
|
||||
elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, rtype)
|
||||
if err != nil {
|
||||
return emptyValue, err
|
||||
}
|
||||
|
||||
if (eic.decodeBinaryAsSlice || dc.binaryAsSlice) && rtype == tBinary {
|
||||
binElem := elem.Interface().(Binary)
|
||||
if binElem.Subtype == TypeBinaryGeneric || binElem.Subtype == TypeBinaryBinaryOld {
|
||||
elem = reflect.ValueOf(binElem.Data)
|
||||
}
|
||||
}
|
||||
|
||||
return elem, nil
|
||||
}
|
||||
|
||||
// DecodeValue is the ValueDecoderFunc for interface{}.
|
||||
func (eic *emptyInterfaceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error {
|
||||
if !val.CanSet() || val.Type() != tEmpty {
|
||||
return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val}
|
||||
}
|
||||
|
||||
elem, err := eic.decodeType(dc, vr, val.Type())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val.Set(elem)
|
||||
return nil
|
||||
}
|
123
encoder.go
Normal file
123
encoder.go
Normal file
@ -0,0 +1,123 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// This pool is used to keep the allocations of Encoders down. This is only used for the Marshal*
|
||||
// methods and is not consumable from outside of this package. The Encoders retrieved from this pool
|
||||
// must have both Reset and SetRegistry called on them.
|
||||
var encPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(Encoder)
|
||||
},
|
||||
}
|
||||
|
||||
// An Encoder writes a serialization format to an output stream. It writes to a ValueWriter
|
||||
// as the destination of BSON data.
|
||||
type Encoder struct {
|
||||
ec EncodeContext
|
||||
vw ValueWriter
|
||||
}
|
||||
|
||||
// NewEncoder returns a new encoder that writes to vw.
|
||||
func NewEncoder(vw ValueWriter) *Encoder {
|
||||
return &Encoder{
|
||||
ec: EncodeContext{Registry: defaultRegistry},
|
||||
vw: vw,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode writes the BSON encoding of val to the stream.
|
||||
//
|
||||
// See [Marshal] for details about BSON marshaling behavior.
|
||||
func (e *Encoder) Encode(val interface{}) error {
|
||||
if marshaler, ok := val.(Marshaler); ok {
|
||||
// TODO(skriptble): Should we have a MarshalAppender interface so that we can have []byte reuse?
|
||||
buf, err := marshaler.MarshalBSON()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return copyDocumentFromBytes(e.vw, buf)
|
||||
}
|
||||
|
||||
encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val))
|
||||
}
|
||||
|
||||
// Reset will reset the state of the Encoder, using the same *EncodeContext used in
|
||||
// the original construction but using vw.
|
||||
func (e *Encoder) Reset(vw ValueWriter) {
|
||||
e.vw = vw
|
||||
}
|
||||
|
||||
// SetRegistry replaces the current registry of the Encoder with r.
|
||||
func (e *Encoder) SetRegistry(r *Registry) {
|
||||
e.ec.Registry = r
|
||||
}
|
||||
|
||||
// ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in
|
||||
// the marshaled BSON when the "inline" struct tag option is set.
|
||||
func (e *Encoder) ErrorOnInlineDuplicates() {
|
||||
e.ec.errorOnInlineDuplicates = true
|
||||
}
|
||||
|
||||
// IntMinSize causes the Encoder to marshal Go integer values (int, int8, int16, int32, int64, uint,
|
||||
// uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can
|
||||
// represent the integer value.
|
||||
func (e *Encoder) IntMinSize() {
|
||||
e.ec.minSize = true
|
||||
}
|
||||
|
||||
// StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name
|
||||
// strings using fmt.Sprint instead of the default string conversion logic.
|
||||
func (e *Encoder) StringifyMapKeysWithFmt() {
|
||||
e.ec.stringifyMapKeysWithFmt = true
|
||||
}
|
||||
|
||||
// NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON
|
||||
// null.
|
||||
func (e *Encoder) NilMapAsEmpty() {
|
||||
e.ec.nilMapAsEmpty = true
|
||||
}
|
||||
|
||||
// NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON
|
||||
// null.
|
||||
func (e *Encoder) NilSliceAsEmpty() {
|
||||
e.ec.nilSliceAsEmpty = true
|
||||
}
|
||||
|
||||
// NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values
|
||||
// instead of BSON null.
|
||||
func (e *Encoder) NilByteSliceAsEmpty() {
|
||||
e.ec.nilByteSliceAsEmpty = true
|
||||
}
|
||||
|
||||
// TODO(GODRIVER-2820): Update the description to remove the note about only examining exported
|
||||
// TODO struct fields once the logic is updated to also inspect private struct fields.
|
||||
|
||||
// OmitZeroStruct causes the Encoder to consider the zero value for a struct (e.g. MyStruct{})
|
||||
// as empty and omit it from the marshaled BSON when the "omitempty" struct tag option is set.
|
||||
//
|
||||
// Note that the Encoder only examines exported struct fields when determining if a struct is the
|
||||
// zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty.
|
||||
func (e *Encoder) OmitZeroStruct() {
|
||||
e.ec.omitZeroStruct = true
|
||||
}
|
||||
|
||||
// UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson"
|
||||
// struct tag is not specified.
|
||||
func (e *Encoder) UseJSONStructTags() {
|
||||
e.ec.useJSONStructTags = true
|
||||
}
|
240
encoder_example_test.go
Normal file
240
encoder_example_test.go
Normal file
@ -0,0 +1,240 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
)
|
||||
|
||||
func ExampleEncoder() {
|
||||
// Create an Encoder that writes BSON values to a bytes.Buffer.
|
||||
buf := new(bytes.Buffer)
|
||||
vw := bson.NewDocumentWriter(buf)
|
||||
encoder := bson.NewEncoder(vw)
|
||||
|
||||
type Product struct {
|
||||
Name string `bson:"name"`
|
||||
SKU string `bson:"sku"`
|
||||
Price int64 `bson:"price_cents"`
|
||||
}
|
||||
|
||||
// Use the Encoder to marshal a BSON document that contains the name, SKU,
|
||||
// and price (in cents) of a product.
|
||||
product := Product{
|
||||
Name: "Cereal Rounds",
|
||||
SKU: "AB12345",
|
||||
Price: 399,
|
||||
}
|
||||
err := encoder.Encode(product)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Print the BSON document as Extended JSON by converting it to bson.Raw.
|
||||
fmt.Println(bson.Raw(buf.Bytes()).String())
|
||||
// Output: {"name": "Cereal Rounds","sku": "AB12345","price_cents": {"$numberLong":"399"}}
|
||||
}
|
||||
|
||||
type CityState struct {
|
||||
City string
|
||||
State string
|
||||
}
|
||||
|
||||
func (k CityState) String() string {
|
||||
return fmt.Sprintf("%s, %s", k.City, k.State)
|
||||
}
|
||||
|
||||
func ExampleEncoder_StringifyMapKeysWithFmt() {
|
||||
// Create an Encoder that writes BSON values to a bytes.Buffer.
|
||||
buf := new(bytes.Buffer)
|
||||
vw := bson.NewDocumentWriter(buf)
|
||||
encoder := bson.NewEncoder(vw)
|
||||
|
||||
// Configure the Encoder to convert Go map keys to BSON document field names
|
||||
// using fmt.Sprintf instead of the default string conversion logic.
|
||||
encoder.StringifyMapKeysWithFmt()
|
||||
|
||||
// Use the Encoder to marshal a BSON document that contains is a map of
|
||||
// city and state to a list of zip codes in that city.
|
||||
zipCodes := map[CityState][]int{
|
||||
{City: "New York", State: "NY"}: {10001, 10301, 10451},
|
||||
}
|
||||
err := encoder.Encode(zipCodes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Print the BSON document as Extended JSON by converting it to bson.Raw.
|
||||
fmt.Println(bson.Raw(buf.Bytes()).String())
|
||||
// Output: {"New York, NY": [{"$numberInt":"10001"},{"$numberInt":"10301"},{"$numberInt":"10451"}]}
|
||||
}
|
||||
|
||||
func ExampleEncoder_UseJSONStructTags() {
|
||||
// Create an Encoder that writes BSON values to a bytes.Buffer.
|
||||
buf := new(bytes.Buffer)
|
||||
vw := bson.NewDocumentWriter(buf)
|
||||
encoder := bson.NewEncoder(vw)
|
||||
|
||||
type Product struct {
|
||||
Name string `json:"name"`
|
||||
SKU string `json:"sku"`
|
||||
Price int64 `json:"price_cents"`
|
||||
}
|
||||
|
||||
// Configure the Encoder to use "json" struct tags when decoding if "bson"
|
||||
// struct tags are not present.
|
||||
encoder.UseJSONStructTags()
|
||||
|
||||
// Use the Encoder to marshal a BSON document that contains the name, SKU,
|
||||
// and price (in cents) of a product.
|
||||
product := Product{
|
||||
Name: "Cereal Rounds",
|
||||
SKU: "AB12345",
|
||||
Price: 399,
|
||||
}
|
||||
err := encoder.Encode(product)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Print the BSON document as Extended JSON by converting it to bson.Raw.
|
||||
fmt.Println(bson.Raw(buf.Bytes()).String())
|
||||
// Output: {"name": "Cereal Rounds","sku": "AB12345","price_cents": {"$numberLong":"399"}}
|
||||
}
|
||||
|
||||
func ExampleEncoder_multipleBSONDocuments() {
|
||||
// Create an Encoder that writes BSON values to a bytes.Buffer.
|
||||
buf := new(bytes.Buffer)
|
||||
vw := bson.NewDocumentWriter(buf)
|
||||
encoder := bson.NewEncoder(vw)
|
||||
|
||||
type Coordinate struct {
|
||||
X int
|
||||
Y int
|
||||
}
|
||||
|
||||
// Use the encoder to marshal 5 Coordinate values as a sequence of BSON
|
||||
// documents.
|
||||
for i := 0; i < 5; i++ {
|
||||
err := encoder.Encode(Coordinate{
|
||||
X: i,
|
||||
Y: i + 1,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read each marshaled BSON document from the buffer and print them as
|
||||
// Extended JSON by converting them to bson.Raw.
|
||||
for {
|
||||
doc, err := bson.ReadDocument(buf)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(doc.String())
|
||||
}
|
||||
// Output:
|
||||
// {"x": {"$numberInt":"0"},"y": {"$numberInt":"1"}}
|
||||
// {"x": {"$numberInt":"1"},"y": {"$numberInt":"2"}}
|
||||
// {"x": {"$numberInt":"2"},"y": {"$numberInt":"3"}}
|
||||
// {"x": {"$numberInt":"3"},"y": {"$numberInt":"4"}}
|
||||
// {"x": {"$numberInt":"4"},"y": {"$numberInt":"5"}}
|
||||
}
|
||||
|
||||
func ExampleEncoder_extendedJSON() {
|
||||
// Create an Encoder that writes canonical Extended JSON values to a
|
||||
// bytes.Buffer.
|
||||
buf := new(bytes.Buffer)
|
||||
vw := bson.NewExtJSONValueWriter(buf, true, false)
|
||||
encoder := bson.NewEncoder(vw)
|
||||
|
||||
type Product struct {
|
||||
Name string `bson:"name"`
|
||||
SKU string `bson:"sku"`
|
||||
Price int64 `bson:"price_cents"`
|
||||
}
|
||||
|
||||
// Use the Encoder to marshal a BSON document that contains the name, SKU,
|
||||
// and price (in cents) of a product.
|
||||
product := Product{
|
||||
Name: "Cereal Rounds",
|
||||
SKU: "AB12345",
|
||||
Price: 399,
|
||||
}
|
||||
err := encoder.Encode(product)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println(buf.String())
|
||||
// Output: {"name":"Cereal Rounds","sku":"AB12345","price_cents":{"$numberLong":"399"}}
|
||||
}
|
||||
|
||||
func ExampleEncoder_multipleExtendedJSONDocuments() {
|
||||
// Create an Encoder that writes canonical Extended JSON values to a
|
||||
// bytes.Buffer.
|
||||
buf := new(bytes.Buffer)
|
||||
vw := bson.NewExtJSONValueWriter(buf, true, false)
|
||||
encoder := bson.NewEncoder(vw)
|
||||
|
||||
type Coordinate struct {
|
||||
X int
|
||||
Y int
|
||||
}
|
||||
|
||||
// Use the encoder to marshal 5 Coordinate values as a sequence of Extended
|
||||
// JSON documents.
|
||||
for i := 0; i < 5; i++ {
|
||||
err := encoder.Encode(Coordinate{
|
||||
X: i,
|
||||
Y: i + 1,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println(buf.String())
|
||||
// Output:
|
||||
// {"x":{"$numberInt":"0"},"y":{"$numberInt":"1"}}
|
||||
// {"x":{"$numberInt":"1"},"y":{"$numberInt":"2"}}
|
||||
// {"x":{"$numberInt":"2"},"y":{"$numberInt":"3"}}
|
||||
// {"x":{"$numberInt":"3"},"y":{"$numberInt":"4"}}
|
||||
// {"x":{"$numberInt":"4"},"y":{"$numberInt":"5"}}
|
||||
}
|
||||
|
||||
func ExampleEncoder_IntMinSize() {
|
||||
// Create an encoder that will marshal integers as the minimum BSON int size
|
||||
// (either 32 or 64 bits) that can represent the integer value.
|
||||
type foo struct {
|
||||
Bar uint32
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
vw := bson.NewDocumentWriter(buf)
|
||||
|
||||
enc := bson.NewEncoder(vw)
|
||||
enc.IntMinSize()
|
||||
|
||||
err := enc.Encode(foo{2})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println(bson.Raw(buf.Bytes()).String())
|
||||
// Output:
|
||||
// {"bar": {"$numberInt":"2"}}
|
||||
}
|
303
encoder_test.go
Normal file
303
encoder_test.go
Normal file
@ -0,0 +1,303 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"gitea.psichedelico.com/go/bson/internal/require"
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
func TestBasicEncode(t *testing.T) {
|
||||
for _, tc := range marshalingTestCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := make(sliceWriter, 0, 1024)
|
||||
vw := NewDocumentWriter(&got)
|
||||
reg := defaultRegistry
|
||||
encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val))
|
||||
noerr(t, err)
|
||||
err = encoder.EncodeValue(EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.val))
|
||||
noerr(t, err)
|
||||
|
||||
if !bytes.Equal(got, tc.want) {
|
||||
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
|
||||
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderEncode(t *testing.T) {
|
||||
for _, tc := range marshalingTestCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := make(sliceWriter, 0, 1024)
|
||||
vw := NewDocumentWriter(&got)
|
||||
enc := NewEncoder(vw)
|
||||
err := enc.Encode(tc.val)
|
||||
noerr(t, err)
|
||||
|
||||
if !bytes.Equal(got, tc.want) {
|
||||
t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
|
||||
t.Errorf("Bytes:\n%v\n%v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Marshaler", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
buf []byte
|
||||
err error
|
||||
wanterr error
|
||||
vw ValueWriter
|
||||
}{
|
||||
{
|
||||
"error",
|
||||
nil,
|
||||
errors.New("Marshaler error"),
|
||||
errors.New("Marshaler error"),
|
||||
&valueReaderWriter{},
|
||||
},
|
||||
{
|
||||
"copy error",
|
||||
[]byte{0x05, 0x00, 0x00, 0x00, 0x00},
|
||||
nil,
|
||||
errors.New("copy error"),
|
||||
&valueReaderWriter{Err: errors.New("copy error"), ErrAfter: writeDocument},
|
||||
},
|
||||
{
|
||||
"success",
|
||||
[]byte{0x07, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
marshaler := testMarshaler{buf: tc.buf, err: tc.err}
|
||||
|
||||
var vw ValueWriter
|
||||
b := make(sliceWriter, 0, 100)
|
||||
compareVW := false
|
||||
if tc.vw != nil {
|
||||
vw = tc.vw
|
||||
} else {
|
||||
compareVW = true
|
||||
vw = NewDocumentWriter(&b)
|
||||
}
|
||||
enc := NewEncoder(vw)
|
||||
got := enc.Encode(marshaler)
|
||||
want := tc.wanterr
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
|
||||
}
|
||||
if compareVW {
|
||||
buf := b
|
||||
if !bytes.Equal(buf, tc.buf) {
|
||||
t.Errorf("Copied bytes do not match. got %v; want %v", buf, tc.buf)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type testMarshaler struct {
|
||||
buf []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func (tm testMarshaler) MarshalBSON() ([]byte, error) { return tm.buf, tm.err }
|
||||
|
||||
func docToBytes(d interface{}) []byte {
|
||||
b, err := Marshal(d)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
type stringerTest struct{}
|
||||
|
||||
func (stringerTest) String() string {
|
||||
return "test key"
|
||||
}
|
||||
|
||||
func TestEncoderConfiguration(t *testing.T) {
|
||||
type inlineDuplicateInner struct {
|
||||
Duplicate string
|
||||
}
|
||||
|
||||
type inlineDuplicateOuter struct {
|
||||
Inline inlineDuplicateInner `bson:",inline"`
|
||||
Duplicate string
|
||||
}
|
||||
|
||||
type zeroStruct struct {
|
||||
MyString string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
configure func(*Encoder)
|
||||
input interface{}
|
||||
want []byte
|
||||
wantErr error
|
||||
}{
|
||||
// Test that ErrorOnInlineDuplicates causes the Encoder to return an error if there are any
|
||||
// duplicate fields in the marshaled document caused by using the "inline" struct tag.
|
||||
{
|
||||
description: "ErrorOnInlineDuplicates",
|
||||
configure: func(enc *Encoder) {
|
||||
enc.ErrorOnInlineDuplicates()
|
||||
},
|
||||
input: inlineDuplicateOuter{
|
||||
Inline: inlineDuplicateInner{Duplicate: "inner"},
|
||||
Duplicate: "outer",
|
||||
},
|
||||
wantErr: errors.New("struct bson.inlineDuplicateOuter has duplicated key duplicate"),
|
||||
},
|
||||
// Test that IntMinSize encodes Go int and int64 values as BSON int32 if the value is small
|
||||
// enough.
|
||||
{
|
||||
description: "IntMinSize",
|
||||
configure: func(enc *Encoder) {
|
||||
enc.IntMinSize()
|
||||
},
|
||||
input: D{
|
||||
{Key: "myInt", Value: int(1)},
|
||||
{Key: "myInt64", Value: int64(1)},
|
||||
{Key: "myUint", Value: uint(1)},
|
||||
{Key: "myUint32", Value: uint32(1)},
|
||||
{Key: "myUint64", Value: uint64(1)},
|
||||
},
|
||||
want: bsoncore.NewDocumentBuilder().
|
||||
AppendInt32("myInt", 1).
|
||||
AppendInt32("myInt64", 1).
|
||||
AppendInt32("myUint", 1).
|
||||
AppendInt32("myUint32", 1).
|
||||
AppendInt32("myUint64", 1).
|
||||
Build(),
|
||||
},
|
||||
// Test that StringifyMapKeysWithFmt uses fmt.Sprint to convert map keys to BSON field names.
|
||||
{
|
||||
description: "StringifyMapKeysWithFmt",
|
||||
configure: func(enc *Encoder) {
|
||||
enc.StringifyMapKeysWithFmt()
|
||||
},
|
||||
input: map[stringerTest]string{
|
||||
{}: "test value",
|
||||
},
|
||||
want: bsoncore.NewDocumentBuilder().
|
||||
AppendString("test key", "test value").
|
||||
Build(),
|
||||
},
|
||||
// Test that NilMapAsEmpty encodes nil Go maps as empty BSON documents.
|
||||
{
|
||||
description: "NilMapAsEmpty",
|
||||
configure: func(enc *Encoder) {
|
||||
enc.NilMapAsEmpty()
|
||||
},
|
||||
input: D{{Key: "myMap", Value: map[string]string(nil)}},
|
||||
want: bsoncore.NewDocumentBuilder().
|
||||
AppendDocument("myMap", bsoncore.NewDocumentBuilder().Build()).
|
||||
Build(),
|
||||
},
|
||||
// Test that NilSliceAsEmpty encodes nil Go slices as empty BSON arrays.
|
||||
{
|
||||
description: "NilSliceAsEmpty",
|
||||
configure: func(enc *Encoder) {
|
||||
enc.NilSliceAsEmpty()
|
||||
},
|
||||
input: D{{Key: "mySlice", Value: []string(nil)}},
|
||||
want: bsoncore.NewDocumentBuilder().
|
||||
AppendArray("mySlice", bsoncore.NewArrayBuilder().Build()).
|
||||
Build(),
|
||||
},
|
||||
// Test that NilByteSliceAsEmpty encodes nil Go byte slices as empty BSON binary elements.
|
||||
{
|
||||
description: "NilByteSliceAsEmpty",
|
||||
configure: func(enc *Encoder) {
|
||||
enc.NilByteSliceAsEmpty()
|
||||
},
|
||||
input: D{{Key: "myBytes", Value: []byte(nil)}},
|
||||
want: bsoncore.NewDocumentBuilder().
|
||||
AppendBinary("myBytes", TypeBinaryGeneric, []byte{}).
|
||||
Build(),
|
||||
},
|
||||
// Test that OmitZeroStruct omits empty structs from the marshaled document if the
|
||||
// "omitempty" struct tag is used.
|
||||
{
|
||||
description: "OmitZeroStruct",
|
||||
configure: func(enc *Encoder) {
|
||||
enc.OmitZeroStruct()
|
||||
},
|
||||
input: struct {
|
||||
Zero zeroStruct `bson:",omitempty"`
|
||||
}{},
|
||||
want: bsoncore.NewDocumentBuilder().Build(),
|
||||
},
|
||||
// Test that UseJSONStructTags causes the Encoder to fall back to "json" struct tags if
|
||||
// "bson" struct tags are not available.
|
||||
{
|
||||
description: "UseJSONStructTags",
|
||||
configure: func(enc *Encoder) {
|
||||
enc.UseJSONStructTags()
|
||||
},
|
||||
input: struct {
|
||||
StructFieldName string `json:"jsonFieldName"`
|
||||
}{
|
||||
StructFieldName: "test value",
|
||||
},
|
||||
want: bsoncore.NewDocumentBuilder().
|
||||
AppendString("jsonFieldName", "test value").
|
||||
Build(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // Capture range variable.
|
||||
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := new(bytes.Buffer)
|
||||
vw := NewDocumentWriter(got)
|
||||
enc := NewEncoder(vw)
|
||||
|
||||
tc.configure(enc)
|
||||
|
||||
err := enc.Encode(tc.input)
|
||||
if tc.wantErr != nil {
|
||||
assert.Equal(t, tc.wantErr, err, "expected and actual errors do not match")
|
||||
return
|
||||
}
|
||||
require.NoError(t, err, "Encode error")
|
||||
|
||||
assert.Equal(t, tc.want, got.Bytes(), "expected and actual encoded BSON do not match")
|
||||
|
||||
// After we compare the raw bytes, also decode the expected and actual BSON as a bson.D
|
||||
// and compare them. The goal is to make assertion failures easier to debug because
|
||||
// binary diffs are very difficult to understand.
|
||||
var wantDoc D
|
||||
err = Unmarshal(tc.want, &wantDoc)
|
||||
require.NoError(t, err, "Unmarshal error")
|
||||
var gotDoc D
|
||||
err = Unmarshal(got.Bytes(), &gotDoc)
|
||||
require.NoError(t, err, "Unmarshal error")
|
||||
|
||||
assert.Equal(t, wantDoc, gotDoc, "expected and actual decoded documents do not match")
|
||||
})
|
||||
}
|
||||
}
|
143
example_test.go
Normal file
143
example_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
)
|
||||
|
||||
// This example uses Raw to skip parsing a nested document in a BSON message.
|
||||
func ExampleRaw_unmarshal() {
|
||||
b, err := bson.Marshal(bson.M{
|
||||
"Word": "beach",
|
||||
"Synonyms": bson.A{"coast", "shore", "waterfront"},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var res struct {
|
||||
Word string
|
||||
Synonyms bson.Raw // Don't parse the whole list, we just want to count the elements.
|
||||
}
|
||||
|
||||
err = bson.Unmarshal(b, &res)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
elems, err := res.Synonyms.Elements()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Printf("%s, synonyms count: %d\n", res.Word, len(elems))
|
||||
|
||||
// Output: beach, synonyms count: 3
|
||||
}
|
||||
|
||||
// This example uses Raw to add a precomputed BSON document during marshal.
|
||||
func ExampleRaw_marshal() {
|
||||
precomputed, err := bson.Marshal(bson.M{"Precomputed": true})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
msg := struct {
|
||||
Message string
|
||||
Metadata bson.Raw
|
||||
}{
|
||||
Message: "Hello World!",
|
||||
Metadata: precomputed,
|
||||
}
|
||||
|
||||
b, err := bson.Marshal(msg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Print the Extended JSON by converting BSON to bson.Raw.
|
||||
fmt.Println(bson.Raw(b).String())
|
||||
|
||||
// Output: {"message": "Hello World!","metadata": {"Precomputed": true}}
|
||||
}
|
||||
|
||||
// This example uses RawValue to delay parsing a value in a BSON message.
|
||||
func ExampleRawValue_unmarshal() {
|
||||
b1, err := bson.Marshal(bson.M{
|
||||
"Format": "UNIX",
|
||||
"Timestamp": 1675282389,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
b2, err := bson.Marshal(bson.M{
|
||||
"Format": "RFC3339",
|
||||
"Timestamp": time.Unix(1675282389, 0).Format(time.RFC3339),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for _, b := range [][]byte{b1, b2} {
|
||||
var res struct {
|
||||
Format string
|
||||
Timestamp bson.RawValue // Delay parsing until we know the timestamp format.
|
||||
}
|
||||
|
||||
err = bson.Unmarshal(b, &res)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var t time.Time
|
||||
switch res.Format {
|
||||
case "UNIX":
|
||||
t = time.Unix(res.Timestamp.AsInt64(), 0)
|
||||
case "RFC3339":
|
||||
t, err = time.Parse(time.RFC3339, res.Timestamp.StringValue())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
fmt.Println(res.Format, t.Unix())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// UNIX 1675282389
|
||||
// RFC3339 1675282389
|
||||
}
|
||||
|
||||
// This example uses RawValue to add a precomputed BSON string value during marshal.
|
||||
func ExampleRawValue_marshal() {
|
||||
t, val, err := bson.MarshalValue("Precomputed message!")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
precomputed := bson.RawValue{
|
||||
Type: t,
|
||||
Value: val,
|
||||
}
|
||||
|
||||
msg := struct {
|
||||
Message bson.RawValue
|
||||
Time time.Time
|
||||
}{
|
||||
Message: precomputed,
|
||||
Time: time.Unix(1675282389, 0),
|
||||
}
|
||||
|
||||
b, err := bson.Marshal(msg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Print the Extended JSON by converting BSON to bson.Raw.
|
||||
fmt.Println(bson.Raw(b).String())
|
||||
|
||||
// Output: {"message": "Precomputed message!","time": {"$date":{"$numberLong":"1675282389000"}}}
|
||||
}
|
804
extjson_parser.go
Normal file
804
extjson_parser.go
Normal file
@ -0,0 +1,804 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxNestingDepth = 200
|
||||
|
||||
// ErrInvalidJSON indicates the JSON input is invalid
|
||||
var ErrInvalidJSON = errors.New("invalid JSON input")
|
||||
|
||||
type jsonParseState byte
|
||||
|
||||
const (
|
||||
jpsStartState jsonParseState = iota
|
||||
jpsSawBeginObject
|
||||
jpsSawEndObject
|
||||
jpsSawBeginArray
|
||||
jpsSawEndArray
|
||||
jpsSawColon
|
||||
jpsSawComma
|
||||
jpsSawKey
|
||||
jpsSawValue
|
||||
jpsDoneState
|
||||
jpsInvalidState
|
||||
)
|
||||
|
||||
type jsonParseMode byte
|
||||
|
||||
const (
|
||||
jpmInvalidMode jsonParseMode = iota
|
||||
jpmObjectMode
|
||||
jpmArrayMode
|
||||
)
|
||||
|
||||
type extJSONValue struct {
|
||||
t Type
|
||||
v interface{}
|
||||
}
|
||||
|
||||
type extJSONObject struct {
|
||||
keys []string
|
||||
values []*extJSONValue
|
||||
}
|
||||
|
||||
type extJSONParser struct {
|
||||
js *jsonScanner
|
||||
s jsonParseState
|
||||
m []jsonParseMode
|
||||
k string
|
||||
v *extJSONValue
|
||||
|
||||
err error
|
||||
canonicalOnly bool
|
||||
depth int
|
||||
maxDepth int
|
||||
|
||||
emptyObject bool
|
||||
relaxedUUID bool
|
||||
}
|
||||
|
||||
// newExtJSONParser returns a new extended JSON parser, ready to to begin
|
||||
// parsing from the first character of the argued json input. It will not
|
||||
// perform any read-ahead and will therefore not report any errors about
|
||||
// malformed JSON at this point.
|
||||
func newExtJSONParser(r io.Reader, canonicalOnly bool) *extJSONParser {
|
||||
return &extJSONParser{
|
||||
js: &jsonScanner{r: r},
|
||||
s: jpsStartState,
|
||||
m: []jsonParseMode{},
|
||||
canonicalOnly: canonicalOnly,
|
||||
maxDepth: maxNestingDepth,
|
||||
}
|
||||
}
|
||||
|
||||
// peekType examines the next value and returns its BSON Type
|
||||
func (ejp *extJSONParser) peekType() (Type, error) {
|
||||
var t Type
|
||||
var err error
|
||||
initialState := ejp.s
|
||||
|
||||
ejp.advanceState()
|
||||
switch ejp.s {
|
||||
case jpsSawValue:
|
||||
t = ejp.v.t
|
||||
case jpsSawBeginArray:
|
||||
t = TypeArray
|
||||
case jpsInvalidState:
|
||||
err = ejp.err
|
||||
case jpsSawComma:
|
||||
// in array mode, seeing a comma means we need to progress again to actually observe a type
|
||||
if ejp.peekMode() == jpmArrayMode {
|
||||
return ejp.peekType()
|
||||
}
|
||||
case jpsSawEndArray:
|
||||
// this would only be a valid state if we were in array mode, so return end-of-array error
|
||||
err = ErrEOA
|
||||
case jpsSawBeginObject:
|
||||
// peek key to determine type
|
||||
ejp.advanceState()
|
||||
switch ejp.s {
|
||||
case jpsSawEndObject: // empty embedded document
|
||||
t = TypeEmbeddedDocument
|
||||
ejp.emptyObject = true
|
||||
case jpsInvalidState:
|
||||
err = ejp.err
|
||||
case jpsSawKey:
|
||||
if initialState == jpsStartState {
|
||||
return TypeEmbeddedDocument, nil
|
||||
}
|
||||
t = wrapperKeyBSONType(ejp.k)
|
||||
|
||||
// if $uuid is encountered, parse as binary subtype 4
|
||||
if ejp.k == "$uuid" {
|
||||
ejp.relaxedUUID = true
|
||||
t = TypeBinary
|
||||
}
|
||||
|
||||
switch t {
|
||||
case TypeJavaScript:
|
||||
// just saw $code, need to check for $scope at same level
|
||||
_, err = ejp.readValue(TypeJavaScript)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
switch ejp.s {
|
||||
case jpsSawEndObject: // type is TypeJavaScript
|
||||
case jpsSawComma:
|
||||
ejp.advanceState()
|
||||
|
||||
if ejp.s == jpsSawKey && ejp.k == "$scope" {
|
||||
t = TypeCodeWithScope
|
||||
} else {
|
||||
err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
|
||||
}
|
||||
case jpsInvalidState:
|
||||
err = ejp.err
|
||||
default:
|
||||
err = ErrInvalidJSON
|
||||
}
|
||||
case TypeCodeWithScope:
|
||||
err = errors.New("invalid extended JSON: code with $scope must contain $code before $scope")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return t, err
|
||||
}
|
||||
|
||||
// readKey parses the next key and its type and returns them
|
||||
func (ejp *extJSONParser) readKey() (string, Type, error) {
|
||||
if ejp.emptyObject {
|
||||
ejp.emptyObject = false
|
||||
return "", 0, ErrEOD
|
||||
}
|
||||
|
||||
// advance to key (or return with error)
|
||||
switch ejp.s {
|
||||
case jpsStartState:
|
||||
ejp.advanceState()
|
||||
if ejp.s == jpsSawBeginObject {
|
||||
ejp.advanceState()
|
||||
}
|
||||
case jpsSawBeginObject:
|
||||
ejp.advanceState()
|
||||
case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
|
||||
ejp.advanceState()
|
||||
switch ejp.s {
|
||||
case jpsSawBeginObject, jpsSawComma:
|
||||
ejp.advanceState()
|
||||
case jpsSawEndObject:
|
||||
return "", 0, ErrEOD
|
||||
case jpsDoneState:
|
||||
return "", 0, io.EOF
|
||||
case jpsInvalidState:
|
||||
return "", 0, ejp.err
|
||||
default:
|
||||
return "", 0, ErrInvalidJSON
|
||||
}
|
||||
case jpsSawKey: // do nothing (key was peeked before)
|
||||
default:
|
||||
return "", 0, invalidRequestError("key")
|
||||
}
|
||||
|
||||
// read key
|
||||
var key string
|
||||
|
||||
switch ejp.s {
|
||||
case jpsSawKey:
|
||||
key = ejp.k
|
||||
case jpsSawEndObject:
|
||||
return "", 0, ErrEOD
|
||||
case jpsInvalidState:
|
||||
return "", 0, ejp.err
|
||||
default:
|
||||
return "", 0, invalidRequestError("key")
|
||||
}
|
||||
|
||||
// check for colon
|
||||
ejp.advanceState()
|
||||
if err := ensureColon(ejp.s, key); err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
// peek at the value to determine type
|
||||
t, err := ejp.peekType()
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return key, t, nil
|
||||
}
|
||||
|
||||
// readValue returns the value corresponding to the Type returned by peekType
|
||||
func (ejp *extJSONParser) readValue(t Type) (*extJSONValue, error) {
|
||||
if ejp.s == jpsInvalidState {
|
||||
return nil, ejp.err
|
||||
}
|
||||
|
||||
var v *extJSONValue
|
||||
|
||||
switch t {
|
||||
case TypeNull, TypeBoolean, TypeString:
|
||||
if ejp.s != jpsSawValue {
|
||||
return nil, invalidRequestError(t.String())
|
||||
}
|
||||
v = ejp.v
|
||||
case TypeInt32, TypeInt64, TypeDouble:
|
||||
// relaxed version allows these to be literal number values
|
||||
if ejp.s == jpsSawValue {
|
||||
v = ejp.v
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
case TypeDecimal128, TypeSymbol, TypeObjectID, TypeMinKey, TypeMaxKey, TypeUndefined:
|
||||
switch ejp.s {
|
||||
case jpsSawKey:
|
||||
// read colon
|
||||
ejp.advanceState()
|
||||
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// read value
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
|
||||
return nil, invalidJSONErrorForType("value", t)
|
||||
}
|
||||
|
||||
v = ejp.v
|
||||
|
||||
// read end object
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawEndObject {
|
||||
return nil, invalidJSONErrorForType("} after value", t)
|
||||
}
|
||||
default:
|
||||
return nil, invalidRequestError(t.String())
|
||||
}
|
||||
case TypeBinary, TypeRegex, TypeTimestamp, TypeDBPointer:
|
||||
if ejp.s != jpsSawKey {
|
||||
return nil, invalidRequestError(t.String())
|
||||
}
|
||||
// read colon
|
||||
ejp.advanceState()
|
||||
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ejp.advanceState()
|
||||
if t == TypeBinary && ejp.s == jpsSawValue {
|
||||
// convert relaxed $uuid format
|
||||
if ejp.relaxedUUID {
|
||||
defer func() { ejp.relaxedUUID = false }()
|
||||
uuid, err := ejp.v.parseSymbol()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing
|
||||
// in the 8th, 13th, 18th, and 23rd characters.
|
||||
//
|
||||
// See https://tools.ietf.org/html/rfc4122#section-3
|
||||
valid := len(uuid) == 36 &&
|
||||
string(uuid[8]) == "-" &&
|
||||
string(uuid[13]) == "-" &&
|
||||
string(uuid[18]) == "-" &&
|
||||
string(uuid[23]) == "-"
|
||||
if !valid {
|
||||
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
|
||||
}
|
||||
|
||||
// remove hyphens
|
||||
uuidNoHyphens := strings.ReplaceAll(uuid, "-", "")
|
||||
if len(uuidNoHyphens) != 32 {
|
||||
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
|
||||
}
|
||||
|
||||
// convert hex to bytes
|
||||
bytes, err := hex.DecodeString(uuidNoHyphens)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %w", err)
|
||||
}
|
||||
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawEndObject {
|
||||
return nil, invalidJSONErrorForType("$uuid and value and then }", TypeBinary)
|
||||
}
|
||||
|
||||
base64 := &extJSONValue{
|
||||
t: TypeString,
|
||||
v: base64.StdEncoding.EncodeToString(bytes),
|
||||
}
|
||||
subType := &extJSONValue{
|
||||
t: TypeString,
|
||||
v: "04",
|
||||
}
|
||||
|
||||
v = &extJSONValue{
|
||||
t: TypeEmbeddedDocument,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"base64", "subType"},
|
||||
values: []*extJSONValue{base64, subType},
|
||||
},
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// convert legacy $binary format
|
||||
base64 := ejp.v
|
||||
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawComma {
|
||||
return nil, invalidJSONErrorForType(",", TypeBinary)
|
||||
}
|
||||
|
||||
ejp.advanceState()
|
||||
key, t, err := ejp.readKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if key != "$type" {
|
||||
return nil, invalidJSONErrorForType("$type", TypeBinary)
|
||||
}
|
||||
|
||||
subType, err := ejp.readValue(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawEndObject {
|
||||
return nil, invalidJSONErrorForType("2 key-value pairs and then }", TypeBinary)
|
||||
}
|
||||
|
||||
v = &extJSONValue{
|
||||
t: TypeEmbeddedDocument,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"base64", "subType"},
|
||||
values: []*extJSONValue{base64, subType},
|
||||
},
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// read KV pairs
|
||||
if ejp.s != jpsSawBeginObject {
|
||||
return nil, invalidJSONErrorForType("{", t)
|
||||
}
|
||||
|
||||
keys, vals, err := ejp.readObject(2, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawEndObject {
|
||||
return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
|
||||
}
|
||||
|
||||
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
|
||||
|
||||
case TypeDateTime:
|
||||
switch ejp.s {
|
||||
case jpsSawValue:
|
||||
v = ejp.v
|
||||
case jpsSawKey:
|
||||
// read colon
|
||||
ejp.advanceState()
|
||||
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ejp.advanceState()
|
||||
switch ejp.s {
|
||||
case jpsSawBeginObject:
|
||||
keys, vals, err := ejp.readObject(1, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v = &extJSONValue{t: TypeEmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
|
||||
case jpsSawValue:
|
||||
if ejp.canonicalOnly {
|
||||
return nil, invalidJSONError("{")
|
||||
}
|
||||
v = ejp.v
|
||||
default:
|
||||
if ejp.canonicalOnly {
|
||||
return nil, invalidJSONErrorForType("object", t)
|
||||
}
|
||||
return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as described in RFC-3339", t)
|
||||
}
|
||||
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawEndObject {
|
||||
return nil, invalidJSONErrorForType("value and then }", t)
|
||||
}
|
||||
default:
|
||||
return nil, invalidRequestError(t.String())
|
||||
}
|
||||
case TypeJavaScript:
|
||||
switch ejp.s {
|
||||
case jpsSawKey:
|
||||
// read colon
|
||||
ejp.advanceState()
|
||||
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// read value
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawValue {
|
||||
return nil, invalidJSONErrorForType("value", t)
|
||||
}
|
||||
v = ejp.v
|
||||
|
||||
// read end object or comma and just return
|
||||
ejp.advanceState()
|
||||
case jpsSawEndObject:
|
||||
v = ejp.v
|
||||
default:
|
||||
return nil, invalidRequestError(t.String())
|
||||
}
|
||||
case TypeCodeWithScope:
|
||||
if ejp.s == jpsSawKey && ejp.k == "$scope" {
|
||||
v = ejp.v // this is the $code string from earlier
|
||||
|
||||
// read colon
|
||||
ejp.advanceState()
|
||||
if err := ensureColon(ejp.s, ejp.k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// read {
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawBeginObject {
|
||||
return nil, invalidJSONError("$scope to be embedded document")
|
||||
}
|
||||
} else {
|
||||
return nil, invalidRequestError(t.String())
|
||||
}
|
||||
case TypeEmbeddedDocument, TypeArray:
|
||||
return nil, invalidRequestError(t.String())
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// readObject is a utility method for reading full objects of known (or expected) size
|
||||
// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
|
||||
func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
|
||||
keys := make([]string, numKeys)
|
||||
vals := make([]*extJSONValue, numKeys)
|
||||
|
||||
if !started {
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawBeginObject {
|
||||
return nil, nil, invalidJSONError("{")
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < numKeys; i++ {
|
||||
key, t, err := ejp.readKey()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
switch ejp.s {
|
||||
case jpsSawKey:
|
||||
v, err := ejp.readValue(t)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys[i] = key
|
||||
vals[i] = v
|
||||
case jpsSawValue:
|
||||
keys[i] = key
|
||||
vals[i] = ejp.v
|
||||
default:
|
||||
return nil, nil, invalidJSONError("value")
|
||||
}
|
||||
}
|
||||
|
||||
ejp.advanceState()
|
||||
if ejp.s != jpsSawEndObject {
|
||||
return nil, nil, invalidJSONError("}")
|
||||
}
|
||||
|
||||
return keys, vals, nil
|
||||
}
|
||||
|
||||
// advanceState reads the next JSON token from the scanner and transitions
|
||||
// from the current state based on that token's type
|
||||
func (ejp *extJSONParser) advanceState() {
|
||||
if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
|
||||
return
|
||||
}
|
||||
|
||||
jt, err := ejp.js.nextToken()
|
||||
|
||||
if err != nil {
|
||||
ejp.err = err
|
||||
ejp.s = jpsInvalidState
|
||||
return
|
||||
}
|
||||
|
||||
valid := ejp.validateToken(jt.t)
|
||||
if !valid {
|
||||
ejp.err = unexpectedTokenError(jt)
|
||||
ejp.s = jpsInvalidState
|
||||
return
|
||||
}
|
||||
|
||||
switch jt.t {
|
||||
case jttBeginObject:
|
||||
ejp.s = jpsSawBeginObject
|
||||
ejp.pushMode(jpmObjectMode)
|
||||
ejp.depth++
|
||||
|
||||
if ejp.depth > ejp.maxDepth {
|
||||
ejp.err = nestingDepthError(jt.p, ejp.depth)
|
||||
ejp.s = jpsInvalidState
|
||||
}
|
||||
case jttEndObject:
|
||||
ejp.s = jpsSawEndObject
|
||||
ejp.depth--
|
||||
|
||||
if ejp.popMode() != jpmObjectMode {
|
||||
ejp.err = unexpectedTokenError(jt)
|
||||
ejp.s = jpsInvalidState
|
||||
}
|
||||
case jttBeginArray:
|
||||
ejp.s = jpsSawBeginArray
|
||||
ejp.pushMode(jpmArrayMode)
|
||||
case jttEndArray:
|
||||
ejp.s = jpsSawEndArray
|
||||
|
||||
if ejp.popMode() != jpmArrayMode {
|
||||
ejp.err = unexpectedTokenError(jt)
|
||||
ejp.s = jpsInvalidState
|
||||
}
|
||||
case jttColon:
|
||||
ejp.s = jpsSawColon
|
||||
case jttComma:
|
||||
ejp.s = jpsSawComma
|
||||
case jttEOF:
|
||||
ejp.s = jpsDoneState
|
||||
if len(ejp.m) != 0 {
|
||||
ejp.err = unexpectedTokenError(jt)
|
||||
ejp.s = jpsInvalidState
|
||||
}
|
||||
case jttString:
|
||||
switch ejp.s {
|
||||
case jpsSawComma:
|
||||
if ejp.peekMode() == jpmArrayMode {
|
||||
ejp.s = jpsSawValue
|
||||
ejp.v = extendJSONToken(jt)
|
||||
return
|
||||
}
|
||||
fallthrough
|
||||
case jpsSawBeginObject:
|
||||
ejp.s = jpsSawKey
|
||||
ejp.k = jt.v.(string)
|
||||
return
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
ejp.s = jpsSawValue
|
||||
ejp.v = extendJSONToken(jt)
|
||||
}
|
||||
}
|
||||
|
||||
var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
|
||||
jpsStartState: {
|
||||
jttBeginObject: true,
|
||||
jttBeginArray: true,
|
||||
jttInt32: true,
|
||||
jttInt64: true,
|
||||
jttDouble: true,
|
||||
jttString: true,
|
||||
jttBool: true,
|
||||
jttNull: true,
|
||||
jttEOF: true,
|
||||
},
|
||||
jpsSawBeginObject: {
|
||||
jttEndObject: true,
|
||||
jttString: true,
|
||||
},
|
||||
jpsSawEndObject: {
|
||||
jttEndObject: true,
|
||||
jttEndArray: true,
|
||||
jttComma: true,
|
||||
jttEOF: true,
|
||||
},
|
||||
jpsSawBeginArray: {
|
||||
jttBeginObject: true,
|
||||
jttBeginArray: true,
|
||||
jttEndArray: true,
|
||||
jttInt32: true,
|
||||
jttInt64: true,
|
||||
jttDouble: true,
|
||||
jttString: true,
|
||||
jttBool: true,
|
||||
jttNull: true,
|
||||
},
|
||||
jpsSawEndArray: {
|
||||
jttEndObject: true,
|
||||
jttEndArray: true,
|
||||
jttComma: true,
|
||||
jttEOF: true,
|
||||
},
|
||||
jpsSawColon: {
|
||||
jttBeginObject: true,
|
||||
jttBeginArray: true,
|
||||
jttInt32: true,
|
||||
jttInt64: true,
|
||||
jttDouble: true,
|
||||
jttString: true,
|
||||
jttBool: true,
|
||||
jttNull: true,
|
||||
},
|
||||
jpsSawComma: {
|
||||
jttBeginObject: true,
|
||||
jttBeginArray: true,
|
||||
jttInt32: true,
|
||||
jttInt64: true,
|
||||
jttDouble: true,
|
||||
jttString: true,
|
||||
jttBool: true,
|
||||
jttNull: true,
|
||||
},
|
||||
jpsSawKey: {
|
||||
jttColon: true,
|
||||
},
|
||||
jpsSawValue: {
|
||||
jttEndObject: true,
|
||||
jttEndArray: true,
|
||||
jttComma: true,
|
||||
jttEOF: true,
|
||||
},
|
||||
jpsDoneState: {},
|
||||
jpsInvalidState: {},
|
||||
}
|
||||
|
||||
func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
|
||||
switch ejp.s {
|
||||
case jpsSawEndObject:
|
||||
// if we are at depth zero and the next token is a '{',
|
||||
// we can consider it valid only if we are not in array mode.
|
||||
if jtt == jttBeginObject && ejp.depth == 0 {
|
||||
return ejp.peekMode() != jpmArrayMode
|
||||
}
|
||||
case jpsSawComma:
|
||||
switch ejp.peekMode() {
|
||||
// the only valid next token after a comma inside a document is a string (a key)
|
||||
case jpmObjectMode:
|
||||
return jtt == jttString
|
||||
case jpmInvalidMode:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
_, ok := jpsValidTransitionTokens[ejp.s][jtt]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ensureExtValueType returns true if the current value has the expected
|
||||
// value type for single-key extended JSON types. For example,
|
||||
// {"$numberInt": v} v must be TypeString
|
||||
func (ejp *extJSONParser) ensureExtValueType(t Type) bool {
|
||||
switch t {
|
||||
case TypeMinKey, TypeMaxKey:
|
||||
return ejp.v.t == TypeInt32
|
||||
case TypeUndefined:
|
||||
return ejp.v.t == TypeBoolean
|
||||
case TypeInt32, TypeInt64, TypeDouble, TypeDecimal128, TypeSymbol, TypeObjectID:
|
||||
return ejp.v.t == TypeString
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (ejp *extJSONParser) pushMode(m jsonParseMode) {
|
||||
ejp.m = append(ejp.m, m)
|
||||
}
|
||||
|
||||
func (ejp *extJSONParser) popMode() jsonParseMode {
|
||||
l := len(ejp.m)
|
||||
if l == 0 {
|
||||
return jpmInvalidMode
|
||||
}
|
||||
|
||||
m := ejp.m[l-1]
|
||||
ejp.m = ejp.m[:l-1]
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (ejp *extJSONParser) peekMode() jsonParseMode {
|
||||
l := len(ejp.m)
|
||||
if l == 0 {
|
||||
return jpmInvalidMode
|
||||
}
|
||||
|
||||
return ejp.m[l-1]
|
||||
}
|
||||
|
||||
func extendJSONToken(jt *jsonToken) *extJSONValue {
|
||||
var t Type
|
||||
|
||||
switch jt.t {
|
||||
case jttInt32:
|
||||
t = TypeInt32
|
||||
case jttInt64:
|
||||
t = TypeInt64
|
||||
case jttDouble:
|
||||
t = TypeDouble
|
||||
case jttString:
|
||||
t = TypeString
|
||||
case jttBool:
|
||||
t = TypeBoolean
|
||||
case jttNull:
|
||||
t = TypeNull
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return &extJSONValue{t: t, v: jt.v}
|
||||
}
|
||||
|
||||
func ensureColon(s jsonParseState, key string) error {
|
||||
if s != jpsSawColon {
|
||||
return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func invalidRequestError(s string) error {
|
||||
return fmt.Errorf("invalid request to read %s", s)
|
||||
}
|
||||
|
||||
func invalidJSONError(expected string) error {
|
||||
return fmt.Errorf("invalid JSON input; expected %s", expected)
|
||||
}
|
||||
|
||||
func invalidJSONErrorForType(expected string, t Type) error {
|
||||
return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
|
||||
}
|
||||
|
||||
func unexpectedTokenError(jt *jsonToken) error {
|
||||
switch jt.t {
|
||||
case jttInt32, jttInt64, jttDouble:
|
||||
return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
|
||||
case jttString:
|
||||
return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
|
||||
case jttBool:
|
||||
return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
|
||||
case jttNull:
|
||||
return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
|
||||
case jttEOF:
|
||||
return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
|
||||
default:
|
||||
return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
|
||||
}
|
||||
}
|
||||
|
||||
func nestingDepthError(p, depth int) error {
|
||||
return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
|
||||
}
|
804
extjson_parser_test.go
Normal file
804
extjson_parser_test.go
Normal file
@ -0,0 +1,804 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
var (
|
||||
keyDiff = specificDiff("key")
|
||||
typDiff = specificDiff("type")
|
||||
valDiff = specificDiff("value")
|
||||
|
||||
expectErrEOF = expectSpecificError(io.EOF)
|
||||
expectErrEOD = expectSpecificError(ErrEOD)
|
||||
expectErrEOA = expectSpecificError(ErrEOA)
|
||||
)
|
||||
|
||||
type expectedErrorFunc func(t *testing.T, err error, desc string)
|
||||
|
||||
type peekTypeTestCase struct {
|
||||
desc string
|
||||
input string
|
||||
typs []Type
|
||||
errFs []expectedErrorFunc
|
||||
}
|
||||
|
||||
type readKeyValueTestCase struct {
|
||||
desc string
|
||||
input string
|
||||
keys []string
|
||||
typs []Type
|
||||
vals []*extJSONValue
|
||||
|
||||
keyEFs []expectedErrorFunc
|
||||
valEFs []expectedErrorFunc
|
||||
}
|
||||
|
||||
func expectNoError(t *testing.T, err error, desc string) {
|
||||
if err != nil {
|
||||
t.Helper()
|
||||
t.Errorf("%s: Unepexted error: %v", desc, err)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func expectError(t *testing.T, err error, desc string) {
|
||||
if err == nil {
|
||||
t.Helper()
|
||||
t.Errorf("%s: Expected error", desc)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func expectSpecificError(expected error) expectedErrorFunc {
|
||||
return func(t *testing.T, err error, desc string) {
|
||||
if !errors.Is(err, expected) {
|
||||
t.Helper()
|
||||
t.Errorf("%s: Expected %v but got: %v", desc, expected, err)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func specificDiff(name string) func(t *testing.T, expected, actual interface{}, desc string) {
|
||||
return func(t *testing.T, expected, actual interface{}, desc string) {
|
||||
if diff := cmp.Diff(expected, actual); diff != "" {
|
||||
t.Helper()
|
||||
t.Errorf("%s: Incorrect JSON %s (-want, +got): %s\n", desc, name, diff)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func expectErrorNOOP(_ *testing.T, _ error, _ string) {
|
||||
}
|
||||
|
||||
func readKeyDiff(t *testing.T, eKey, aKey string, eTyp, aTyp Type, err error, errF expectedErrorFunc, desc string) {
|
||||
keyDiff(t, eKey, aKey, desc)
|
||||
typDiff(t, eTyp, aTyp, desc)
|
||||
errF(t, err, desc)
|
||||
}
|
||||
|
||||
func readValueDiff(t *testing.T, eVal, aVal *extJSONValue, err error, errF expectedErrorFunc, desc string) {
|
||||
if aVal != nil {
|
||||
typDiff(t, eVal.t, aVal.t, desc)
|
||||
valDiff(t, eVal.v, aVal.v, desc)
|
||||
} else {
|
||||
valDiff(t, eVal, aVal, desc)
|
||||
}
|
||||
|
||||
errF(t, err, desc)
|
||||
}
|
||||
|
||||
func TestExtJSONParserPeekType(t *testing.T) {
|
||||
makeValidPeekTypeTestCase := func(input string, typ Type, desc string) peekTypeTestCase {
|
||||
return peekTypeTestCase{
|
||||
desc: desc, input: input,
|
||||
typs: []Type{typ},
|
||||
errFs: []expectedErrorFunc{expectNoError},
|
||||
}
|
||||
}
|
||||
makeInvalidTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase {
|
||||
return peekTypeTestCase{
|
||||
desc: desc, input: input,
|
||||
typs: []Type{Type(0)},
|
||||
errFs: []expectedErrorFunc{lastEF},
|
||||
}
|
||||
}
|
||||
|
||||
makeInvalidPeekTypeTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase {
|
||||
return peekTypeTestCase{
|
||||
desc: desc, input: input,
|
||||
typs: []Type{TypeArray, TypeString, Type(0)},
|
||||
errFs: []expectedErrorFunc{expectNoError, expectNoError, lastEF},
|
||||
}
|
||||
}
|
||||
|
||||
cases := []peekTypeTestCase{
|
||||
makeValidPeekTypeTestCase(`null`, TypeNull, "Null"),
|
||||
makeValidPeekTypeTestCase(`"string"`, TypeString, "String"),
|
||||
makeValidPeekTypeTestCase(`true`, TypeBoolean, "Boolean--true"),
|
||||
makeValidPeekTypeTestCase(`false`, TypeBoolean, "Boolean--false"),
|
||||
makeValidPeekTypeTestCase(`{"$minKey": 1}`, TypeMinKey, "MinKey"),
|
||||
makeValidPeekTypeTestCase(`{"$maxKey": 1}`, TypeMaxKey, "MaxKey"),
|
||||
makeValidPeekTypeTestCase(`{"$numberInt": "42"}`, TypeInt32, "Int32"),
|
||||
makeValidPeekTypeTestCase(`{"$numberLong": "42"}`, TypeInt64, "Int64"),
|
||||
makeValidPeekTypeTestCase(`{"$symbol": "symbol"}`, TypeSymbol, "Symbol"),
|
||||
makeValidPeekTypeTestCase(`{"$numberDouble": "42.42"}`, TypeDouble, "Double"),
|
||||
makeValidPeekTypeTestCase(`{"$undefined": true}`, TypeUndefined, "Undefined"),
|
||||
makeValidPeekTypeTestCase(`{"$numberDouble": "NaN"}`, TypeDouble, "Double--NaN"),
|
||||
makeValidPeekTypeTestCase(`{"$numberDecimal": "1234"}`, TypeDecimal128, "Decimal"),
|
||||
makeValidPeekTypeTestCase(`{"foo": "bar"}`, TypeEmbeddedDocument, "Toplevel document"),
|
||||
makeValidPeekTypeTestCase(`{"$date": {"$numberLong": "0"}}`, TypeDateTime, "Datetime"),
|
||||
makeValidPeekTypeTestCase(`{"$code": "function() {}"}`, TypeJavaScript, "Code no scope"),
|
||||
makeValidPeekTypeTestCase(`[{"$numberInt": "1"},{"$numberInt": "2"}]`, TypeArray, "Array"),
|
||||
makeValidPeekTypeTestCase(`{"$timestamp": {"t": 42, "i": 1}}`, TypeTimestamp, "Timestamp"),
|
||||
makeValidPeekTypeTestCase(`{"$oid": "57e193d7a9cc81b4027498b5"}`, TypeObjectID, "Object ID"),
|
||||
makeValidPeekTypeTestCase(`{"$binary": {"base64": "AQIDBAU=", "subType": "80"}}`, TypeBinary, "Binary"),
|
||||
makeValidPeekTypeTestCase(`{"$code": "function() {}", "$scope": {}}`, TypeCodeWithScope, "Code With Scope"),
|
||||
makeValidPeekTypeTestCase(`{"$binary": {"base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03"}}`, TypeBinary, "Binary"),
|
||||
makeValidPeekTypeTestCase(`{"$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03"}`, TypeBinary, "Binary"),
|
||||
makeValidPeekTypeTestCase(`{"$regularExpression": {"pattern": "foo*", "options": "ix"}}`, TypeRegex, "Regular expression"),
|
||||
makeValidPeekTypeTestCase(`{"$dbPointer": {"$ref": "db.collection", "$id": {"$oid": "57e193d7a9cc81b4027498b1"}}}`, TypeDBPointer, "DBPointer"),
|
||||
makeValidPeekTypeTestCase(`{"$ref": "collection", "$id": {"$oid": "57fd71e96e32ab4225b723fb"}, "$db": "database"}`, TypeEmbeddedDocument, "DBRef"),
|
||||
makeInvalidPeekTypeTestCase("invalid array--missing ]", `["a"`, expectError),
|
||||
makeInvalidPeekTypeTestCase("invalid array--colon in array", `["a":`, expectError),
|
||||
makeInvalidPeekTypeTestCase("invalid array--extra comma", `["a",,`, expectError),
|
||||
makeInvalidPeekTypeTestCase("invalid array--trailing comma", `["a",]`, expectError),
|
||||
makeInvalidPeekTypeTestCase("peekType after end of array", `["a"]`, expectErrEOA),
|
||||
{
|
||||
desc: "invalid array--leading comma",
|
||||
input: `[,`,
|
||||
typs: []Type{TypeArray, Type(0)},
|
||||
errFs: []expectedErrorFunc{expectNoError, expectError},
|
||||
},
|
||||
makeInvalidTestCase("lone $scope", `{"$scope": {}}`, expectError),
|
||||
makeInvalidTestCase("empty code with unknown extra key", `{"$code":"", "0":""}`, expectError),
|
||||
makeInvalidTestCase("non-empty code with unknown extra key", `{"$code":"foobar", "0":""}`, expectError),
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ejp := newExtJSONParser(strings.NewReader(tc.input), true)
|
||||
// Manually set the parser's starting state to jpsSawColon so peekType will read ahead to find the extjson
|
||||
// type of the value. If not set, the parser will be in jpsStartState and advance to jpsSawKey, which will
|
||||
// cause it to return without peeking the extjson type.
|
||||
ejp.s = jpsSawColon
|
||||
|
||||
for i, eTyp := range tc.typs {
|
||||
errF := tc.errFs[i]
|
||||
|
||||
typ, err := ejp.peekType()
|
||||
errF(t, err, tc.desc)
|
||||
if err != nil {
|
||||
// Don't inspect the type if there was an error
|
||||
return
|
||||
}
|
||||
|
||||
typDiff(t, eTyp, typ, tc.desc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtJSONParserReadKeyReadValue(t *testing.T) {
|
||||
// several test cases will use the same keys, types, and values, and only differ on input structure
|
||||
|
||||
keys := []string{"_id", "Symbol", "String", "Int32", "Int64", "Int", "MinKey"}
|
||||
types := []Type{TypeObjectID, TypeSymbol, TypeString, TypeInt32, TypeInt64, TypeInt32, TypeMinKey}
|
||||
values := []*extJSONValue{
|
||||
{t: TypeString, v: "57e193d7a9cc81b4027498b5"},
|
||||
{t: TypeString, v: "symbol"},
|
||||
{t: TypeString, v: "string"},
|
||||
{t: TypeString, v: "42"},
|
||||
{t: TypeString, v: "42"},
|
||||
{t: TypeInt32, v: int32(42)},
|
||||
{t: TypeInt32, v: int32(1)},
|
||||
}
|
||||
|
||||
errFuncs := make([]expectedErrorFunc, 7)
|
||||
for i := 0; i < 7; i++ {
|
||||
errFuncs[i] = expectNoError
|
||||
}
|
||||
|
||||
firstKeyError := func(desc, input string) readKeyValueTestCase {
|
||||
return readKeyValueTestCase{
|
||||
desc: desc,
|
||||
input: input,
|
||||
keys: []string{""},
|
||||
typs: []Type{Type(0)},
|
||||
vals: []*extJSONValue{nil},
|
||||
keyEFs: []expectedErrorFunc{expectError},
|
||||
valEFs: []expectedErrorFunc{expectErrorNOOP},
|
||||
}
|
||||
}
|
||||
|
||||
secondKeyError := func(desc, input, firstKey string, firstType Type, firstValue *extJSONValue) readKeyValueTestCase {
|
||||
return readKeyValueTestCase{
|
||||
desc: desc,
|
||||
input: input,
|
||||
keys: []string{firstKey, ""},
|
||||
typs: []Type{firstType, Type(0)},
|
||||
vals: []*extJSONValue{firstValue, nil},
|
||||
keyEFs: []expectedErrorFunc{expectNoError, expectError},
|
||||
valEFs: []expectedErrorFunc{expectNoError, expectErrorNOOP},
|
||||
}
|
||||
}
|
||||
|
||||
cases := []readKeyValueTestCase{
|
||||
{
|
||||
desc: "normal spacing",
|
||||
input: `{
|
||||
"_id": { "$oid": "57e193d7a9cc81b4027498b5" },
|
||||
"Symbol": { "$symbol": "symbol" },
|
||||
"String": "string",
|
||||
"Int32": { "$numberInt": "42" },
|
||||
"Int64": { "$numberLong": "42" },
|
||||
"Int": 42,
|
||||
"MinKey": { "$minKey": 1 }
|
||||
}`,
|
||||
keys: keys, typs: types, vals: values,
|
||||
keyEFs: errFuncs, valEFs: errFuncs,
|
||||
},
|
||||
{
|
||||
desc: "new line before comma",
|
||||
input: `{ "_id": { "$oid": "57e193d7a9cc81b4027498b5" }
|
||||
, "Symbol": { "$symbol": "symbol" }
|
||||
, "String": "string"
|
||||
, "Int32": { "$numberInt": "42" }
|
||||
, "Int64": { "$numberLong": "42" }
|
||||
, "Int": 42
|
||||
, "MinKey": { "$minKey": 1 }
|
||||
}`,
|
||||
keys: keys, typs: types, vals: values,
|
||||
keyEFs: errFuncs, valEFs: errFuncs,
|
||||
},
|
||||
{
|
||||
desc: "tabs around colons",
|
||||
input: `{
|
||||
"_id": { "$oid" : "57e193d7a9cc81b4027498b5" },
|
||||
"Symbol": { "$symbol" : "symbol" },
|
||||
"String": "string",
|
||||
"Int32": { "$numberInt" : "42" },
|
||||
"Int64": { "$numberLong": "42" },
|
||||
"Int": 42,
|
||||
"MinKey": { "$minKey": 1 }
|
||||
}`,
|
||||
keys: keys, typs: types, vals: values,
|
||||
keyEFs: errFuncs, valEFs: errFuncs,
|
||||
},
|
||||
{
|
||||
desc: "no whitespace",
|
||||
input: `{"_id":{"$oid":"57e193d7a9cc81b4027498b5"},"Symbol":{"$symbol":"symbol"},"String":"string","Int32":{"$numberInt":"42"},"Int64":{"$numberLong":"42"},"Int":42,"MinKey":{"$minKey":1}}`,
|
||||
keys: keys, typs: types, vals: values,
|
||||
keyEFs: errFuncs, valEFs: errFuncs,
|
||||
},
|
||||
{
|
||||
desc: "mixed whitespace",
|
||||
input: ` {
|
||||
"_id" : { "$oid": "57e193d7a9cc81b4027498b5" },
|
||||
"Symbol" : { "$symbol": "symbol" } ,
|
||||
"String" : "string",
|
||||
"Int32" : { "$numberInt": "42" } ,
|
||||
"Int64" : {"$numberLong" : "42"},
|
||||
"Int" : 42,
|
||||
"MinKey" : { "$minKey": 1 } } `,
|
||||
keys: keys, typs: types, vals: values,
|
||||
keyEFs: errFuncs, valEFs: errFuncs,
|
||||
},
|
||||
{
|
||||
desc: "nested object",
|
||||
input: `{"k1": 1, "k2": { "k3": { "k4": 4 } }, "k5": 5}`,
|
||||
keys: []string{"k1", "k2", "k3", "k4", "", "", "k5", ""},
|
||||
typs: []Type{TypeInt32, TypeEmbeddedDocument, TypeEmbeddedDocument, TypeInt32, Type(0), Type(0), TypeInt32, Type(0)},
|
||||
vals: []*extJSONValue{
|
||||
{t: TypeInt32, v: int32(1)}, nil, nil, {t: TypeInt32, v: int32(4)}, nil, nil, {t: TypeInt32, v: int32(5)}, nil,
|
||||
},
|
||||
keyEFs: []expectedErrorFunc{
|
||||
expectNoError, expectNoError, expectNoError, expectNoError, expectErrEOD,
|
||||
expectErrEOD, expectNoError, expectErrEOD,
|
||||
},
|
||||
valEFs: []expectedErrorFunc{
|
||||
expectNoError, expectError, expectError, expectNoError, expectErrorNOOP,
|
||||
expectErrorNOOP, expectNoError, expectErrorNOOP,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "invalid input: invalid values for extended type",
|
||||
input: `{"a": {"$numberInt": "1", "x"`,
|
||||
keys: []string{"a"},
|
||||
typs: []Type{TypeInt32},
|
||||
vals: []*extJSONValue{nil},
|
||||
keyEFs: []expectedErrorFunc{expectNoError},
|
||||
valEFs: []expectedErrorFunc{expectError},
|
||||
},
|
||||
firstKeyError("invalid input: missing key--EOF", "{"),
|
||||
firstKeyError("invalid input: missing key--colon first", "{:"),
|
||||
firstKeyError("invalid input: missing value", `{"a":`),
|
||||
firstKeyError("invalid input: missing colon", `{"a" 1`),
|
||||
firstKeyError("invalid input: extra colon", `{"a"::`),
|
||||
secondKeyError("invalid input: missing }", `{"a": 1`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
|
||||
secondKeyError("invalid input: missing comma", `{"a": 1 "b"`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
|
||||
secondKeyError("invalid input: extra comma", `{"a": 1,, "b"`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
|
||||
secondKeyError("invalid input: trailing comma in object", `{"a": 1,}`, "a", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}),
|
||||
{
|
||||
desc: "invalid input: lone scope after a complete value",
|
||||
input: `{"a": "", "b": {"$scope: ""}}`,
|
||||
keys: []string{"a"},
|
||||
typs: []Type{TypeString},
|
||||
vals: []*extJSONValue{{TypeString, ""}},
|
||||
keyEFs: []expectedErrorFunc{expectNoError, expectNoError},
|
||||
valEFs: []expectedErrorFunc{expectNoError, expectError},
|
||||
},
|
||||
{
|
||||
desc: "invalid input: lone scope nested",
|
||||
input: `{"a":{"b":{"$scope":{`,
|
||||
keys: []string{},
|
||||
typs: []Type{},
|
||||
vals: []*extJSONValue{nil},
|
||||
keyEFs: []expectedErrorFunc{expectNoError},
|
||||
valEFs: []expectedErrorFunc{expectError},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ejp := newExtJSONParser(strings.NewReader(tc.input), true)
|
||||
|
||||
for i, eKey := range tc.keys {
|
||||
eTyp := tc.typs[i]
|
||||
eVal := tc.vals[i]
|
||||
|
||||
keyErrF := tc.keyEFs[i]
|
||||
valErrF := tc.valEFs[i]
|
||||
|
||||
k, typ, err := ejp.readKey()
|
||||
readKeyDiff(t, eKey, k, eTyp, typ, err, keyErrF, tc.desc)
|
||||
|
||||
v, err := ejp.readValue(typ)
|
||||
readValueDiff(t, eVal, v, err, valErrF, tc.desc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type ejpExpectationTest func(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{})
|
||||
|
||||
type ejpTestCase struct {
|
||||
f ejpExpectationTest
|
||||
p *extJSONParser
|
||||
k string
|
||||
t Type
|
||||
v interface{}
|
||||
}
|
||||
|
||||
// expectSingleValue is used for simple JSON types (strings, numbers, literals) and for extended JSON types that
|
||||
// have single key-value pairs (i.e. { "$minKey": 1 }, { "$numberLong": "42.42" })
|
||||
func expectSingleValue(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{}) {
|
||||
eVal := expectedValue.(*extJSONValue)
|
||||
|
||||
k, typ, err := p.readKey()
|
||||
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
|
||||
|
||||
v, err := p.readValue(typ)
|
||||
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
|
||||
}
|
||||
|
||||
// expectMultipleValues is used for values that are subdocuments of known size and with known keys (such as extended
|
||||
// JSON types { "$timestamp": {"t": 1, "i": 1} } and { "$regularExpression": {"pattern": "", options: ""} })
|
||||
func expectMultipleValues(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{}) {
|
||||
k, typ, err := p.readKey()
|
||||
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
|
||||
|
||||
v, err := p.readValue(typ)
|
||||
expectNoError(t, err, "")
|
||||
typDiff(t, TypeEmbeddedDocument, v.t, expectedKey)
|
||||
|
||||
actObj := v.v.(*extJSONObject)
|
||||
expObj := expectedValue.(*extJSONObject)
|
||||
|
||||
for i, actKey := range actObj.keys {
|
||||
expKey := expObj.keys[i]
|
||||
actVal := actObj.values[i]
|
||||
expVal := expObj.values[i]
|
||||
|
||||
keyDiff(t, expKey, actKey, expectedKey)
|
||||
typDiff(t, expVal.t, actVal.t, expectedKey)
|
||||
valDiff(t, expVal.v, actVal.v, expectedKey)
|
||||
}
|
||||
}
|
||||
|
||||
type ejpKeyTypValTriple struct {
|
||||
key string
|
||||
typ Type
|
||||
val *extJSONValue
|
||||
}
|
||||
|
||||
type ejpSubDocumentTestValue struct {
|
||||
code string // code is only used for TypeCodeWithScope (and is ignored for TypeEmbeddedDocument
|
||||
ktvs []ejpKeyTypValTriple // list of (key, type, value) triples; this is "scope" for TypeCodeWithScope
|
||||
}
|
||||
|
||||
// expectSubDocument is used for embedded documents and code with scope types; it reads all the keys and values
|
||||
// in the embedded document (or scope for codeWithScope) and compares them to the expectedValue's list of (key, type,
|
||||
// value) triples
|
||||
func expectSubDocument(t *testing.T, p *extJSONParser, expectedKey string, expectedType Type, expectedValue interface{}) {
|
||||
subdoc := expectedValue.(ejpSubDocumentTestValue)
|
||||
|
||||
k, typ, err := p.readKey()
|
||||
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
|
||||
|
||||
if expectedType == TypeCodeWithScope {
|
||||
v, err := p.readValue(typ)
|
||||
readValueDiff(t, &extJSONValue{t: TypeString, v: subdoc.code}, v, err, expectNoError, expectedKey)
|
||||
}
|
||||
|
||||
for _, ktv := range subdoc.ktvs {
|
||||
eKey := ktv.key
|
||||
eTyp := ktv.typ
|
||||
eVal := ktv.val
|
||||
|
||||
k, typ, err = p.readKey()
|
||||
readKeyDiff(t, eKey, k, eTyp, typ, err, expectNoError, expectedKey)
|
||||
|
||||
v, err := p.readValue(typ)
|
||||
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
|
||||
}
|
||||
|
||||
if expectedType == TypeCodeWithScope {
|
||||
// expect scope doc to close
|
||||
k, typ, err = p.readKey()
|
||||
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOD, expectedKey)
|
||||
}
|
||||
|
||||
// expect subdoc to close
|
||||
k, typ, err = p.readKey()
|
||||
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOD, expectedKey)
|
||||
}
|
||||
|
||||
// expectArray takes the expectedKey, ignores the expectedType, and uses the expectedValue
|
||||
// as a slice of (type Type, value *extJSONValue) pairs
|
||||
func expectArray(t *testing.T, p *extJSONParser, expectedKey string, _ Type, expectedValue interface{}) {
|
||||
ktvs := expectedValue.([]ejpKeyTypValTriple)
|
||||
|
||||
k, typ, err := p.readKey()
|
||||
readKeyDiff(t, expectedKey, k, TypeArray, typ, err, expectNoError, expectedKey)
|
||||
|
||||
for _, ktv := range ktvs {
|
||||
eTyp := ktv.typ
|
||||
eVal := ktv.val
|
||||
|
||||
typ, err = p.peekType()
|
||||
typDiff(t, eTyp, typ, expectedKey)
|
||||
expectNoError(t, err, expectedKey)
|
||||
|
||||
v, err := p.readValue(typ)
|
||||
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
|
||||
}
|
||||
|
||||
// expect array to end
|
||||
typ, err = p.peekType()
|
||||
typDiff(t, Type(0), typ, expectedKey)
|
||||
expectErrEOA(t, err, expectedKey)
|
||||
}
|
||||
|
||||
func TestExtJSONParserAllTypes(t *testing.T) {
|
||||
in := ` { "_id" : { "$oid": "57e193d7a9cc81b4027498b5"}
|
||||
, "Symbol" : { "$symbol": "symbol"}
|
||||
, "String" : "string"
|
||||
, "Int32" : { "$numberInt": "42"}
|
||||
, "Int64" : { "$numberLong": "42"}
|
||||
, "Double" : { "$numberDouble": "42.42"}
|
||||
, "SpecialFloat" : { "$numberDouble": "NaN" }
|
||||
, "Decimal" : { "$numberDecimal": "1234" }
|
||||
, "Binary" : { "$binary": { "base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03" } }
|
||||
, "BinaryLegacy" : { "$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03" }
|
||||
, "BinaryUserDefined" : { "$binary": { "base64": "AQIDBAU=", "subType": "80" } }
|
||||
, "Code" : { "$code": "function() {}" }
|
||||
, "CodeWithEmptyScope" : { "$code": "function() {}", "$scope": {} }
|
||||
, "CodeWithScope" : { "$code": "function() {}", "$scope": { "x": 1 } }
|
||||
, "EmptySubdocument" : {}
|
||||
, "Subdocument" : { "foo": "bar", "baz": { "$numberInt": "42" } }
|
||||
, "Array" : [{"$numberInt": "1"}, {"$numberLong": "2"}, {"$numberDouble": "3"}, 4, "string", 5.0]
|
||||
, "Timestamp" : { "$timestamp": { "t": 42, "i": 1 } }
|
||||
, "RegularExpression" : { "$regularExpression": { "pattern": "foo*", "options": "ix" } }
|
||||
, "DatetimeEpoch" : { "$date": { "$numberLong": "0" } }
|
||||
, "DatetimePositive" : { "$date": { "$numberLong": "9223372036854775807" } }
|
||||
, "DatetimeNegative" : { "$date": { "$numberLong": "-9223372036854775808" } }
|
||||
, "True" : true
|
||||
, "False" : false
|
||||
, "DBPointer" : { "$dbPointer": { "$ref": "db.collection", "$id": { "$oid": "57e193d7a9cc81b4027498b1" } } }
|
||||
, "DBRef" : { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" }, "$db": "database" }
|
||||
, "DBRefNoDB" : { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" } }
|
||||
, "MinKey" : { "$minKey": 1 }
|
||||
, "MaxKey" : { "$maxKey": 1 }
|
||||
, "Null" : null
|
||||
, "Undefined" : { "$undefined": true }
|
||||
}`
|
||||
|
||||
ejp := newExtJSONParser(strings.NewReader(in), true)
|
||||
|
||||
cases := []ejpTestCase{
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "_id", t: TypeObjectID, v: &extJSONValue{t: TypeString, v: "57e193d7a9cc81b4027498b5"},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "Symbol", t: TypeSymbol, v: &extJSONValue{t: TypeString, v: "symbol"},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "String", t: TypeString, v: &extJSONValue{t: TypeString, v: "string"},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "Int32", t: TypeInt32, v: &extJSONValue{t: TypeString, v: "42"},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "Int64", t: TypeInt64, v: &extJSONValue{t: TypeString, v: "42"},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "Double", t: TypeDouble, v: &extJSONValue{t: TypeString, v: "42.42"},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "SpecialFloat", t: TypeDouble, v: &extJSONValue{t: TypeString, v: "NaN"},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "Decimal", t: TypeDecimal128, v: &extJSONValue{t: TypeString, v: "1234"},
|
||||
},
|
||||
{
|
||||
f: expectMultipleValues, p: ejp,
|
||||
k: "Binary", t: TypeBinary,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"base64", "subType"},
|
||||
values: []*extJSONValue{
|
||||
{t: TypeString, v: "o0w498Or7cijeBSpkquNtg=="},
|
||||
{t: TypeString, v: "03"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectMultipleValues, p: ejp,
|
||||
k: "BinaryLegacy", t: TypeBinary,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"base64", "subType"},
|
||||
values: []*extJSONValue{
|
||||
{t: TypeString, v: "o0w498Or7cijeBSpkquNtg=="},
|
||||
{t: TypeString, v: "03"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectMultipleValues, p: ejp,
|
||||
k: "BinaryUserDefined", t: TypeBinary,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"base64", "subType"},
|
||||
values: []*extJSONValue{
|
||||
{t: TypeString, v: "AQIDBAU="},
|
||||
{t: TypeString, v: "80"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "Code", t: TypeJavaScript, v: &extJSONValue{t: TypeString, v: "function() {}"},
|
||||
},
|
||||
{
|
||||
f: expectSubDocument, p: ejp,
|
||||
k: "CodeWithEmptyScope", t: TypeCodeWithScope,
|
||||
v: ejpSubDocumentTestValue{
|
||||
code: "function() {}",
|
||||
ktvs: []ejpKeyTypValTriple{},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectSubDocument, p: ejp,
|
||||
k: "CodeWithScope", t: TypeCodeWithScope,
|
||||
v: ejpSubDocumentTestValue{
|
||||
code: "function() {}",
|
||||
ktvs: []ejpKeyTypValTriple{
|
||||
{"x", TypeInt32, &extJSONValue{t: TypeInt32, v: int32(1)}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectSubDocument, p: ejp,
|
||||
k: "EmptySubdocument", t: TypeEmbeddedDocument,
|
||||
v: ejpSubDocumentTestValue{
|
||||
ktvs: []ejpKeyTypValTriple{},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectSubDocument, p: ejp,
|
||||
k: "Subdocument", t: TypeEmbeddedDocument,
|
||||
v: ejpSubDocumentTestValue{
|
||||
ktvs: []ejpKeyTypValTriple{
|
||||
{"foo", TypeString, &extJSONValue{t: TypeString, v: "bar"}},
|
||||
{"baz", TypeInt32, &extJSONValue{t: TypeString, v: "42"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectArray, p: ejp,
|
||||
k: "Array", t: TypeArray,
|
||||
v: []ejpKeyTypValTriple{
|
||||
{typ: TypeInt32, val: &extJSONValue{t: TypeString, v: "1"}},
|
||||
{typ: TypeInt64, val: &extJSONValue{t: TypeString, v: "2"}},
|
||||
{typ: TypeDouble, val: &extJSONValue{t: TypeString, v: "3"}},
|
||||
{typ: TypeInt32, val: &extJSONValue{t: TypeInt32, v: int32(4)}},
|
||||
{typ: TypeString, val: &extJSONValue{t: TypeString, v: "string"}},
|
||||
{typ: TypeDouble, val: &extJSONValue{t: TypeDouble, v: 5.0}},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectMultipleValues, p: ejp,
|
||||
k: "Timestamp", t: TypeTimestamp,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"t", "i"},
|
||||
values: []*extJSONValue{
|
||||
{t: TypeInt32, v: int32(42)},
|
||||
{t: TypeInt32, v: int32(1)},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectMultipleValues, p: ejp,
|
||||
k: "RegularExpression", t: TypeRegex,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"pattern", "options"},
|
||||
values: []*extJSONValue{
|
||||
{t: TypeString, v: "foo*"},
|
||||
{t: TypeString, v: "ix"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectMultipleValues, p: ejp,
|
||||
k: "DatetimeEpoch", t: TypeDateTime,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"$numberLong"},
|
||||
values: []*extJSONValue{
|
||||
{t: TypeString, v: "0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectMultipleValues, p: ejp,
|
||||
k: "DatetimePositive", t: TypeDateTime,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"$numberLong"},
|
||||
values: []*extJSONValue{
|
||||
{t: TypeString, v: "9223372036854775807"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectMultipleValues, p: ejp,
|
||||
k: "DatetimeNegative", t: TypeDateTime,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"$numberLong"},
|
||||
values: []*extJSONValue{
|
||||
{t: TypeString, v: "-9223372036854775808"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "True", t: TypeBoolean, v: &extJSONValue{t: TypeBoolean, v: true},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "False", t: TypeBoolean, v: &extJSONValue{t: TypeBoolean, v: false},
|
||||
},
|
||||
{
|
||||
f: expectMultipleValues, p: ejp,
|
||||
k: "DBPointer", t: TypeDBPointer,
|
||||
v: &extJSONObject{
|
||||
keys: []string{"$ref", "$id"},
|
||||
values: []*extJSONValue{
|
||||
{t: TypeString, v: "db.collection"},
|
||||
{t: TypeString, v: "57e193d7a9cc81b4027498b1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectSubDocument, p: ejp,
|
||||
k: "DBRef", t: TypeEmbeddedDocument,
|
||||
v: ejpSubDocumentTestValue{
|
||||
ktvs: []ejpKeyTypValTriple{
|
||||
{"$ref", TypeString, &extJSONValue{t: TypeString, v: "collection"}},
|
||||
{"$id", TypeObjectID, &extJSONValue{t: TypeString, v: "57fd71e96e32ab4225b723fb"}},
|
||||
{"$db", TypeString, &extJSONValue{t: TypeString, v: "database"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectSubDocument, p: ejp,
|
||||
k: "DBRefNoDB", t: TypeEmbeddedDocument,
|
||||
v: ejpSubDocumentTestValue{
|
||||
ktvs: []ejpKeyTypValTriple{
|
||||
{"$ref", TypeString, &extJSONValue{t: TypeString, v: "collection"}},
|
||||
{"$id", TypeObjectID, &extJSONValue{t: TypeString, v: "57fd71e96e32ab4225b723fb"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "MinKey", t: TypeMinKey, v: &extJSONValue{t: TypeInt32, v: int32(1)},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "MaxKey", t: TypeMaxKey, v: &extJSONValue{t: TypeInt32, v: int32(1)},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "Null", t: TypeNull, v: &extJSONValue{t: TypeNull, v: nil},
|
||||
},
|
||||
{
|
||||
f: expectSingleValue, p: ejp,
|
||||
k: "Undefined", t: TypeUndefined, v: &extJSONValue{t: TypeBoolean, v: true},
|
||||
},
|
||||
}
|
||||
|
||||
// run the test cases
|
||||
for _, tc := range cases {
|
||||
tc.f(t, tc.p, tc.k, tc.t, tc.v)
|
||||
}
|
||||
|
||||
// expect end of whole document: read final }
|
||||
k, typ, err := ejp.readKey()
|
||||
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOD, "")
|
||||
|
||||
// expect end of whole document: read EOF
|
||||
k, typ, err = ejp.readKey()
|
||||
readKeyDiff(t, "", k, Type(0), typ, err, expectErrEOF, "")
|
||||
if diff := cmp.Diff(jpsDoneState, ejp.s); diff != "" {
|
||||
t.Errorf("expected parser to be in done state but instead is in %v\n", ejp.s)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtJSONValue(t *testing.T) {
|
||||
t.Run("Large Date", func(t *testing.T) {
|
||||
val := &extJSONValue{
|
||||
t: TypeString,
|
||||
v: "3001-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
intVal, err := val.parseDateTime()
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing date time: %v", err)
|
||||
}
|
||||
|
||||
if intVal <= 0 {
|
||||
t.Fatalf("expected value above 0, got %v", intVal)
|
||||
}
|
||||
})
|
||||
t.Run("fallback time format", func(t *testing.T) {
|
||||
val := &extJSONValue{
|
||||
t: TypeString,
|
||||
v: "2019-06-04T14:54:31.416+0000",
|
||||
}
|
||||
|
||||
_, err := val.parseDateTime()
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing date time: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
46
extjson_prose_test.go
Normal file
46
extjson_prose_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
)
|
||||
|
||||
func TestExtJSON(t *testing.T) {
|
||||
timestampNegativeInt32Err := fmt.Errorf("$timestamp i number should be uint32: -1")
|
||||
timestampNegativeInt64Err := fmt.Errorf("$timestamp i number should be uint32: -2147483649")
|
||||
timestampLargeValueErr := fmt.Errorf("$timestamp i number should be uint32: 4294967296")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
canonical bool
|
||||
err error
|
||||
}{
|
||||
{"timestamp - negative int32 value", `{"":{"$timestamp":{"t":0,"i":-1}}}`, false, timestampNegativeInt32Err},
|
||||
{"timestamp - negative int64 value", `{"":{"$timestamp":{"t":0,"i":-2147483649}}}`, false, timestampNegativeInt64Err},
|
||||
{"timestamp - value overflows uint32", `{"":{"$timestamp":{"t":0,"i":4294967296}}}`, false, timestampLargeValueErr},
|
||||
{"top level key is not treated as special", `{"$code": "foo"}`, false, nil},
|
||||
{"escaped single quote errors", `{"f\'oo": "bar"}`, false, ErrInvalidJSON},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var res Raw
|
||||
err := UnmarshalExtJSON([]byte(tc.input), tc.canonical, &res)
|
||||
if tc.err == nil {
|
||||
assert.Nil(t, err, "UnmarshalExtJSON error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotNil(t, err, "expected error %v, got nil", tc.err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error(), "expected error %v, got %v", tc.err, err)
|
||||
})
|
||||
}
|
||||
}
|
606
extjson_reader.go
Normal file
606
extjson_reader.go
Normal file
@ -0,0 +1,606 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type ejvrState struct {
|
||||
mode mode
|
||||
vType Type
|
||||
depth int
|
||||
}
|
||||
|
||||
// extJSONValueReader is for reading extended JSON.
|
||||
type extJSONValueReader struct {
|
||||
p *extJSONParser
|
||||
|
||||
stack []ejvrState
|
||||
frame int
|
||||
}
|
||||
|
||||
// NewExtJSONValueReader returns a ValueReader that reads Extended JSON values
|
||||
// from r. If canonicalOnly is true, reading values from the ValueReader returns
|
||||
// an error if the Extended JSON was not marshaled in canonical mode.
|
||||
func NewExtJSONValueReader(r io.Reader, canonicalOnly bool) (ValueReader, error) {
|
||||
return newExtJSONValueReader(r, canonicalOnly)
|
||||
}
|
||||
|
||||
func newExtJSONValueReader(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
|
||||
ejvr := new(extJSONValueReader)
|
||||
return ejvr.reset(r, canonicalOnly)
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) reset(r io.Reader, canonicalOnly bool) (*extJSONValueReader, error) {
|
||||
p := newExtJSONParser(r, canonicalOnly)
|
||||
typ, err := p.peekType()
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrInvalidJSON
|
||||
}
|
||||
|
||||
var m mode
|
||||
switch typ {
|
||||
case TypeEmbeddedDocument:
|
||||
m = mTopLevel
|
||||
case TypeArray:
|
||||
m = mArray
|
||||
default:
|
||||
m = mValue
|
||||
}
|
||||
|
||||
stack := make([]ejvrState, 1, 5)
|
||||
stack[0] = ejvrState{
|
||||
mode: m,
|
||||
vType: typ,
|
||||
}
|
||||
return &extJSONValueReader{
|
||||
p: p,
|
||||
stack: stack,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) advanceFrame() {
|
||||
if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack
|
||||
length := len(ejvr.stack)
|
||||
if length+1 >= cap(ejvr.stack) {
|
||||
// double it
|
||||
buf := make([]ejvrState, 2*cap(ejvr.stack)+1)
|
||||
copy(buf, ejvr.stack)
|
||||
ejvr.stack = buf
|
||||
}
|
||||
ejvr.stack = ejvr.stack[:length+1]
|
||||
}
|
||||
ejvr.frame++
|
||||
|
||||
// Clean the stack
|
||||
ejvr.stack[ejvr.frame].mode = 0
|
||||
ejvr.stack[ejvr.frame].vType = 0
|
||||
ejvr.stack[ejvr.frame].depth = 0
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) pushDocument() {
|
||||
ejvr.advanceFrame()
|
||||
|
||||
ejvr.stack[ejvr.frame].mode = mDocument
|
||||
ejvr.stack[ejvr.frame].depth = ejvr.p.depth
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) pushCodeWithScope() {
|
||||
ejvr.advanceFrame()
|
||||
|
||||
ejvr.stack[ejvr.frame].mode = mCodeWithScope
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) pushArray() {
|
||||
ejvr.advanceFrame()
|
||||
|
||||
ejvr.stack[ejvr.frame].mode = mArray
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) push(m mode, t Type) {
|
||||
ejvr.advanceFrame()
|
||||
|
||||
ejvr.stack[ejvr.frame].mode = m
|
||||
ejvr.stack[ejvr.frame].vType = t
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) pop() {
|
||||
switch ejvr.stack[ejvr.frame].mode {
|
||||
case mElement, mValue:
|
||||
ejvr.frame--
|
||||
case mDocument, mArray, mCodeWithScope:
|
||||
ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
|
||||
}
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) skipObject() {
|
||||
// read entire object until depth returns to 0 (last ending } or ] seen)
|
||||
depth := 1
|
||||
for depth > 0 {
|
||||
ejvr.p.advanceState()
|
||||
|
||||
// If object is empty, raise depth and continue. When emptyObject is true, the
|
||||
// parser has already read both the opening and closing brackets of an empty
|
||||
// object ("{}"), so the next valid token will be part of the parent document,
|
||||
// not part of the nested document.
|
||||
//
|
||||
// If there is a comma, there are remaining fields, emptyObject must be set back
|
||||
// to false, and comma must be skipped with advanceState().
|
||||
if ejvr.p.emptyObject {
|
||||
if ejvr.p.s == jpsSawComma {
|
||||
ejvr.p.emptyObject = false
|
||||
ejvr.p.advanceState()
|
||||
}
|
||||
depth--
|
||||
continue
|
||||
}
|
||||
|
||||
switch ejvr.p.s {
|
||||
case jpsSawBeginObject, jpsSawBeginArray:
|
||||
depth++
|
||||
case jpsSawEndObject, jpsSawEndArray:
|
||||
depth--
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
|
||||
te := TransitionError{
|
||||
name: name,
|
||||
current: ejvr.stack[ejvr.frame].mode,
|
||||
destination: destination,
|
||||
modes: modes,
|
||||
action: "read",
|
||||
}
|
||||
if ejvr.frame != 0 {
|
||||
te.parent = ejvr.stack[ejvr.frame-1].mode
|
||||
}
|
||||
return te
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) typeError(t Type) error {
|
||||
return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t)
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ensureElementValue(t Type, destination mode, callerName string, addModes ...mode) error {
|
||||
switch ejvr.stack[ejvr.frame].mode {
|
||||
case mElement, mValue:
|
||||
if ejvr.stack[ejvr.frame].vType != t {
|
||||
return ejvr.typeError(t)
|
||||
}
|
||||
default:
|
||||
modes := []mode{mElement, mValue}
|
||||
if addModes != nil {
|
||||
modes = append(modes, addModes...)
|
||||
}
|
||||
return ejvr.invalidTransitionErr(destination, callerName, modes)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) Type() Type {
|
||||
return ejvr.stack[ejvr.frame].vType
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) Skip() error {
|
||||
switch ejvr.stack[ejvr.frame].mode {
|
||||
case mElement, mValue:
|
||||
default:
|
||||
return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
|
||||
}
|
||||
|
||||
defer ejvr.pop()
|
||||
|
||||
t := ejvr.stack[ejvr.frame].vType
|
||||
switch t {
|
||||
case TypeArray, TypeEmbeddedDocument, TypeCodeWithScope:
|
||||
// read entire array, doc or CodeWithScope
|
||||
ejvr.skipObject()
|
||||
default:
|
||||
_, err := ejvr.p.readValue(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) {
|
||||
switch ejvr.stack[ejvr.frame].mode {
|
||||
case mTopLevel: // allow reading array from top level
|
||||
case mArray:
|
||||
return ejvr, nil
|
||||
default:
|
||||
if err := ejvr.ensureElementValue(TypeArray, mArray, "ReadArray", mTopLevel, mArray); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ejvr.pushArray()
|
||||
|
||||
return ejvr, nil
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) {
|
||||
if err := ejvr.ensureElementValue(TypeBinary, 0, "ReadBinary"); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeBinary)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
b, btype, err = v.parseBinary()
|
||||
|
||||
ejvr.pop()
|
||||
return b, btype, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) {
|
||||
if err := ejvr.ensureElementValue(TypeBoolean, 0, "ReadBoolean"); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeBoolean)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if v.t != TypeBoolean {
|
||||
return false, fmt.Errorf("expected type bool, but got type %s", v.t)
|
||||
}
|
||||
|
||||
ejvr.pop()
|
||||
return v.v.(bool), nil
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) {
|
||||
switch ejvr.stack[ejvr.frame].mode {
|
||||
case mTopLevel:
|
||||
return ejvr, nil
|
||||
case mElement, mValue:
|
||||
if ejvr.stack[ejvr.frame].vType != TypeEmbeddedDocument {
|
||||
return nil, ejvr.typeError(TypeEmbeddedDocument)
|
||||
}
|
||||
|
||||
ejvr.pushDocument()
|
||||
return ejvr, nil
|
||||
default:
|
||||
return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
|
||||
}
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
|
||||
if err = ejvr.ensureElementValue(TypeCodeWithScope, 0, "ReadCodeWithScope"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeCodeWithScope)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
code, err = v.parseJavascript()
|
||||
|
||||
ejvr.pushCodeWithScope()
|
||||
return code, ejvr, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid ObjectID, err error) {
|
||||
if err = ejvr.ensureElementValue(TypeDBPointer, 0, "ReadDBPointer"); err != nil {
|
||||
return "", NilObjectID, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeDBPointer)
|
||||
if err != nil {
|
||||
return "", NilObjectID, err
|
||||
}
|
||||
|
||||
ns, oid, err = v.parseDBPointer()
|
||||
|
||||
ejvr.pop()
|
||||
return ns, oid, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) {
|
||||
if err := ejvr.ensureElementValue(TypeDateTime, 0, "ReadDateTime"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeDateTime)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
d, err := v.parseDateTime()
|
||||
|
||||
ejvr.pop()
|
||||
return d, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadDecimal128() (Decimal128, error) {
|
||||
if err := ejvr.ensureElementValue(TypeDecimal128, 0, "ReadDecimal128"); err != nil {
|
||||
return Decimal128{}, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeDecimal128)
|
||||
if err != nil {
|
||||
return Decimal128{}, err
|
||||
}
|
||||
|
||||
d, err := v.parseDecimal128()
|
||||
|
||||
ejvr.pop()
|
||||
return d, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadDouble() (float64, error) {
|
||||
if err := ejvr.ensureElementValue(TypeDouble, 0, "ReadDouble"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeDouble)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
d, err := v.parseDouble()
|
||||
|
||||
ejvr.pop()
|
||||
return d, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadInt32() (int32, error) {
|
||||
if err := ejvr.ensureElementValue(TypeInt32, 0, "ReadInt32"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeInt32)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
i, err := v.parseInt32()
|
||||
|
||||
ejvr.pop()
|
||||
return i, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadInt64() (int64, error) {
|
||||
if err := ejvr.ensureElementValue(TypeInt64, 0, "ReadInt64"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeInt64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
i, err := v.parseInt64()
|
||||
|
||||
ejvr.pop()
|
||||
return i, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) {
|
||||
if err = ejvr.ensureElementValue(TypeJavaScript, 0, "ReadJavascript"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeJavaScript)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
code, err = v.parseJavascript()
|
||||
|
||||
ejvr.pop()
|
||||
return code, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadMaxKey() error {
|
||||
if err := ejvr.ensureElementValue(TypeMaxKey, 0, "ReadMaxKey"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeMaxKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = v.parseMinMaxKey("max")
|
||||
|
||||
ejvr.pop()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadMinKey() error {
|
||||
if err := ejvr.ensureElementValue(TypeMinKey, 0, "ReadMinKey"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeMinKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = v.parseMinMaxKey("min")
|
||||
|
||||
ejvr.pop()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadNull() error {
|
||||
if err := ejvr.ensureElementValue(TypeNull, 0, "ReadNull"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeNull)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if v.t != TypeNull {
|
||||
return fmt.Errorf("expected type null but got type %s", v.t)
|
||||
}
|
||||
|
||||
ejvr.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadObjectID() (ObjectID, error) {
|
||||
if err := ejvr.ensureElementValue(TypeObjectID, 0, "ReadObjectID"); err != nil {
|
||||
return ObjectID{}, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeObjectID)
|
||||
if err != nil {
|
||||
return ObjectID{}, err
|
||||
}
|
||||
|
||||
oid, err := v.parseObjectID()
|
||||
|
||||
ejvr.pop()
|
||||
return oid, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) {
|
||||
if err = ejvr.ensureElementValue(TypeRegex, 0, "ReadRegex"); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeRegex)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
pattern, options, err = v.parseRegex()
|
||||
|
||||
ejvr.pop()
|
||||
return pattern, options, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadString() (string, error) {
|
||||
if err := ejvr.ensureElementValue(TypeString, 0, "ReadString"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeString)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if v.t != TypeString {
|
||||
return "", fmt.Errorf("expected type string but got type %s", v.t)
|
||||
}
|
||||
|
||||
ejvr.pop()
|
||||
return v.v.(string), nil
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) {
|
||||
if err = ejvr.ensureElementValue(TypeSymbol, 0, "ReadSymbol"); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeSymbol)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
symbol, err = v.parseSymbol()
|
||||
|
||||
ejvr.pop()
|
||||
return symbol, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) {
|
||||
if err = ejvr.ensureElementValue(TypeTimestamp, 0, "ReadTimestamp"); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeTimestamp)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
t, i, err = v.parseTimestamp()
|
||||
|
||||
ejvr.pop()
|
||||
return t, i, err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadUndefined() error {
|
||||
if err := ejvr.ensureElementValue(TypeUndefined, 0, "ReadUndefined"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := ejvr.p.readValue(TypeUndefined)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = v.parseUndefined()
|
||||
|
||||
ejvr.pop()
|
||||
return err
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
|
||||
switch ejvr.stack[ejvr.frame].mode {
|
||||
case mTopLevel, mDocument, mCodeWithScope:
|
||||
default:
|
||||
return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
|
||||
}
|
||||
|
||||
name, t, err := ejvr.p.readKey()
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrEOD) {
|
||||
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
|
||||
_, err := ejvr.p.peekType()
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
ejvr.pop()
|
||||
}
|
||||
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
ejvr.push(mElement, t)
|
||||
return name, ejvr, nil
|
||||
}
|
||||
|
||||
func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {
|
||||
switch ejvr.stack[ejvr.frame].mode {
|
||||
case mArray:
|
||||
default:
|
||||
return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
|
||||
}
|
||||
|
||||
t, err := ejvr.p.peekType()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrEOA) {
|
||||
ejvr.pop()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ejvr.push(mValue, t)
|
||||
return ejvr, nil
|
||||
}
|
168
extjson_reader_test.go
Normal file
168
extjson_reader_test.go
Normal file
@ -0,0 +1,168 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestExtJSONReader(t *testing.T) {
|
||||
t.Run("ReadDocument", func(t *testing.T) {
|
||||
t.Run("EmbeddedDocument", func(t *testing.T) {
|
||||
ejvr := &extJSONValueReader{
|
||||
stack: []ejvrState{
|
||||
{mode: mTopLevel},
|
||||
{mode: mElement, vType: TypeBoolean},
|
||||
},
|
||||
frame: 1,
|
||||
}
|
||||
|
||||
ejvr.stack[1].mode = mArray
|
||||
wanterr := ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
|
||||
_, err := ejvr.ReadDocument()
|
||||
if err == nil || err.Error() != wanterr.Error() {
|
||||
t.Errorf("Incorrect returned error. got %v; want %v", err, wanterr)
|
||||
}
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("invalid transition", func(t *testing.T) {
|
||||
t.Run("Skip", func(t *testing.T) {
|
||||
ejvr := &extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}}
|
||||
wanterr := (&extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}}).invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
|
||||
goterr := ejvr.Skip()
|
||||
if !cmp.Equal(goterr, wanterr, cmp.Comparer(assert.CompareErrors)) {
|
||||
t.Errorf("Expected correct invalid transition error. got %v; want %v", goterr, wanterr)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadMultipleTopLevelDocuments(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected [][]byte
|
||||
}{
|
||||
{
|
||||
"single top-level document",
|
||||
"{\"foo\":1}",
|
||||
[][]byte{
|
||||
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
},
|
||||
{
|
||||
"single top-level document with leading and trailing whitespace",
|
||||
"\n\n {\"foo\":1} \n",
|
||||
[][]byte{
|
||||
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
},
|
||||
{
|
||||
"two top-level documents",
|
||||
"{\"foo\":1}{\"foo\":2}",
|
||||
[][]byte{
|
||||
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
},
|
||||
{
|
||||
"two top-level documents with leading and trailing whitespace and whitespace separation ",
|
||||
"\n\n {\"foo\":1}\n{\"foo\":2}\n ",
|
||||
[][]byte{
|
||||
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
},
|
||||
{
|
||||
"top-level array with single document",
|
||||
"[{\"foo\":1}]",
|
||||
[][]byte{
|
||||
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
},
|
||||
{
|
||||
"top-level array with 2 documents",
|
||||
"[{\"foo\":1},{\"foo\":2}]",
|
||||
[][]byte{
|
||||
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
|
||||
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := strings.NewReader(tc.input)
|
||||
vr, err := NewExtJSONValueReader(r, false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, but got %v", err)
|
||||
}
|
||||
|
||||
actual, err := readAllDocuments(vr)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, but got %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.expected, actual); diff != "" {
|
||||
t.Fatalf("expected does not match actual: %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func readAllDocuments(vr ValueReader) ([][]byte, error) {
|
||||
var actual [][]byte
|
||||
|
||||
switch vr.Type() {
|
||||
case TypeEmbeddedDocument:
|
||||
for {
|
||||
result, err := copyDocumentToBytes(vr)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
actual = append(actual, result)
|
||||
}
|
||||
case TypeArray:
|
||||
ar, err := vr.ReadArray()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
evr, err := ar.ReadValue()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrEOA) {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := copyDocumentToBytes(evr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
actual = append(actual, result)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("expected an array or a document, but got %s", vr.Type())
|
||||
}
|
||||
|
||||
return actual, nil
|
||||
}
|
223
extjson_tables.go
Normal file
223
extjson_tables.go
Normal file
@ -0,0 +1,223 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/golang/go by The Go Authors
|
||||
// See THIRD-PARTY-NOTICES for original license terms.
|
||||
|
||||
package bson
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
// safeSet holds the value true if the ASCII character with the given array
|
||||
// position can be represented inside a JSON string without any further
|
||||
// escaping.
|
||||
//
|
||||
// All values are true except for the ASCII control characters (0-31), the
|
||||
// double quote ("), and the backslash character ("\").
|
||||
var safeSet = [utf8.RuneSelf]bool{
|
||||
' ': true,
|
||||
'!': true,
|
||||
'"': false,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': true,
|
||||
'\'': true,
|
||||
'(': true,
|
||||
')': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
',': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'/': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
':': true,
|
||||
';': true,
|
||||
'<': true,
|
||||
'=': true,
|
||||
'>': true,
|
||||
'?': true,
|
||||
'@': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'V': true,
|
||||
'W': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'[': true,
|
||||
'\\': false,
|
||||
']': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'{': true,
|
||||
'|': true,
|
||||
'}': true,
|
||||
'~': true,
|
||||
'\u007f': true,
|
||||
}
|
||||
|
||||
// htmlSafeSet holds the value true if the ASCII character with the given
|
||||
// array position can be safely represented inside a JSON string, embedded
|
||||
// inside of HTML <script> tags, without any additional escaping.
|
||||
//
|
||||
// All values are true except for the ASCII control characters (0-31), the
|
||||
// double quote ("), the backslash character ("\"), HTML opening and closing
|
||||
// tags ("<" and ">"), and the ampersand ("&").
|
||||
var htmlSafeSet = [utf8.RuneSelf]bool{
|
||||
' ': true,
|
||||
'!': true,
|
||||
'"': false,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': false,
|
||||
'\'': true,
|
||||
'(': true,
|
||||
')': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
',': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'/': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
':': true,
|
||||
';': true,
|
||||
'<': false,
|
||||
'=': true,
|
||||
'>': false,
|
||||
'?': true,
|
||||
'@': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'V': true,
|
||||
'W': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'[': true,
|
||||
'\\': false,
|
||||
']': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'{': true,
|
||||
'|': true,
|
||||
'}': true,
|
||||
'~': true,
|
||||
'\u007f': true,
|
||||
}
|
489
extjson_wrappers.go
Normal file
489
extjson_wrappers.go
Normal file
@ -0,0 +1,489 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func wrapperKeyBSONType(key string) Type {
|
||||
switch key {
|
||||
case "$numberInt":
|
||||
return TypeInt32
|
||||
case "$numberLong":
|
||||
return TypeInt64
|
||||
case "$oid":
|
||||
return TypeObjectID
|
||||
case "$symbol":
|
||||
return TypeSymbol
|
||||
case "$numberDouble":
|
||||
return TypeDouble
|
||||
case "$numberDecimal":
|
||||
return TypeDecimal128
|
||||
case "$binary":
|
||||
return TypeBinary
|
||||
case "$code":
|
||||
return TypeJavaScript
|
||||
case "$scope":
|
||||
return TypeCodeWithScope
|
||||
case "$timestamp":
|
||||
return TypeTimestamp
|
||||
case "$regularExpression":
|
||||
return TypeRegex
|
||||
case "$dbPointer":
|
||||
return TypeDBPointer
|
||||
case "$date":
|
||||
return TypeDateTime
|
||||
case "$minKey":
|
||||
return TypeMinKey
|
||||
case "$maxKey":
|
||||
return TypeMaxKey
|
||||
case "$undefined":
|
||||
return TypeUndefined
|
||||
}
|
||||
|
||||
return TypeEmbeddedDocument
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
|
||||
if ejv.t != TypeEmbeddedDocument {
|
||||
return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
binObj := ejv.v.(*extJSONObject)
|
||||
bFound := false
|
||||
stFound := false
|
||||
|
||||
for i, key := range binObj.keys {
|
||||
val := binObj.values[i]
|
||||
|
||||
switch key {
|
||||
case "base64":
|
||||
if bFound {
|
||||
return nil, 0, errors.New("duplicate base64 key in $binary")
|
||||
}
|
||||
|
||||
if val.t != TypeString {
|
||||
return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t)
|
||||
}
|
||||
|
||||
base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string))
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string))
|
||||
}
|
||||
|
||||
b = base64Bytes
|
||||
bFound = true
|
||||
case "subType":
|
||||
if stFound {
|
||||
return nil, 0, errors.New("duplicate subType key in $binary")
|
||||
}
|
||||
|
||||
if val.t != TypeString {
|
||||
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
|
||||
}
|
||||
|
||||
i, err := strconv.ParseUint(val.v.(string), 16, 8)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err)
|
||||
}
|
||||
|
||||
subType = byte(i)
|
||||
stFound = true
|
||||
default:
|
||||
return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
if !bFound {
|
||||
return nil, 0, errors.New("missing base64 field in $binary object")
|
||||
}
|
||||
|
||||
if !stFound {
|
||||
return nil, 0, errors.New("missing subType field in $binary object")
|
||||
|
||||
}
|
||||
|
||||
return b, subType, nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseDBPointer() (ns string, oid ObjectID, err error) {
|
||||
if ejv.t != TypeEmbeddedDocument {
|
||||
return "", NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
dbpObj := ejv.v.(*extJSONObject)
|
||||
oidFound := false
|
||||
nsFound := false
|
||||
|
||||
for i, key := range dbpObj.keys {
|
||||
val := dbpObj.values[i]
|
||||
|
||||
switch key {
|
||||
case "$ref":
|
||||
if nsFound {
|
||||
return "", NilObjectID, errors.New("duplicate $ref key in $dbPointer")
|
||||
}
|
||||
|
||||
if val.t != TypeString {
|
||||
return "", NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t)
|
||||
}
|
||||
|
||||
ns = val.v.(string)
|
||||
nsFound = true
|
||||
case "$id":
|
||||
if oidFound {
|
||||
return "", NilObjectID, errors.New("duplicate $id key in $dbPointer")
|
||||
}
|
||||
|
||||
if val.t != TypeString {
|
||||
return "", NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t)
|
||||
}
|
||||
|
||||
oid, err = ObjectIDFromHex(val.v.(string))
|
||||
if err != nil {
|
||||
return "", NilObjectID, err
|
||||
}
|
||||
|
||||
oidFound = true
|
||||
default:
|
||||
return "", NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
if !nsFound {
|
||||
return "", oid, errors.New("missing $ref field in $dbPointer object")
|
||||
}
|
||||
|
||||
if !oidFound {
|
||||
return "", oid, errors.New("missing $id field in $dbPointer object")
|
||||
}
|
||||
|
||||
return ns, oid, nil
|
||||
}
|
||||
|
||||
const (
|
||||
rfc3339Milli = "2006-01-02T15:04:05.999Z07:00"
|
||||
)
|
||||
|
||||
var (
|
||||
timeFormats = []string{rfc3339Milli, "2006-01-02T15:04:05.999Z0700"}
|
||||
)
|
||||
|
||||
func (ejv *extJSONValue) parseDateTime() (int64, error) {
|
||||
switch ejv.t {
|
||||
case TypeInt32:
|
||||
return int64(ejv.v.(int32)), nil
|
||||
case TypeInt64:
|
||||
return ejv.v.(int64), nil
|
||||
case TypeString:
|
||||
return parseDatetimeString(ejv.v.(string))
|
||||
case TypeEmbeddedDocument:
|
||||
return parseDatetimeObject(ejv.v.(*extJSONObject))
|
||||
default:
|
||||
return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t)
|
||||
}
|
||||
}
|
||||
|
||||
func parseDatetimeString(data string) (int64, error) {
|
||||
var t time.Time
|
||||
var err error
|
||||
// try acceptable time formats until one matches
|
||||
for _, format := range timeFormats {
|
||||
t, err = time.Parse(format, data)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid $date value string: %s", data)
|
||||
}
|
||||
|
||||
return int64(NewDateTimeFromTime(t)), nil
|
||||
}
|
||||
|
||||
func parseDatetimeObject(data *extJSONObject) (d int64, err error) {
|
||||
dFound := false
|
||||
|
||||
for i, key := range data.keys {
|
||||
val := data.values[i]
|
||||
|
||||
switch key {
|
||||
case "$numberLong":
|
||||
if dFound {
|
||||
return 0, errors.New("duplicate $numberLong key in $date")
|
||||
}
|
||||
|
||||
if val.t != TypeString {
|
||||
return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t)
|
||||
}
|
||||
|
||||
d, err = val.parseInt64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
dFound = true
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid key in $date object: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
if !dFound {
|
||||
return 0, errors.New("missing $numberLong field in $date object")
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseDecimal128() (Decimal128, error) {
|
||||
if ejv.t != TypeString {
|
||||
return Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
d, err := ParseDecimal128(ejv.v.(string))
|
||||
if err != nil {
|
||||
return Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string))
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseDouble() (float64, error) {
|
||||
if ejv.t == TypeDouble {
|
||||
return ejv.v.(float64), nil
|
||||
}
|
||||
|
||||
if ejv.t != TypeString {
|
||||
return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
switch ejv.v.(string) {
|
||||
case "Infinity":
|
||||
return math.Inf(1), nil
|
||||
case "-Infinity":
|
||||
return math.Inf(-1), nil
|
||||
case "NaN":
|
||||
return math.NaN(), nil
|
||||
}
|
||||
|
||||
f, err := strconv.ParseFloat(ejv.v.(string), 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseInt32() (int32, error) {
|
||||
if ejv.t == TypeInt32 {
|
||||
return ejv.v.(int32), nil
|
||||
}
|
||||
|
||||
if ejv.t != TypeString {
|
||||
return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if i < math.MinInt32 || i > math.MaxInt32 {
|
||||
return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i)
|
||||
}
|
||||
|
||||
return int32(i), nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseInt64() (int64, error) {
|
||||
if ejv.t == TypeInt64 {
|
||||
return ejv.v.(int64), nil
|
||||
}
|
||||
|
||||
if ejv.t != TypeString {
|
||||
return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseJavascript() (code string, err error) {
|
||||
if ejv.t != TypeString {
|
||||
return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
return ejv.v.(string), nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseMinMaxKey(minmax string) error {
|
||||
if ejv.t != TypeInt32 {
|
||||
return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t)
|
||||
}
|
||||
|
||||
if ejv.v.(int32) != 1 {
|
||||
return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseObjectID() (ObjectID, error) {
|
||||
if ejv.t != TypeString {
|
||||
return NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
return ObjectIDFromHex(ejv.v.(string))
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) {
|
||||
if ejv.t != TypeEmbeddedDocument {
|
||||
return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
regexObj := ejv.v.(*extJSONObject)
|
||||
patFound := false
|
||||
optFound := false
|
||||
|
||||
for i, key := range regexObj.keys {
|
||||
val := regexObj.values[i]
|
||||
|
||||
switch key {
|
||||
case "pattern":
|
||||
if patFound {
|
||||
return "", "", errors.New("duplicate pattern key in $regularExpression")
|
||||
}
|
||||
|
||||
if val.t != TypeString {
|
||||
return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t)
|
||||
}
|
||||
|
||||
pattern = val.v.(string)
|
||||
patFound = true
|
||||
case "options":
|
||||
if optFound {
|
||||
return "", "", errors.New("duplicate options key in $regularExpression")
|
||||
}
|
||||
|
||||
if val.t != TypeString {
|
||||
return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t)
|
||||
}
|
||||
|
||||
options = val.v.(string)
|
||||
optFound = true
|
||||
default:
|
||||
return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
if !patFound {
|
||||
return "", "", errors.New("missing pattern field in $regularExpression object")
|
||||
}
|
||||
|
||||
if !optFound {
|
||||
return "", "", errors.New("missing options field in $regularExpression object")
|
||||
|
||||
}
|
||||
|
||||
return pattern, options, nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseSymbol() (string, error) {
|
||||
if ejv.t != TypeString {
|
||||
return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
return ejv.v.(string), nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) {
|
||||
if ejv.t != TypeEmbeddedDocument {
|
||||
return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) {
|
||||
if flag {
|
||||
return 0, fmt.Errorf("duplicate %s key in $timestamp", key)
|
||||
}
|
||||
|
||||
switch val.t {
|
||||
case TypeInt32:
|
||||
value := val.v.(int32)
|
||||
|
||||
if value < 0 {
|
||||
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
|
||||
}
|
||||
|
||||
return uint32(value), nil
|
||||
case TypeInt64:
|
||||
value := val.v.(int64)
|
||||
if value < 0 || value > int64(math.MaxUint32) {
|
||||
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
|
||||
}
|
||||
|
||||
return uint32(value), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t)
|
||||
}
|
||||
}
|
||||
|
||||
tsObj := ejv.v.(*extJSONObject)
|
||||
tFound := false
|
||||
iFound := false
|
||||
|
||||
for j, key := range tsObj.keys {
|
||||
val := tsObj.values[j]
|
||||
|
||||
switch key {
|
||||
case "t":
|
||||
if t, err = handleKey(key, val, tFound); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
tFound = true
|
||||
case "i":
|
||||
if i, err = handleKey(key, val, iFound); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
iFound = true
|
||||
default:
|
||||
return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
if !tFound {
|
||||
return 0, 0, errors.New("missing t field in $timestamp object")
|
||||
}
|
||||
|
||||
if !iFound {
|
||||
return 0, 0, errors.New("missing i field in $timestamp object")
|
||||
}
|
||||
|
||||
return t, i, nil
|
||||
}
|
||||
|
||||
func (ejv *extJSONValue) parseUndefined() error {
|
||||
if ejv.t != TypeBoolean {
|
||||
return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t)
|
||||
}
|
||||
|
||||
if !ejv.v.(bool) {
|
||||
return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
690
extjson_writer.go
Normal file
690
extjson_writer.go
Normal file
@ -0,0 +1,690 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type ejvwState struct {
|
||||
mode mode
|
||||
}
|
||||
|
||||
type extJSONValueWriter struct {
|
||||
w io.Writer
|
||||
buf []byte
|
||||
|
||||
stack []ejvwState
|
||||
frame int64
|
||||
canonical bool
|
||||
escapeHTML bool
|
||||
newlines bool
|
||||
}
|
||||
|
||||
// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
|
||||
func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) ValueWriter {
|
||||
// Enable newlines for all Extended JSON value writers created by NewExtJSONValueWriter. We
|
||||
// expect these value writers to be used with an Encoder, which should add newlines after
|
||||
// encoded Extended JSON documents.
|
||||
return newExtJSONWriter(w, canonical, escapeHTML, true)
|
||||
}
|
||||
|
||||
func newExtJSONWriter(w io.Writer, canonical, escapeHTML, newlines bool) *extJSONValueWriter {
|
||||
stack := make([]ejvwState, 1, 5)
|
||||
stack[0] = ejvwState{mode: mTopLevel}
|
||||
|
||||
return &extJSONValueWriter{
|
||||
w: w,
|
||||
buf: []byte{},
|
||||
stack: stack,
|
||||
canonical: canonical,
|
||||
escapeHTML: escapeHTML,
|
||||
newlines: newlines,
|
||||
}
|
||||
}
|
||||
|
||||
func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
|
||||
stack := make([]ejvwState, 1, 5)
|
||||
stack[0] = ejvwState{mode: mTopLevel}
|
||||
|
||||
return &extJSONValueWriter{
|
||||
buf: buf,
|
||||
stack: stack,
|
||||
canonical: canonical,
|
||||
escapeHTML: escapeHTML,
|
||||
}
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
|
||||
if ejvw.stack == nil {
|
||||
ejvw.stack = make([]ejvwState, 1, 5)
|
||||
}
|
||||
|
||||
ejvw.stack = ejvw.stack[:1]
|
||||
ejvw.stack[0] = ejvwState{mode: mTopLevel}
|
||||
ejvw.canonical = canonical
|
||||
ejvw.escapeHTML = escapeHTML
|
||||
ejvw.frame = 0
|
||||
ejvw.buf = buf
|
||||
ejvw.w = nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) advanceFrame() {
|
||||
if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
|
||||
length := len(ejvw.stack)
|
||||
if length+1 >= cap(ejvw.stack) {
|
||||
// double it
|
||||
buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
|
||||
copy(buf, ejvw.stack)
|
||||
ejvw.stack = buf
|
||||
}
|
||||
ejvw.stack = ejvw.stack[:length+1]
|
||||
}
|
||||
ejvw.frame++
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) push(m mode) {
|
||||
ejvw.advanceFrame()
|
||||
|
||||
ejvw.stack[ejvw.frame].mode = m
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) pop() {
|
||||
switch ejvw.stack[ejvw.frame].mode {
|
||||
case mElement, mValue:
|
||||
ejvw.frame--
|
||||
case mDocument, mArray, mCodeWithScope:
|
||||
ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
|
||||
}
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
|
||||
te := TransitionError{
|
||||
name: name,
|
||||
current: ejvw.stack[ejvw.frame].mode,
|
||||
destination: destination,
|
||||
modes: modes,
|
||||
action: "write",
|
||||
}
|
||||
if ejvw.frame != 0 {
|
||||
te.parent = ejvw.stack[ejvw.frame-1].mode
|
||||
}
|
||||
return te
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
|
||||
switch ejvw.stack[ejvw.frame].mode {
|
||||
case mElement, mValue:
|
||||
default:
|
||||
modes := []mode{mElement, mValue}
|
||||
if addmodes != nil {
|
||||
modes = append(modes, addmodes...)
|
||||
}
|
||||
return ejvw.invalidTransitionErr(destination, callerName, modes)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
|
||||
var s string
|
||||
if quotes {
|
||||
s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
|
||||
} else {
|
||||
s = fmt.Sprintf(`{"$%s":%s}`, key, value)
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, []byte(s)...)
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
|
||||
if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, '[')
|
||||
|
||||
ejvw.push(mArray)
|
||||
return ejvw, nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
|
||||
return ejvw.WriteBinaryWithSubtype(b, 0x00)
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(`{"$binary":{"base64":"`)
|
||||
buf.WriteString(base64.StdEncoding.EncodeToString(b))
|
||||
buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
|
||||
|
||||
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
|
||||
if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(`{"$code":`)
|
||||
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
|
||||
buf.WriteString(`,"$scope":{`)
|
||||
|
||||
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||
|
||||
ejvw.push(mCodeWithScope)
|
||||
return ejvw, nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid ObjectID) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(`{"$dbPointer":{"$ref":"`)
|
||||
buf.WriteString(ns)
|
||||
buf.WriteString(`","$id":{"$oid":"`)
|
||||
buf.WriteString(oid.Hex())
|
||||
buf.WriteString(`"}}},`)
|
||||
|
||||
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
|
||||
|
||||
if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
|
||||
s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
|
||||
ejvw.writeExtendedSingleValue("date", s, false)
|
||||
} else {
|
||||
ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteDecimal128(d Decimal128) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
|
||||
if ejvw.stack[ejvw.frame].mode == mTopLevel {
|
||||
ejvw.buf = append(ejvw.buf, '{')
|
||||
return ejvw, nil
|
||||
}
|
||||
|
||||
if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, '{')
|
||||
ejvw.push(mDocument)
|
||||
return ejvw, nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s := formatDouble(f)
|
||||
|
||||
if ejvw.canonical {
|
||||
ejvw.writeExtendedSingleValue("numberDouble", s, true)
|
||||
} else {
|
||||
switch s {
|
||||
case "Infinity":
|
||||
fallthrough
|
||||
case "-Infinity":
|
||||
fallthrough
|
||||
case "NaN":
|
||||
s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
|
||||
}
|
||||
ejvw.buf = append(ejvw.buf, []byte(s)...)
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s := strconv.FormatInt(int64(i), 10)
|
||||
|
||||
if ejvw.canonical {
|
||||
ejvw.writeExtendedSingleValue("numberInt", s, true)
|
||||
} else {
|
||||
ejvw.buf = append(ejvw.buf, []byte(s)...)
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s := strconv.FormatInt(i, 10)
|
||||
|
||||
if ejvw.canonical {
|
||||
ejvw.writeExtendedSingleValue("numberLong", s, true)
|
||||
} else {
|
||||
ejvw.buf = append(ejvw.buf, []byte(s)...)
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
|
||||
|
||||
ejvw.writeExtendedSingleValue("code", buf.String(), false)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteMaxKey() error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ejvw.writeExtendedSingleValue("maxKey", "1", false)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteMinKey() error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ejvw.writeExtendedSingleValue("minKey", "1", false)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteNull() error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, []byte("null")...)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteObjectID(oid ObjectID) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
options = sortStringAlphebeticAscending(options)
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(`{"$regularExpression":{"pattern":`)
|
||||
writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
|
||||
buf.WriteString(`,"options":`)
|
||||
writeStringWithEscapes(options, &buf, ejvw.escapeHTML)
|
||||
buf.WriteString(`}},`)
|
||||
|
||||
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteString(s string) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
|
||||
|
||||
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
|
||||
|
||||
ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(`{"$timestamp":{"t":`)
|
||||
buf.WriteString(strconv.FormatUint(uint64(t), 10))
|
||||
buf.WriteString(`,"i":`)
|
||||
buf.WriteString(strconv.FormatUint(uint64(i), 10))
|
||||
buf.WriteString(`}},`)
|
||||
|
||||
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteUndefined() error {
|
||||
if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ejvw.writeExtendedSingleValue("undefined", "true", false)
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
|
||||
switch ejvw.stack[ejvw.frame].mode {
|
||||
case mDocument, mTopLevel, mCodeWithScope:
|
||||
var buf bytes.Buffer
|
||||
writeStringWithEscapes(key, &buf, ejvw.escapeHTML)
|
||||
|
||||
ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`%s:`, buf.String()))...)
|
||||
ejvw.push(mElement)
|
||||
default:
|
||||
return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
|
||||
}
|
||||
|
||||
return ejvw, nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
|
||||
switch ejvw.stack[ejvw.frame].mode {
|
||||
case mDocument, mTopLevel, mCodeWithScope:
|
||||
default:
|
||||
return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
|
||||
}
|
||||
|
||||
// close the document
|
||||
if ejvw.buf[len(ejvw.buf)-1] == ',' {
|
||||
ejvw.buf[len(ejvw.buf)-1] = '}'
|
||||
} else {
|
||||
ejvw.buf = append(ejvw.buf, '}')
|
||||
}
|
||||
|
||||
switch ejvw.stack[ejvw.frame].mode {
|
||||
case mCodeWithScope:
|
||||
ejvw.buf = append(ejvw.buf, '}')
|
||||
fallthrough
|
||||
case mDocument:
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
case mTopLevel:
|
||||
// If the value writer has newlines enabled, end top-level documents with a newline so that
|
||||
// multiple documents encoded to the same writer are separated by newlines. That matches the
|
||||
// Go json.Encoder behavior and also works with NewExtJSONValueReader.
|
||||
if ejvw.newlines {
|
||||
ejvw.buf = append(ejvw.buf, '\n')
|
||||
}
|
||||
if ejvw.w != nil {
|
||||
if _, err := ejvw.w.Write(ejvw.buf); err != nil {
|
||||
return err
|
||||
}
|
||||
ejvw.buf = ejvw.buf[:0]
|
||||
}
|
||||
}
|
||||
|
||||
ejvw.pop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
|
||||
switch ejvw.stack[ejvw.frame].mode {
|
||||
case mArray:
|
||||
ejvw.push(mValue)
|
||||
default:
|
||||
return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
|
||||
}
|
||||
|
||||
return ejvw, nil
|
||||
}
|
||||
|
||||
func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
|
||||
switch ejvw.stack[ejvw.frame].mode {
|
||||
case mArray:
|
||||
// close the array
|
||||
if ejvw.buf[len(ejvw.buf)-1] == ',' {
|
||||
ejvw.buf[len(ejvw.buf)-1] = ']'
|
||||
} else {
|
||||
ejvw.buf = append(ejvw.buf, ']')
|
||||
}
|
||||
|
||||
ejvw.buf = append(ejvw.buf, ',')
|
||||
|
||||
ejvw.pop()
|
||||
default:
|
||||
return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatDouble(f float64) string {
|
||||
var s string
|
||||
switch {
|
||||
case math.IsInf(f, 1):
|
||||
s = "Infinity"
|
||||
case math.IsInf(f, -1):
|
||||
s = "-Infinity"
|
||||
case math.IsNaN(f):
|
||||
s = "NaN"
|
||||
default:
|
||||
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
|
||||
// perfectly represent it.
|
||||
s = strconv.FormatFloat(f, 'G', -1, 64)
|
||||
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
|
||||
s += ".0"
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
var hexChars = "0123456789abcdef"
|
||||
|
||||
func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
|
||||
buf.WriteByte('"')
|
||||
start := 0
|
||||
for i := 0; i < len(s); {
|
||||
if b := s[i]; b < utf8.RuneSelf {
|
||||
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if start < i {
|
||||
buf.WriteString(s[start:i])
|
||||
}
|
||||
switch b {
|
||||
case '\\', '"':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte(b)
|
||||
case '\n':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('n')
|
||||
case '\r':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('r')
|
||||
case '\t':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('t')
|
||||
case '\b':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('b')
|
||||
case '\f':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('f')
|
||||
default:
|
||||
// This encodes bytes < 0x20 except for \t, \n and \r.
|
||||
// If escapeHTML is set, it also escapes <, >, and &
|
||||
// because they can lead to security holes when
|
||||
// user-controlled strings are rendered into JSON
|
||||
// and served to some browsers.
|
||||
buf.WriteString(`\u00`)
|
||||
buf.WriteByte(hexChars[b>>4])
|
||||
buf.WriteByte(hexChars[b&0xF])
|
||||
}
|
||||
i++
|
||||
start = i
|
||||
continue
|
||||
}
|
||||
c, size := utf8.DecodeRuneInString(s[i:])
|
||||
if c == utf8.RuneError && size == 1 {
|
||||
if start < i {
|
||||
buf.WriteString(s[start:i])
|
||||
}
|
||||
buf.WriteString(`\ufffd`)
|
||||
i += size
|
||||
start = i
|
||||
continue
|
||||
}
|
||||
// U+2028 is LINE SEPARATOR.
|
||||
// U+2029 is PARAGRAPH SEPARATOR.
|
||||
// They are both technically valid characters in JSON strings,
|
||||
// but don't work in JSONP, which has to be evaluated as JavaScript,
|
||||
// and can lead to security holes there. It is valid JSON to
|
||||
// escape them, so we do so unconditionally.
|
||||
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
|
||||
if c == '\u2028' || c == '\u2029' {
|
||||
if start < i {
|
||||
buf.WriteString(s[start:i])
|
||||
}
|
||||
buf.WriteString(`\u202`)
|
||||
buf.WriteByte(hexChars[c&0xF])
|
||||
i += size
|
||||
start = i
|
||||
continue
|
||||
}
|
||||
i += size
|
||||
}
|
||||
if start < len(s) {
|
||||
buf.WriteString(s[start:])
|
||||
}
|
||||
buf.WriteByte('"')
|
||||
}
|
||||
|
||||
type sortableString []rune
|
||||
|
||||
func (ss sortableString) Len() int {
|
||||
return len(ss)
|
||||
}
|
||||
|
||||
func (ss sortableString) Less(i, j int) bool {
|
||||
return ss[i] < ss[j]
|
||||
}
|
||||
|
||||
func (ss sortableString) Swap(i, j int) {
|
||||
ss[i], ss[j] = ss[j], ss[i]
|
||||
}
|
||||
|
||||
func sortStringAlphebeticAscending(s string) string {
|
||||
ss := sortableString([]rune(s))
|
||||
sort.Sort(ss)
|
||||
return string([]rune(ss))
|
||||
}
|
259
extjson_writer_test.go
Normal file
259
extjson_writer_test.go
Normal file
@ -0,0 +1,259 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
)
|
||||
|
||||
func TestExtJSONValueWriter(t *testing.T) {
|
||||
oid := ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
|
||||
testCases := []struct {
|
||||
name string
|
||||
fn interface{}
|
||||
params []interface{}
|
||||
}{
|
||||
{
|
||||
"WriteBinary",
|
||||
(*extJSONValueWriter).WriteBinary,
|
||||
[]interface{}{[]byte{0x01, 0x02, 0x03}},
|
||||
},
|
||||
{
|
||||
"WriteBinaryWithSubtype (not 0x02)",
|
||||
(*extJSONValueWriter).WriteBinaryWithSubtype,
|
||||
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xFF)},
|
||||
},
|
||||
{
|
||||
"WriteBinaryWithSubtype (0x02)",
|
||||
(*extJSONValueWriter).WriteBinaryWithSubtype,
|
||||
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0x02)},
|
||||
},
|
||||
{
|
||||
"WriteBoolean",
|
||||
(*extJSONValueWriter).WriteBoolean,
|
||||
[]interface{}{true},
|
||||
},
|
||||
{
|
||||
"WriteDBPointer",
|
||||
(*extJSONValueWriter).WriteDBPointer,
|
||||
[]interface{}{"bar", oid},
|
||||
},
|
||||
{
|
||||
"WriteDateTime",
|
||||
(*extJSONValueWriter).WriteDateTime,
|
||||
[]interface{}{int64(12345678)},
|
||||
},
|
||||
{
|
||||
"WriteDecimal128",
|
||||
(*extJSONValueWriter).WriteDecimal128,
|
||||
[]interface{}{NewDecimal128(10, 20)},
|
||||
},
|
||||
{
|
||||
"WriteDouble",
|
||||
(*extJSONValueWriter).WriteDouble,
|
||||
[]interface{}{float64(3.14159)},
|
||||
},
|
||||
{
|
||||
"WriteInt32",
|
||||
(*extJSONValueWriter).WriteInt32,
|
||||
[]interface{}{int32(123456)},
|
||||
},
|
||||
{
|
||||
"WriteInt64",
|
||||
(*extJSONValueWriter).WriteInt64,
|
||||
[]interface{}{int64(1234567890)},
|
||||
},
|
||||
{
|
||||
"WriteJavascript",
|
||||
(*extJSONValueWriter).WriteJavascript,
|
||||
[]interface{}{"var foo = 'bar';"},
|
||||
},
|
||||
{
|
||||
"WriteMaxKey",
|
||||
(*extJSONValueWriter).WriteMaxKey,
|
||||
[]interface{}{},
|
||||
},
|
||||
{
|
||||
"WriteMinKey",
|
||||
(*extJSONValueWriter).WriteMinKey,
|
||||
[]interface{}{},
|
||||
},
|
||||
{
|
||||
"WriteNull",
|
||||
(*extJSONValueWriter).WriteNull,
|
||||
[]interface{}{},
|
||||
},
|
||||
{
|
||||
"WriteObjectID",
|
||||
(*extJSONValueWriter).WriteObjectID,
|
||||
[]interface{}{oid},
|
||||
},
|
||||
{
|
||||
"WriteRegex",
|
||||
(*extJSONValueWriter).WriteRegex,
|
||||
[]interface{}{"bar", "baz"},
|
||||
},
|
||||
{
|
||||
"WriteString",
|
||||
(*extJSONValueWriter).WriteString,
|
||||
[]interface{}{"hello, world!"},
|
||||
},
|
||||
{
|
||||
"WriteSymbol",
|
||||
(*extJSONValueWriter).WriteSymbol,
|
||||
[]interface{}{"symbollolz"},
|
||||
},
|
||||
{
|
||||
"WriteTimestamp",
|
||||
(*extJSONValueWriter).WriteTimestamp,
|
||||
[]interface{}{uint32(10), uint32(20)},
|
||||
},
|
||||
{
|
||||
"WriteUndefined",
|
||||
(*extJSONValueWriter).WriteUndefined,
|
||||
[]interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fn := reflect.ValueOf(tc.fn)
|
||||
if fn.Kind() != reflect.Func {
|
||||
t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind())
|
||||
}
|
||||
if fn.Type().NumIn() != len(tc.params)+1 || fn.Type().In(0) != reflect.TypeOf((*extJSONValueWriter)(nil)) {
|
||||
t.Fatalf("fn must have at least one parameter and the first parameter must be a *valueWriter")
|
||||
}
|
||||
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
|
||||
t.Fatalf("fn must have one return value and it must be an error.")
|
||||
}
|
||||
params := make([]reflect.Value, 1, len(tc.params)+1)
|
||||
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||
params[0] = reflect.ValueOf(ejvw)
|
||||
for _, param := range tc.params {
|
||||
params = append(params, reflect.ValueOf(param))
|
||||
}
|
||||
|
||||
t.Run("incorrect transition", func(t *testing.T) {
|
||||
results := fn.Call(params)
|
||||
got := results[0].Interface().(error)
|
||||
fnName := tc.name
|
||||
if strings.Contains(fnName, "WriteBinary") {
|
||||
fnName = "WriteBinaryWithSubtype"
|
||||
}
|
||||
want := TransitionError{current: mTopLevel, name: fnName, modes: []mode{mElement, mValue},
|
||||
action: "write"}
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("WriteArray", func(t *testing.T) {
|
||||
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||
ejvw.push(mArray)
|
||||
want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel,
|
||||
name: "WriteArray", modes: []mode{mElement, mValue}, action: "write"}
|
||||
_, got := ejvw.WriteArray()
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteCodeWithScope", func(t *testing.T) {
|
||||
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||
ejvw.push(mArray)
|
||||
want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel,
|
||||
name: "WriteCodeWithScope", modes: []mode{mElement, mValue}, action: "write"}
|
||||
_, got := ejvw.WriteCodeWithScope("")
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteDocument", func(t *testing.T) {
|
||||
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||
ejvw.push(mArray)
|
||||
want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel,
|
||||
name: "WriteDocument", modes: []mode{mElement, mValue, mTopLevel}, action: "write"}
|
||||
_, got := ejvw.WriteDocument()
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteDocumentElement", func(t *testing.T) {
|
||||
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||
ejvw.push(mElement)
|
||||
want := TransitionError{current: mElement,
|
||||
destination: mElement,
|
||||
parent: mTopLevel,
|
||||
name: "WriteDocumentElement",
|
||||
modes: []mode{mDocument, mTopLevel, mCodeWithScope},
|
||||
action: "write"}
|
||||
_, got := ejvw.WriteDocumentElement("")
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteDocumentEnd", func(t *testing.T) {
|
||||
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||
ejvw.push(mElement)
|
||||
want := fmt.Errorf("incorrect mode to end document: %s", mElement)
|
||||
got := ejvw.WriteDocumentEnd()
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteArrayElement", func(t *testing.T) {
|
||||
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||
ejvw.push(mElement)
|
||||
want := TransitionError{current: mElement,
|
||||
destination: mValue,
|
||||
parent: mTopLevel,
|
||||
name: "WriteArrayElement",
|
||||
modes: []mode{mArray},
|
||||
action: "write"}
|
||||
_, got := ejvw.WriteArrayElement()
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteArrayEnd", func(t *testing.T) {
|
||||
ejvw := newExtJSONWriter(io.Discard, true, true, false)
|
||||
ejvw.push(mElement)
|
||||
want := fmt.Errorf("incorrect mode to end array: %s", mElement)
|
||||
got := ejvw.WriteArrayEnd()
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WriteBytes", func(t *testing.T) {
|
||||
t.Run("writeElementHeader error", func(t *testing.T) {
|
||||
ejvw := newExtJSONWriterFromSlice(nil, true, true)
|
||||
want := TransitionError{current: mTopLevel, destination: mode(0),
|
||||
name: "WriteBinaryWithSubtype", modes: []mode{mElement, mValue}, action: "write"}
|
||||
got := ejvw.WriteBinaryWithSubtype(nil, (byte)(TypeEmbeddedDocument))
|
||||
if !assert.CompareErrors(got, want) {
|
||||
t.Errorf("Did not received expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("FormatDoubleWithExponent", func(t *testing.T) {
|
||||
want := "3E-12"
|
||||
got := formatDouble(float64(0.000000000003))
|
||||
if got != want {
|
||||
t.Errorf("Did not receive expected string. got %s: want %s", got, want)
|
||||
}
|
||||
})
|
||||
}
|
40
fuzz_test.go
Normal file
40
fuzz_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bson
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func FuzzDecode(f *testing.F) {
|
||||
seedBSONCorpus(f)
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
for _, typ := range []func() interface{}{
|
||||
func() interface{} { return new(D) },
|
||||
func() interface{} { return new([]E) },
|
||||
func() interface{} { return new(M) },
|
||||
func() interface{} { return new(interface{}) },
|
||||
func() interface{} { return make(map[string]interface{}) },
|
||||
func() interface{} { return new([]interface{}) },
|
||||
} {
|
||||
i := typ()
|
||||
if err := Unmarshal(data, i); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
encoded, err := Marshal(i)
|
||||
if err != nil {
|
||||
t.Fatal("failed to marshal", err)
|
||||
}
|
||||
|
||||
if err := Unmarshal(encoded, i); err != nil {
|
||||
t.Fatal("failed to unmarshal", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
12
go.mod
Normal file
12
go.mod
Normal file
@ -0,0 +1,12 @@
|
||||
module gitea.psichedelico.com/go/bson
|
||||
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.23.7
|
||||
|
||||
require (
|
||||
github.com/google/go-cmp v0.7.0
|
||||
golang.org/x/sync v0.12.0
|
||||
)
|
||||
|
||||
require github.com/davecgh/go-spew v1.1.1
|
6
go.sum
Normal file
6
go.sum
Normal file
@ -0,0 +1,6 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
481
internal/assert/assertion_compare.go
Normal file
481
internal/assert/assertion_compare.go
Normal file
@ -0,0 +1,481 @@
|
||||
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare.go
|
||||
|
||||
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style license that can be found in
|
||||
// the THIRD-PARTY-NOTICES file.
|
||||
|
||||
package assert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CompareType int
|
||||
|
||||
const (
|
||||
compareLess CompareType = iota - 1
|
||||
compareEqual
|
||||
compareGreater
|
||||
)
|
||||
|
||||
var (
|
||||
intType = reflect.TypeOf(int(1))
|
||||
int8Type = reflect.TypeOf(int8(1))
|
||||
int16Type = reflect.TypeOf(int16(1))
|
||||
int32Type = reflect.TypeOf(int32(1))
|
||||
int64Type = reflect.TypeOf(int64(1))
|
||||
|
||||
uintType = reflect.TypeOf(uint(1))
|
||||
uint8Type = reflect.TypeOf(uint8(1))
|
||||
uint16Type = reflect.TypeOf(uint16(1))
|
||||
uint32Type = reflect.TypeOf(uint32(1))
|
||||
uint64Type = reflect.TypeOf(uint64(1))
|
||||
|
||||
float32Type = reflect.TypeOf(float32(1))
|
||||
float64Type = reflect.TypeOf(float64(1))
|
||||
|
||||
stringType = reflect.TypeOf("")
|
||||
|
||||
timeType = reflect.TypeOf(time.Time{})
|
||||
bytesType = reflect.TypeOf([]byte{})
|
||||
)
|
||||
|
||||
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
|
||||
obj1Value := reflect.ValueOf(obj1)
|
||||
obj2Value := reflect.ValueOf(obj2)
|
||||
|
||||
// throughout this switch we try and avoid calling .Convert() if possible,
|
||||
// as this has a pretty big performance impact
|
||||
switch kind {
|
||||
case reflect.Int:
|
||||
{
|
||||
intobj1, ok := obj1.(int)
|
||||
if !ok {
|
||||
intobj1 = obj1Value.Convert(intType).Interface().(int)
|
||||
}
|
||||
intobj2, ok := obj2.(int)
|
||||
if !ok {
|
||||
intobj2 = obj2Value.Convert(intType).Interface().(int)
|
||||
}
|
||||
if intobj1 > intobj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if intobj1 == intobj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if intobj1 < intobj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Int8:
|
||||
{
|
||||
int8obj1, ok := obj1.(int8)
|
||||
if !ok {
|
||||
int8obj1 = obj1Value.Convert(int8Type).Interface().(int8)
|
||||
}
|
||||
int8obj2, ok := obj2.(int8)
|
||||
if !ok {
|
||||
int8obj2 = obj2Value.Convert(int8Type).Interface().(int8)
|
||||
}
|
||||
if int8obj1 > int8obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if int8obj1 == int8obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if int8obj1 < int8obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Int16:
|
||||
{
|
||||
int16obj1, ok := obj1.(int16)
|
||||
if !ok {
|
||||
int16obj1 = obj1Value.Convert(int16Type).Interface().(int16)
|
||||
}
|
||||
int16obj2, ok := obj2.(int16)
|
||||
if !ok {
|
||||
int16obj2 = obj2Value.Convert(int16Type).Interface().(int16)
|
||||
}
|
||||
if int16obj1 > int16obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if int16obj1 == int16obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if int16obj1 < int16obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Int32:
|
||||
{
|
||||
int32obj1, ok := obj1.(int32)
|
||||
if !ok {
|
||||
int32obj1 = obj1Value.Convert(int32Type).Interface().(int32)
|
||||
}
|
||||
int32obj2, ok := obj2.(int32)
|
||||
if !ok {
|
||||
int32obj2 = obj2Value.Convert(int32Type).Interface().(int32)
|
||||
}
|
||||
if int32obj1 > int32obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if int32obj1 == int32obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if int32obj1 < int32obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Int64:
|
||||
{
|
||||
int64obj1, ok := obj1.(int64)
|
||||
if !ok {
|
||||
int64obj1 = obj1Value.Convert(int64Type).Interface().(int64)
|
||||
}
|
||||
int64obj2, ok := obj2.(int64)
|
||||
if !ok {
|
||||
int64obj2 = obj2Value.Convert(int64Type).Interface().(int64)
|
||||
}
|
||||
if int64obj1 > int64obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if int64obj1 == int64obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if int64obj1 < int64obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint:
|
||||
{
|
||||
uintobj1, ok := obj1.(uint)
|
||||
if !ok {
|
||||
uintobj1 = obj1Value.Convert(uintType).Interface().(uint)
|
||||
}
|
||||
uintobj2, ok := obj2.(uint)
|
||||
if !ok {
|
||||
uintobj2 = obj2Value.Convert(uintType).Interface().(uint)
|
||||
}
|
||||
if uintobj1 > uintobj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uintobj1 == uintobj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uintobj1 < uintobj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint8:
|
||||
{
|
||||
uint8obj1, ok := obj1.(uint8)
|
||||
if !ok {
|
||||
uint8obj1 = obj1Value.Convert(uint8Type).Interface().(uint8)
|
||||
}
|
||||
uint8obj2, ok := obj2.(uint8)
|
||||
if !ok {
|
||||
uint8obj2 = obj2Value.Convert(uint8Type).Interface().(uint8)
|
||||
}
|
||||
if uint8obj1 > uint8obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uint8obj1 == uint8obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uint8obj1 < uint8obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint16:
|
||||
{
|
||||
uint16obj1, ok := obj1.(uint16)
|
||||
if !ok {
|
||||
uint16obj1 = obj1Value.Convert(uint16Type).Interface().(uint16)
|
||||
}
|
||||
uint16obj2, ok := obj2.(uint16)
|
||||
if !ok {
|
||||
uint16obj2 = obj2Value.Convert(uint16Type).Interface().(uint16)
|
||||
}
|
||||
if uint16obj1 > uint16obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uint16obj1 == uint16obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uint16obj1 < uint16obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint32:
|
||||
{
|
||||
uint32obj1, ok := obj1.(uint32)
|
||||
if !ok {
|
||||
uint32obj1 = obj1Value.Convert(uint32Type).Interface().(uint32)
|
||||
}
|
||||
uint32obj2, ok := obj2.(uint32)
|
||||
if !ok {
|
||||
uint32obj2 = obj2Value.Convert(uint32Type).Interface().(uint32)
|
||||
}
|
||||
if uint32obj1 > uint32obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uint32obj1 == uint32obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uint32obj1 < uint32obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Uint64:
|
||||
{
|
||||
uint64obj1, ok := obj1.(uint64)
|
||||
if !ok {
|
||||
uint64obj1 = obj1Value.Convert(uint64Type).Interface().(uint64)
|
||||
}
|
||||
uint64obj2, ok := obj2.(uint64)
|
||||
if !ok {
|
||||
uint64obj2 = obj2Value.Convert(uint64Type).Interface().(uint64)
|
||||
}
|
||||
if uint64obj1 > uint64obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if uint64obj1 == uint64obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if uint64obj1 < uint64obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Float32:
|
||||
{
|
||||
float32obj1, ok := obj1.(float32)
|
||||
if !ok {
|
||||
float32obj1 = obj1Value.Convert(float32Type).Interface().(float32)
|
||||
}
|
||||
float32obj2, ok := obj2.(float32)
|
||||
if !ok {
|
||||
float32obj2 = obj2Value.Convert(float32Type).Interface().(float32)
|
||||
}
|
||||
if float32obj1 > float32obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if float32obj1 == float32obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if float32obj1 < float32obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.Float64:
|
||||
{
|
||||
float64obj1, ok := obj1.(float64)
|
||||
if !ok {
|
||||
float64obj1 = obj1Value.Convert(float64Type).Interface().(float64)
|
||||
}
|
||||
float64obj2, ok := obj2.(float64)
|
||||
if !ok {
|
||||
float64obj2 = obj2Value.Convert(float64Type).Interface().(float64)
|
||||
}
|
||||
if float64obj1 > float64obj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if float64obj1 == float64obj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if float64obj1 < float64obj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
{
|
||||
stringobj1, ok := obj1.(string)
|
||||
if !ok {
|
||||
stringobj1 = obj1Value.Convert(stringType).Interface().(string)
|
||||
}
|
||||
stringobj2, ok := obj2.(string)
|
||||
if !ok {
|
||||
stringobj2 = obj2Value.Convert(stringType).Interface().(string)
|
||||
}
|
||||
if stringobj1 > stringobj2 {
|
||||
return compareGreater, true
|
||||
}
|
||||
if stringobj1 == stringobj2 {
|
||||
return compareEqual, true
|
||||
}
|
||||
if stringobj1 < stringobj2 {
|
||||
return compareLess, true
|
||||
}
|
||||
}
|
||||
// Check for known struct types we can check for compare results.
|
||||
case reflect.Struct:
|
||||
{
|
||||
// All structs enter here. We're not interested in most types.
|
||||
if !canConvert(obj1Value, timeType) {
|
||||
break
|
||||
}
|
||||
|
||||
// time.Time can compared!
|
||||
timeObj1, ok := obj1.(time.Time)
|
||||
if !ok {
|
||||
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
|
||||
}
|
||||
|
||||
timeObj2, ok := obj2.(time.Time)
|
||||
if !ok {
|
||||
timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time)
|
||||
}
|
||||
|
||||
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
|
||||
}
|
||||
case reflect.Slice:
|
||||
{
|
||||
// We only care about the []byte type.
|
||||
if !canConvert(obj1Value, bytesType) {
|
||||
break
|
||||
}
|
||||
|
||||
// []byte can be compared!
|
||||
bytesObj1, ok := obj1.([]byte)
|
||||
if !ok {
|
||||
bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
|
||||
|
||||
}
|
||||
bytesObj2, ok := obj2.([]byte)
|
||||
if !ok {
|
||||
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
|
||||
}
|
||||
|
||||
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
|
||||
}
|
||||
}
|
||||
|
||||
return compareEqual, false
|
||||
}
|
||||
|
||||
// Greater asserts that the first element is greater than the second
|
||||
//
|
||||
// assert.Greater(t, 2, 1)
|
||||
// assert.Greater(t, float64(2), float64(1))
|
||||
// assert.Greater(t, "b", "a")
|
||||
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// GreaterOrEqual asserts that the first element is greater than or equal to the second
|
||||
//
|
||||
// assert.GreaterOrEqual(t, 2, 1)
|
||||
// assert.GreaterOrEqual(t, 2, 2)
|
||||
// assert.GreaterOrEqual(t, "b", "a")
|
||||
// assert.GreaterOrEqual(t, "b", "b")
|
||||
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// Less asserts that the first element is less than the second
|
||||
//
|
||||
// assert.Less(t, 1, 2)
|
||||
// assert.Less(t, float64(1), float64(2))
|
||||
// assert.Less(t, "a", "b")
|
||||
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// LessOrEqual asserts that the first element is less than or equal to the second
|
||||
//
|
||||
// assert.LessOrEqual(t, 1, 2)
|
||||
// assert.LessOrEqual(t, 2, 2)
|
||||
// assert.LessOrEqual(t, "a", "b")
|
||||
// assert.LessOrEqual(t, "b", "b")
|
||||
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
|
||||
}
|
||||
|
||||
// Positive asserts that the specified element is positive
|
||||
//
|
||||
// assert.Positive(t, 1)
|
||||
// assert.Positive(t, 1.23)
|
||||
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
zero := reflect.Zero(reflect.TypeOf(e))
|
||||
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...)
|
||||
}
|
||||
|
||||
// Negative asserts that the specified element is negative
|
||||
//
|
||||
// assert.Negative(t, -1)
|
||||
// assert.Negative(t, -1.23)
|
||||
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
zero := reflect.Zero(reflect.TypeOf(e))
|
||||
return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...)
|
||||
}
|
||||
|
||||
func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
e1Kind := reflect.ValueOf(e1).Kind()
|
||||
e2Kind := reflect.ValueOf(e2).Kind()
|
||||
if e1Kind != e2Kind {
|
||||
return Fail(t, "Elements should be the same type", msgAndArgs...)
|
||||
}
|
||||
|
||||
compareResult, isComparable := compare(e1, e2, e1Kind)
|
||||
if !isComparable {
|
||||
return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...)
|
||||
}
|
||||
|
||||
if !containsValue(allowedComparesResults, compareResult) {
|
||||
return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func containsValue(values []CompareType, value CompareType) bool {
|
||||
for _, v := range values {
|
||||
if v == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// CompareErrors asserts two errors
|
||||
func CompareErrors(err1, err2 error) bool {
|
||||
if err1 == nil && err2 == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if err1 == nil || err2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err1.Error() != err2.Error() {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
18
internal/assert/assertion_compare_can_convert.go
Normal file
18
internal/assert/assertion_compare_can_convert.go
Normal file
@ -0,0 +1,18 @@
|
||||
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_can_convert.go
|
||||
|
||||
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style license that can be found in
|
||||
// the THIRD-PARTY-NOTICES file.
|
||||
|
||||
//go:build go1.17
|
||||
// +build go1.17
|
||||
|
||||
package assert
|
||||
|
||||
import "reflect"
|
||||
|
||||
// Wrapper around reflect.Value.CanConvert, for compatibility
|
||||
// reasons.
|
||||
func canConvert(value reflect.Value, to reflect.Type) bool {
|
||||
return value.CanConvert(to)
|
||||
}
|
184
internal/assert/assertion_compare_go1.17_test.go
Normal file
184
internal/assert/assertion_compare_go1.17_test.go
Normal file
@ -0,0 +1,184 @@
|
||||
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_go1.17_test.go
|
||||
|
||||
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style license that can be found in
|
||||
// the THIRD-PARTY-NOTICES file.
|
||||
|
||||
//go:build go1.17
|
||||
// +build go1.17
|
||||
|
||||
package assert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCompare17(t *testing.T) {
|
||||
type customTime time.Time
|
||||
type customBytes []byte
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
cType string
|
||||
}{
|
||||
{less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"},
|
||||
{less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"},
|
||||
{less: []byte{1, 1}, greater: []byte{1, 2}, cType: "[]byte"},
|
||||
{less: customBytes([]byte{1, 1}), greater: customBytes([]byte{1, 2}), cType: "[]byte"},
|
||||
} {
|
||||
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
|
||||
if !isComparable {
|
||||
t.Error("object should be comparable for type " + currCase.cType)
|
||||
}
|
||||
|
||||
if resLess != compareLess {
|
||||
t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType,
|
||||
currCase.less, currCase.greater)
|
||||
}
|
||||
|
||||
resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())
|
||||
if !isComparable {
|
||||
t.Error("object are comparable for type " + currCase.cType)
|
||||
}
|
||||
|
||||
if resGreater != compareGreater {
|
||||
t.Errorf("object greater should be greater than less for type " + currCase.cType)
|
||||
}
|
||||
|
||||
resEqual, isComparable := compare(currCase.less, currCase.less, reflect.ValueOf(currCase.less).Kind())
|
||||
if !isComparable {
|
||||
t.Error("object are comparable for type " + currCase.cType)
|
||||
}
|
||||
|
||||
if resEqual != 0 {
|
||||
t.Errorf("objects should be equal for type " + currCase.cType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGreater17(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !Greater(mockT, 2, 1) {
|
||||
t.Error("Greater should return true")
|
||||
}
|
||||
|
||||
if Greater(mockT, 1, 1) {
|
||||
t.Error("Greater should return false")
|
||||
}
|
||||
|
||||
if Greater(mockT, 1, 2) {
|
||||
t.Error("Greater should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
msg string
|
||||
}{
|
||||
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than "[1 2]"`},
|
||||
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than "0001-01-01 01:00:00 +0000 UTC"`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, Greater(out, currCase.less, currCase.greater))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Greater")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGreaterOrEqual17(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !GreaterOrEqual(mockT, 2, 1) {
|
||||
t.Error("GreaterOrEqual should return true")
|
||||
}
|
||||
|
||||
if !GreaterOrEqual(mockT, 1, 1) {
|
||||
t.Error("GreaterOrEqual should return true")
|
||||
}
|
||||
|
||||
if GreaterOrEqual(mockT, 1, 2) {
|
||||
t.Error("GreaterOrEqual should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
msg string
|
||||
}{
|
||||
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than or equal to "[1 2]"`},
|
||||
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than or equal to "0001-01-01 01:00:00 +0000 UTC"`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, GreaterOrEqual(out, currCase.less, currCase.greater))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.GreaterOrEqual")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLess17(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !Less(mockT, 1, 2) {
|
||||
t.Error("Less should return true")
|
||||
}
|
||||
|
||||
if Less(mockT, 1, 1) {
|
||||
t.Error("Less should return false")
|
||||
}
|
||||
|
||||
if Less(mockT, 2, 1) {
|
||||
t.Error("Less should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
msg string
|
||||
}{
|
||||
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than "[1 1]"`},
|
||||
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than "0001-01-01 00:00:00 +0000 UTC"`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, Less(out, currCase.greater, currCase.less))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Less")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLessOrEqual17(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !LessOrEqual(mockT, 1, 2) {
|
||||
t.Error("LessOrEqual should return true")
|
||||
}
|
||||
|
||||
if !LessOrEqual(mockT, 1, 1) {
|
||||
t.Error("LessOrEqual should return true")
|
||||
}
|
||||
|
||||
if LessOrEqual(mockT, 2, 1) {
|
||||
t.Error("LessOrEqual should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
msg string
|
||||
}{
|
||||
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than or equal to "[1 1]"`},
|
||||
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than or equal to "0001-01-01 00:00:00 +0000 UTC"`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, LessOrEqual(out, currCase.greater, currCase.less))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.LessOrEqual")
|
||||
}
|
||||
}
|
18
internal/assert/assertion_compare_legacy.go
Normal file
18
internal/assert/assertion_compare_legacy.go
Normal file
@ -0,0 +1,18 @@
|
||||
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_legacy.go
|
||||
|
||||
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style license that can be found in
|
||||
// the THIRD-PARTY-NOTICES file.
|
||||
|
||||
//go:build !go1.17
|
||||
// +build !go1.17
|
||||
|
||||
package assert
|
||||
|
||||
import "reflect"
|
||||
|
||||
// Older versions of Go does not have the reflect.Value.CanConvert
|
||||
// method.
|
||||
func canConvert(value reflect.Value, to reflect.Type) bool {
|
||||
return false
|
||||
}
|
455
internal/assert/assertion_compare_test.go
Normal file
455
internal/assert/assertion_compare_test.go
Normal file
@ -0,0 +1,455 @@
|
||||
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_compare_test.go
|
||||
|
||||
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style license that can be found in
|
||||
// the THIRD-PARTY-NOTICES file.
|
||||
|
||||
package assert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompare(t *testing.T) {
|
||||
type customInt int
|
||||
type customInt8 int8
|
||||
type customInt16 int16
|
||||
type customInt32 int32
|
||||
type customInt64 int64
|
||||
type customUInt uint
|
||||
type customUInt8 uint8
|
||||
type customUInt16 uint16
|
||||
type customUInt32 uint32
|
||||
type customUInt64 uint64
|
||||
type customFloat32 float32
|
||||
type customFloat64 float64
|
||||
type customString string
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
cType string
|
||||
}{
|
||||
{less: customString("a"), greater: customString("b"), cType: "string"},
|
||||
{less: "a", greater: "b", cType: "string"},
|
||||
{less: customInt(1), greater: customInt(2), cType: "int"},
|
||||
{less: int(1), greater: int(2), cType: "int"},
|
||||
{less: customInt8(1), greater: customInt8(2), cType: "int8"},
|
||||
{less: int8(1), greater: int8(2), cType: "int8"},
|
||||
{less: customInt16(1), greater: customInt16(2), cType: "int16"},
|
||||
{less: int16(1), greater: int16(2), cType: "int16"},
|
||||
{less: customInt32(1), greater: customInt32(2), cType: "int32"},
|
||||
{less: int32(1), greater: int32(2), cType: "int32"},
|
||||
{less: customInt64(1), greater: customInt64(2), cType: "int64"},
|
||||
{less: int64(1), greater: int64(2), cType: "int64"},
|
||||
{less: customUInt(1), greater: customUInt(2), cType: "uint"},
|
||||
{less: uint8(1), greater: uint8(2), cType: "uint8"},
|
||||
{less: customUInt8(1), greater: customUInt8(2), cType: "uint8"},
|
||||
{less: uint16(1), greater: uint16(2), cType: "uint16"},
|
||||
{less: customUInt16(1), greater: customUInt16(2), cType: "uint16"},
|
||||
{less: uint32(1), greater: uint32(2), cType: "uint32"},
|
||||
{less: customUInt32(1), greater: customUInt32(2), cType: "uint32"},
|
||||
{less: uint64(1), greater: uint64(2), cType: "uint64"},
|
||||
{less: customUInt64(1), greater: customUInt64(2), cType: "uint64"},
|
||||
{less: float32(1.23), greater: float32(2.34), cType: "float32"},
|
||||
{less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"},
|
||||
{less: float64(1.23), greater: float64(2.34), cType: "float64"},
|
||||
{less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"},
|
||||
} {
|
||||
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
|
||||
if !isComparable {
|
||||
t.Error("object should be comparable for type " + currCase.cType)
|
||||
}
|
||||
|
||||
if resLess != compareLess {
|
||||
t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType,
|
||||
currCase.less, currCase.greater)
|
||||
}
|
||||
|
||||
resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind())
|
||||
if !isComparable {
|
||||
t.Error("object are comparable for type " + currCase.cType)
|
||||
}
|
||||
|
||||
if resGreater != compareGreater {
|
||||
t.Errorf("object greater should be greater than less for type " + currCase.cType)
|
||||
}
|
||||
|
||||
resEqual, isComparable := compare(currCase.less, currCase.less, reflect.ValueOf(currCase.less).Kind())
|
||||
if !isComparable {
|
||||
t.Error("object are comparable for type " + currCase.cType)
|
||||
}
|
||||
|
||||
if resEqual != 0 {
|
||||
t.Errorf("objects should be equal for type " + currCase.cType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type outputT struct {
|
||||
buf *bytes.Buffer
|
||||
helpers map[string]struct{}
|
||||
}
|
||||
|
||||
// Implements TestingT
|
||||
func (t *outputT) Errorf(format string, args ...interface{}) {
|
||||
s := fmt.Sprintf(format, args...)
|
||||
t.buf.WriteString(s)
|
||||
}
|
||||
|
||||
func (t *outputT) Helper() {
|
||||
if t.helpers == nil {
|
||||
t.helpers = make(map[string]struct{})
|
||||
}
|
||||
t.helpers[callerName(1)] = struct{}{}
|
||||
}
|
||||
|
||||
// callerName gives the function name (qualified with a package path)
|
||||
// for the caller after skip frames (where 0 means the current function).
|
||||
func callerName(skip int) string {
|
||||
// Make room for the skip PC.
|
||||
var pc [1]uintptr
|
||||
n := runtime.Callers(skip+2, pc[:]) // skip + runtime.Callers + callerName
|
||||
if n == 0 {
|
||||
panic("testing: zero callers found")
|
||||
}
|
||||
frames := runtime.CallersFrames(pc[:n])
|
||||
frame, _ := frames.Next()
|
||||
return frame.Function
|
||||
}
|
||||
|
||||
func TestGreater(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !Greater(mockT, 2, 1) {
|
||||
t.Error("Greater should return true")
|
||||
}
|
||||
|
||||
if Greater(mockT, 1, 1) {
|
||||
t.Error("Greater should return false")
|
||||
}
|
||||
|
||||
if Greater(mockT, 1, 2) {
|
||||
t.Error("Greater should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
msg string
|
||||
}{
|
||||
{less: "a", greater: "b", msg: `"a" is not greater than "b"`},
|
||||
{less: int(1), greater: int(2), msg: `"1" is not greater than "2"`},
|
||||
{less: int8(1), greater: int8(2), msg: `"1" is not greater than "2"`},
|
||||
{less: int16(1), greater: int16(2), msg: `"1" is not greater than "2"`},
|
||||
{less: int32(1), greater: int32(2), msg: `"1" is not greater than "2"`},
|
||||
{less: int64(1), greater: int64(2), msg: `"1" is not greater than "2"`},
|
||||
{less: uint8(1), greater: uint8(2), msg: `"1" is not greater than "2"`},
|
||||
{less: uint16(1), greater: uint16(2), msg: `"1" is not greater than "2"`},
|
||||
{less: uint32(1), greater: uint32(2), msg: `"1" is not greater than "2"`},
|
||||
{less: uint64(1), greater: uint64(2), msg: `"1" is not greater than "2"`},
|
||||
{less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than "2.34"`},
|
||||
{less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than "2.34"`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, Greater(out, currCase.less, currCase.greater))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Greater")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGreaterOrEqual(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !GreaterOrEqual(mockT, 2, 1) {
|
||||
t.Error("GreaterOrEqual should return true")
|
||||
}
|
||||
|
||||
if !GreaterOrEqual(mockT, 1, 1) {
|
||||
t.Error("GreaterOrEqual should return true")
|
||||
}
|
||||
|
||||
if GreaterOrEqual(mockT, 1, 2) {
|
||||
t.Error("GreaterOrEqual should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
msg string
|
||||
}{
|
||||
{less: "a", greater: "b", msg: `"a" is not greater than or equal to "b"`},
|
||||
{less: int(1), greater: int(2), msg: `"1" is not greater than or equal to "2"`},
|
||||
{less: int8(1), greater: int8(2), msg: `"1" is not greater than or equal to "2"`},
|
||||
{less: int16(1), greater: int16(2), msg: `"1" is not greater than or equal to "2"`},
|
||||
{less: int32(1), greater: int32(2), msg: `"1" is not greater than or equal to "2"`},
|
||||
{less: int64(1), greater: int64(2), msg: `"1" is not greater than or equal to "2"`},
|
||||
{less: uint8(1), greater: uint8(2), msg: `"1" is not greater than or equal to "2"`},
|
||||
{less: uint16(1), greater: uint16(2), msg: `"1" is not greater than or equal to "2"`},
|
||||
{less: uint32(1), greater: uint32(2), msg: `"1" is not greater than or equal to "2"`},
|
||||
{less: uint64(1), greater: uint64(2), msg: `"1" is not greater than or equal to "2"`},
|
||||
{less: float32(1.23), greater: float32(2.34), msg: `"1.23" is not greater than or equal to "2.34"`},
|
||||
{less: float64(1.23), greater: float64(2.34), msg: `"1.23" is not greater than or equal to "2.34"`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, GreaterOrEqual(out, currCase.less, currCase.greater))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.GreaterOrEqual")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLess(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !Less(mockT, 1, 2) {
|
||||
t.Error("Less should return true")
|
||||
}
|
||||
|
||||
if Less(mockT, 1, 1) {
|
||||
t.Error("Less should return false")
|
||||
}
|
||||
|
||||
if Less(mockT, 2, 1) {
|
||||
t.Error("Less should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
msg string
|
||||
}{
|
||||
{less: "a", greater: "b", msg: `"b" is not less than "a"`},
|
||||
{less: int(1), greater: int(2), msg: `"2" is not less than "1"`},
|
||||
{less: int8(1), greater: int8(2), msg: `"2" is not less than "1"`},
|
||||
{less: int16(1), greater: int16(2), msg: `"2" is not less than "1"`},
|
||||
{less: int32(1), greater: int32(2), msg: `"2" is not less than "1"`},
|
||||
{less: int64(1), greater: int64(2), msg: `"2" is not less than "1"`},
|
||||
{less: uint8(1), greater: uint8(2), msg: `"2" is not less than "1"`},
|
||||
{less: uint16(1), greater: uint16(2), msg: `"2" is not less than "1"`},
|
||||
{less: uint32(1), greater: uint32(2), msg: `"2" is not less than "1"`},
|
||||
{less: uint64(1), greater: uint64(2), msg: `"2" is not less than "1"`},
|
||||
{less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than "1.23"`},
|
||||
{less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than "1.23"`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, Less(out, currCase.greater, currCase.less))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Less")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLessOrEqual(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !LessOrEqual(mockT, 1, 2) {
|
||||
t.Error("LessOrEqual should return true")
|
||||
}
|
||||
|
||||
if !LessOrEqual(mockT, 1, 1) {
|
||||
t.Error("LessOrEqual should return true")
|
||||
}
|
||||
|
||||
if LessOrEqual(mockT, 2, 1) {
|
||||
t.Error("LessOrEqual should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
less interface{}
|
||||
greater interface{}
|
||||
msg string
|
||||
}{
|
||||
{less: "a", greater: "b", msg: `"b" is not less than or equal to "a"`},
|
||||
{less: int(1), greater: int(2), msg: `"2" is not less than or equal to "1"`},
|
||||
{less: int8(1), greater: int8(2), msg: `"2" is not less than or equal to "1"`},
|
||||
{less: int16(1), greater: int16(2), msg: `"2" is not less than or equal to "1"`},
|
||||
{less: int32(1), greater: int32(2), msg: `"2" is not less than or equal to "1"`},
|
||||
{less: int64(1), greater: int64(2), msg: `"2" is not less than or equal to "1"`},
|
||||
{less: uint8(1), greater: uint8(2), msg: `"2" is not less than or equal to "1"`},
|
||||
{less: uint16(1), greater: uint16(2), msg: `"2" is not less than or equal to "1"`},
|
||||
{less: uint32(1), greater: uint32(2), msg: `"2" is not less than or equal to "1"`},
|
||||
{less: uint64(1), greater: uint64(2), msg: `"2" is not less than or equal to "1"`},
|
||||
{less: float32(1.23), greater: float32(2.34), msg: `"2.34" is not less than or equal to "1.23"`},
|
||||
{less: float64(1.23), greater: float64(2.34), msg: `"2.34" is not less than or equal to "1.23"`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, LessOrEqual(out, currCase.greater, currCase.less))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.LessOrEqual")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPositive(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !Positive(mockT, 1) {
|
||||
t.Error("Positive should return true")
|
||||
}
|
||||
|
||||
if !Positive(mockT, 1.23) {
|
||||
t.Error("Positive should return true")
|
||||
}
|
||||
|
||||
if Positive(mockT, -1) {
|
||||
t.Error("Positive should return false")
|
||||
}
|
||||
|
||||
if Positive(mockT, -1.23) {
|
||||
t.Error("Positive should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
e interface{}
|
||||
msg string
|
||||
}{
|
||||
{e: int(-1), msg: `"-1" is not positive`},
|
||||
{e: int8(-1), msg: `"-1" is not positive`},
|
||||
{e: int16(-1), msg: `"-1" is not positive`},
|
||||
{e: int32(-1), msg: `"-1" is not positive`},
|
||||
{e: int64(-1), msg: `"-1" is not positive`},
|
||||
{e: float32(-1.23), msg: `"-1.23" is not positive`},
|
||||
{e: float64(-1.23), msg: `"-1.23" is not positive`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, Positive(out, currCase.e))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Positive")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNegative(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
if !Negative(mockT, -1) {
|
||||
t.Error("Negative should return true")
|
||||
}
|
||||
|
||||
if !Negative(mockT, -1.23) {
|
||||
t.Error("Negative should return true")
|
||||
}
|
||||
|
||||
if Negative(mockT, 1) {
|
||||
t.Error("Negative should return false")
|
||||
}
|
||||
|
||||
if Negative(mockT, 1.23) {
|
||||
t.Error("Negative should return false")
|
||||
}
|
||||
|
||||
// Check error report
|
||||
for _, currCase := range []struct {
|
||||
e interface{}
|
||||
msg string
|
||||
}{
|
||||
{e: int(1), msg: `"1" is not negative`},
|
||||
{e: int8(1), msg: `"1" is not negative`},
|
||||
{e: int16(1), msg: `"1" is not negative`},
|
||||
{e: int32(1), msg: `"1" is not negative`},
|
||||
{e: int64(1), msg: `"1" is not negative`},
|
||||
{e: float32(1.23), msg: `"1.23" is not negative`},
|
||||
{e: float64(1.23), msg: `"1.23" is not negative`},
|
||||
} {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
False(t, Negative(out, currCase.e))
|
||||
Contains(t, out.buf.String(), currCase.msg)
|
||||
Contains(t, out.helpers, "gitea.psichedelico.com/go/bson/internal/assert.Negative")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_compareTwoValuesDifferentValuesTypes(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
for _, currCase := range []struct {
|
||||
v1 interface{}
|
||||
v2 interface{}
|
||||
compareResult bool
|
||||
}{
|
||||
{v1: 123, v2: "abc"},
|
||||
{v1: "abc", v2: 123456},
|
||||
{v1: float64(12), v2: "123"},
|
||||
{v1: "float(12)", v2: float64(1)},
|
||||
} {
|
||||
compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, []CompareType{compareLess, compareEqual, compareGreater}, "testFailMessage")
|
||||
False(t, compareResult)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_compareTwoValuesNotComparableValues(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
type CompareStruct struct {
|
||||
}
|
||||
|
||||
for _, currCase := range []struct {
|
||||
v1 interface{}
|
||||
v2 interface{}
|
||||
}{
|
||||
{v1: CompareStruct{}, v2: CompareStruct{}},
|
||||
{v1: map[string]int{}, v2: map[string]int{}},
|
||||
{v1: make([]int, 5), v2: make([]int, 5)},
|
||||
} {
|
||||
compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, []CompareType{compareLess, compareEqual, compareGreater}, "testFailMessage")
|
||||
False(t, compareResult)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_compareTwoValuesCorrectCompareResult(t *testing.T) {
|
||||
mockT := new(testing.T)
|
||||
|
||||
for _, currCase := range []struct {
|
||||
v1 interface{}
|
||||
v2 interface{}
|
||||
compareTypes []CompareType
|
||||
}{
|
||||
{v1: 1, v2: 2, compareTypes: []CompareType{compareLess}},
|
||||
{v1: 1, v2: 2, compareTypes: []CompareType{compareLess, compareEqual}},
|
||||
{v1: 2, v2: 2, compareTypes: []CompareType{compareGreater, compareEqual}},
|
||||
{v1: 2, v2: 2, compareTypes: []CompareType{compareEqual}},
|
||||
{v1: 2, v2: 1, compareTypes: []CompareType{compareEqual, compareGreater}},
|
||||
{v1: 2, v2: 1, compareTypes: []CompareType{compareGreater}},
|
||||
} {
|
||||
compareResult := compareTwoValues(mockT, currCase.v1, currCase.v2, currCase.compareTypes, "testFailMessage")
|
||||
True(t, compareResult)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_containsValue(t *testing.T) {
|
||||
for _, currCase := range []struct {
|
||||
values []CompareType
|
||||
value CompareType
|
||||
result bool
|
||||
}{
|
||||
{values: []CompareType{compareGreater}, value: compareGreater, result: true},
|
||||
{values: []CompareType{compareGreater, compareLess}, value: compareGreater, result: true},
|
||||
{values: []CompareType{compareGreater, compareLess}, value: compareLess, result: true},
|
||||
{values: []CompareType{compareGreater, compareLess}, value: compareEqual, result: false},
|
||||
} {
|
||||
compareResult := containsValue(currCase.values, currCase.value)
|
||||
Equal(t, currCase.result, compareResult)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComparingMsgAndArgsForwarding(t *testing.T) {
|
||||
msgAndArgs := []interface{}{"format %s %x", "this", 0xc001}
|
||||
expectedOutput := "format this c001\n"
|
||||
funcs := []func(t TestingT){
|
||||
func(t TestingT) { Greater(t, 1, 2, msgAndArgs...) },
|
||||
func(t TestingT) { GreaterOrEqual(t, 1, 2, msgAndArgs...) },
|
||||
func(t TestingT) { Less(t, 2, 1, msgAndArgs...) },
|
||||
func(t TestingT) { LessOrEqual(t, 2, 1, msgAndArgs...) },
|
||||
func(t TestingT) { Positive(t, 0, msgAndArgs...) },
|
||||
func(t TestingT) { Negative(t, 0, msgAndArgs...) },
|
||||
}
|
||||
for _, f := range funcs {
|
||||
out := &outputT{buf: bytes.NewBuffer(nil)}
|
||||
f(out)
|
||||
Contains(t, out.buf.String(), expectedOutput)
|
||||
}
|
||||
}
|
325
internal/assert/assertion_format.go
Normal file
325
internal/assert/assertion_format.go
Normal file
@ -0,0 +1,325 @@
|
||||
// Copied from https://github.com/stretchr/testify/blob/1333b5d3bda8cf5aedcf3e1aaa95cac28aaab892/assert/assertion_format.go
|
||||
|
||||
// Copyright 2020 Mat Ryer, Tyler Bunnell and all contributors. All rights reserved.
|
||||
// Use of this source code is governed by an MIT-style license that can be found in
|
||||
// the THIRD-PARTY-NOTICES file.
|
||||
|
||||
package assert
|
||||
|
||||
import (
|
||||
time "time"
|
||||
)
|
||||
|
||||
// Containsf asserts that the specified string, list(array, slice...) or map contains the
|
||||
// specified substring or element.
|
||||
//
|
||||
// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted")
|
||||
// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted")
|
||||
// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted")
|
||||
func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Contains(t, s, contains, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// ElementsMatchf asserts that the specified listA(array, slice...) is equal to specified
|
||||
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
|
||||
// the number of appearances of each of them in both lists should match.
|
||||
//
|
||||
// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted")
|
||||
func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Equalf asserts that two objects are equal.
|
||||
//
|
||||
// assert.Equalf(t, 123, 123, "error message %s", "formatted")
|
||||
//
|
||||
// Pointer variable equality is determined based on the equality of the
|
||||
// referenced values (as opposed to the memory addresses). Function equality
|
||||
// cannot be determined and will always fail.
|
||||
func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Equal(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// EqualErrorf asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that it is equal to the provided error.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted")
|
||||
func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// EqualValuesf asserts that two objects are equal or convertible to the same types
|
||||
// and equal.
|
||||
//
|
||||
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
|
||||
func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return EqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Errorf asserts that a function returned an error (i.e. not `nil`).
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// if assert.Errorf(t, err, "error message %s", "formatted") {
|
||||
// assert.Equal(t, expectedErrorf, err)
|
||||
// }
|
||||
func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Error(t, err, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
|
||||
// and that the error contains the specified substring.
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
|
||||
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return ErrorContains(t, theError, contains, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Eventuallyf asserts that given condition will be met in waitFor time,
|
||||
// periodically checking target function each tick.
|
||||
//
|
||||
// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
|
||||
func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Eventually(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Failf reports a failure through
|
||||
func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Fail(t, failureMessage, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// FailNowf fails test
|
||||
func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return FailNow(t, failureMessage, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Falsef asserts that the specified value is false.
|
||||
//
|
||||
// assert.Falsef(t, myBool, "error message %s", "formatted")
|
||||
func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return False(t, value, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Greaterf asserts that the first element is greater than the second
|
||||
//
|
||||
// assert.Greaterf(t, 2, 1, "error message %s", "formatted")
|
||||
// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted")
|
||||
// assert.Greaterf(t, "b", "a", "error message %s", "formatted")
|
||||
func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Greater(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// GreaterOrEqualf asserts that the first element is greater than or equal to the second
|
||||
//
|
||||
// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted")
|
||||
// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted")
|
||||
// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted")
|
||||
// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted")
|
||||
func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return GreaterOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// InDeltaf asserts that the two numerals are within delta of each other.
|
||||
//
|
||||
// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted")
|
||||
func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return InDelta(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// IsTypef asserts that the specified objects are of the same type.
|
||||
func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return IsType(t, expectedType, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Lenf asserts that the specified object has specific length.
|
||||
// Lenf also fails if the object has a type that len() not accept.
|
||||
//
|
||||
// assert.Lenf(t, mySlice, 3, "error message %s", "formatted")
|
||||
func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Len(t, object, length, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Lessf asserts that the first element is less than the second
|
||||
//
|
||||
// assert.Lessf(t, 1, 2, "error message %s", "formatted")
|
||||
// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted")
|
||||
// assert.Lessf(t, "a", "b", "error message %s", "formatted")
|
||||
func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Less(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// LessOrEqualf asserts that the first element is less than or equal to the second
|
||||
//
|
||||
// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted")
|
||||
// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted")
|
||||
// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted")
|
||||
// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted")
|
||||
func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return LessOrEqual(t, e1, e2, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Negativef asserts that the specified element is negative
|
||||
//
|
||||
// assert.Negativef(t, -1, "error message %s", "formatted")
|
||||
// assert.Negativef(t, -1.23, "error message %s", "formatted")
|
||||
func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Negative(t, e, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Nilf asserts that the specified object is nil.
|
||||
//
|
||||
// assert.Nilf(t, err, "error message %s", "formatted")
|
||||
func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Nil(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NoErrorf asserts that a function returned no error (i.e. `nil`).
|
||||
//
|
||||
// actualObj, err := SomeFunction()
|
||||
// if assert.NoErrorf(t, err, "error message %s", "formatted") {
|
||||
// assert.Equal(t, expectedObj, actualObj)
|
||||
// }
|
||||
func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NoError(t, err, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the
|
||||
// specified substring or element.
|
||||
//
|
||||
// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted")
|
||||
// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted")
|
||||
// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted")
|
||||
func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotContains(t, s, contains, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotEqualf asserts that the specified values are NOT equal.
|
||||
//
|
||||
// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted")
|
||||
//
|
||||
// Pointer variable equality is determined based on the equality of the
|
||||
// referenced values (as opposed to the memory addresses).
|
||||
func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotEqualValuesf asserts that two objects are not equal even when converted to the same type
|
||||
//
|
||||
// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted")
|
||||
func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// NotNilf asserts that the specified object is not nil.
|
||||
//
|
||||
// assert.NotNilf(t, err, "error message %s", "formatted")
|
||||
func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return NotNil(t, object, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Positivef asserts that the specified element is positive
|
||||
//
|
||||
// assert.Positivef(t, 1, "error message %s", "formatted")
|
||||
// assert.Positivef(t, 1.23, "error message %s", "formatted")
|
||||
func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return Positive(t, e, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// Truef asserts that the specified value is true.
|
||||
//
|
||||
// assert.Truef(t, myBool, "error message %s", "formatted")
|
||||
func Truef(t TestingT, value bool, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return True(t, value, append([]interface{}{msg}, args...)...)
|
||||
}
|
||||
|
||||
// WithinDurationf asserts that the two times are within duration delta of each other.
|
||||
//
|
||||
// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted")
|
||||
func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
|
||||
}
|
126
internal/assert/assertion_mongo.go
Normal file
126
internal/assert/assertion_mongo.go
Normal file
@ -0,0 +1,126 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// assertion_mongo.go contains MongoDB-specific extensions to the "assert"
|
||||
// package.
|
||||
|
||||
package assert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// DifferentAddressRanges asserts that two byte slices reference distinct memory
|
||||
// address ranges, meaning they reference different underlying byte arrays.
|
||||
func DifferentAddressRanges(t TestingT, a, b []byte) (ok bool) {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Find the start and end memory addresses for the underlying byte array for
|
||||
// each input byte slice.
|
||||
sliceAddrRange := func(b []byte) (uintptr, uintptr) {
|
||||
sh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
|
||||
return sh.Data, sh.Data + uintptr(sh.Cap-1)
|
||||
}
|
||||
aStart, aEnd := sliceAddrRange(a)
|
||||
bStart, bEnd := sliceAddrRange(b)
|
||||
|
||||
// If "b" starts after "a" ends or "a" starts after "b" ends, there is no
|
||||
// overlap.
|
||||
if bStart > aEnd || aStart > bEnd {
|
||||
return true
|
||||
}
|
||||
|
||||
// Otherwise, calculate the overlap start and end and print the memory
|
||||
// overlap error message.
|
||||
min := func(a, b uintptr) uintptr {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
max := func(a, b uintptr) uintptr {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
overlapLow := max(aStart, bStart)
|
||||
overlapHigh := min(aEnd, bEnd)
|
||||
|
||||
t.Errorf("Byte slices point to the same underlying byte array:\n"+
|
||||
"\ta addresses:\t%d ... %d\n"+
|
||||
"\tb addresses:\t%d ... %d\n"+
|
||||
"\toverlap:\t%d ... %d",
|
||||
aStart, aEnd,
|
||||
bStart, bEnd,
|
||||
overlapLow, overlapHigh)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// EqualBSON asserts that the expected and actual BSON binary values are equal.
|
||||
// If the values are not equal, it prints both the binary and Extended JSON diff
|
||||
// of the BSON values. The provided BSON value types must implement the
|
||||
// fmt.Stringer interface.
|
||||
func EqualBSON(t TestingT, expected, actual interface{}) bool {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
return Equal(t,
|
||||
expected,
|
||||
actual,
|
||||
`expected and actual BSON values do not match
|
||||
As Extended JSON:
|
||||
Expected: %s
|
||||
Actual : %s`,
|
||||
expected.(fmt.Stringer).String(),
|
||||
actual.(fmt.Stringer).String())
|
||||
}
|
||||
|
||||
// Soon runs the provided callback and fails the passed-in test if the callback
|
||||
// does not complete within timeout. The provided callback should respect the
|
||||
// passed-in context and cease execution when it has expired.
|
||||
//
|
||||
// Deprecated: This function will be removed with GODRIVER-2667, use
|
||||
// assert.Eventually instead.
|
||||
func Soon(t TestingT, callback func(ctx context.Context), timeout time.Duration) {
|
||||
if h, ok := t.(tHelper); ok {
|
||||
h.Helper()
|
||||
}
|
||||
|
||||
// Create context to manually cancel callback after Soon assertion.
|
||||
callbackCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
fullCallback := func() {
|
||||
callback(callbackCtx)
|
||||
done <- struct{}{}
|
||||
}
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
go fullCallback()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-timer.C:
|
||||
t.Errorf("timed out in %s waiting for callback", timeout)
|
||||
}
|
||||
}
|
125
internal/assert/assertion_mongo_test.go
Normal file
125
internal/assert/assertion_mongo_test.go
Normal file
@ -0,0 +1,125 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package assert
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
)
|
||||
|
||||
func TestDifferentAddressRanges(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
slice := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
a []byte
|
||||
b []byte
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "distinct byte slices",
|
||||
a: []byte{0, 1, 2, 3},
|
||||
b: []byte{0, 1, 2, 3},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "same byte slice",
|
||||
a: slice,
|
||||
b: slice,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "whole and subslice",
|
||||
a: slice,
|
||||
b: slice[:4],
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "two subslices",
|
||||
a: slice[1:2],
|
||||
b: slice[3:4],
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
a: []byte{0, 1, 2, 3},
|
||||
b: []byte{},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nil",
|
||||
a: []byte{0, 1, 2, 3},
|
||||
b: nil,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // Capture range variable.
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := DifferentAddressRanges(new(testing.T), tc.a, tc.b)
|
||||
if got != tc.want {
|
||||
t.Errorf("DifferentAddressRanges(%p, %p) = %v, want %v", tc.a, tc.b, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqualBSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
expected interface{}
|
||||
actual interface{}
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "equal bson.Raw",
|
||||
expected: bson.Raw{5, 0, 0, 0, 0},
|
||||
actual: bson.Raw{5, 0, 0, 0, 0},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "different bson.Raw",
|
||||
expected: bson.Raw{8, 0, 0, 0, 10, 120, 0, 0},
|
||||
actual: bson.Raw{5, 0, 0, 0, 0},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid bson.Raw",
|
||||
expected: bson.Raw{99, 99, 99, 99},
|
||||
actual: bson.Raw{5, 0, 0, 0, 0},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil bson.Raw",
|
||||
expected: bson.Raw(nil),
|
||||
actual: bson.Raw(nil),
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // Capture range variable.
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := EqualBSON(new(testing.T), tc.expected, tc.actual)
|
||||
if got != tc.want {
|
||||
t.Errorf("EqualBSON(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1075
internal/assert/assertions.go
Normal file
1075
internal/assert/assertions.go
Normal file
File diff suppressed because it is too large
Load Diff
1231
internal/assert/assertions_test.go
Normal file
1231
internal/assert/assertions_test.go
Normal file
File diff suppressed because it is too large
Load Diff
766
internal/assert/difflib.go
Normal file
766
internal/assert/difflib.go
Normal file
@ -0,0 +1,766 @@
|
||||
// Copied from https://github.com/pmezard/go-difflib/blob/5d4384ee4fb2527b0a1256a821ebfc92f91efefc/difflib/difflib.go
|
||||
|
||||
// Copyright 2013 Patrick Mezard. All rights reserved. Use of this source code is
|
||||
// governed by a license that can be found in the THIRD-PARTY-NOTICES file.
|
||||
|
||||
package assert
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func calculateRatio(matches, length int) float64 {
|
||||
if length > 0 {
|
||||
return 2.0 * float64(matches) / float64(length)
|
||||
}
|
||||
return 1.0
|
||||
}
|
||||
|
||||
type Match struct {
|
||||
A int
|
||||
B int
|
||||
Size int
|
||||
}
|
||||
|
||||
type OpCode struct {
|
||||
Tag byte
|
||||
I1 int
|
||||
I2 int
|
||||
J1 int
|
||||
J2 int
|
||||
}
|
||||
|
||||
// SequenceMatcher compares sequence of strings. The basic
|
||||
// algorithm predates, and is a little fancier than, an algorithm
|
||||
// published in the late 1980's by Ratcliff and Obershelp under the
|
||||
// hyperbolic name "gestalt pattern matching". The basic idea is to find
|
||||
// the longest contiguous matching subsequence that contains no "junk"
|
||||
// elements (R-O doesn't address junk). The same idea is then applied
|
||||
// recursively to the pieces of the sequences to the left and to the right
|
||||
// of the matching subsequence. This does not yield minimal edit
|
||||
// sequences, but does tend to yield matches that "look right" to people.
|
||||
//
|
||||
// SequenceMatcher tries to compute a "human-friendly diff" between two
|
||||
// sequences. Unlike e.g. UNIX(tm) diff, the fundamental notion is the
|
||||
// longest *contiguous* & junk-free matching subsequence. That's what
|
||||
// catches peoples' eyes. The Windows(tm) windiff has another interesting
|
||||
// notion, pairing up elements that appear uniquely in each sequence.
|
||||
// That, and the method here, appear to yield more intuitive difference
|
||||
// reports than does diff. This method appears to be the least vulnerable
|
||||
// to syncing up on blocks of "junk lines", though (like blank lines in
|
||||
// ordinary text files, or maybe "<P>" lines in HTML files). That may be
|
||||
// because this is the only method of the 3 that has a *concept* of
|
||||
// "junk" <wink>.
|
||||
//
|
||||
// Timing: Basic R-O is cubic time worst case and quadratic time expected
|
||||
// case. SequenceMatcher is quadratic time for the worst case and has
|
||||
// expected-case behavior dependent in a complicated way on how many
|
||||
// elements the sequences have in common; best case time is linear.
|
||||
type SequenceMatcher struct {
|
||||
a []string
|
||||
b []string
|
||||
b2j map[string][]int
|
||||
IsJunk func(string) bool
|
||||
autoJunk bool
|
||||
bJunk map[string]struct{}
|
||||
matchingBlocks []Match
|
||||
fullBCount map[string]int
|
||||
bPopular map[string]struct{}
|
||||
opCodes []OpCode
|
||||
}
|
||||
|
||||
func NewMatcher(a, b []string) *SequenceMatcher {
|
||||
m := SequenceMatcher{autoJunk: true}
|
||||
m.SetSeqs(a, b)
|
||||
return &m
|
||||
}
|
||||
|
||||
func NewMatcherWithJunk(a, b []string, autoJunk bool,
|
||||
isJunk func(string) bool) *SequenceMatcher {
|
||||
|
||||
m := SequenceMatcher{IsJunk: isJunk, autoJunk: autoJunk}
|
||||
m.SetSeqs(a, b)
|
||||
return &m
|
||||
}
|
||||
|
||||
// SetSeqs sets the two sequences to be compared.
|
||||
func (m *SequenceMatcher) SetSeqs(a, b []string) {
|
||||
m.SetSeq1(a)
|
||||
m.SetSeq2(b)
|
||||
}
|
||||
|
||||
// SetSeq1 sets the first sequence to be compared. The second sequence to be compared is
|
||||
// not changed.
|
||||
//
|
||||
// SequenceMatcher computes and caches detailed information about the second
|
||||
// sequence, so if you want to compare one sequence S against many sequences,
|
||||
// use .SetSeq2(s) once and call .SetSeq1(x) repeatedly for each of the other
|
||||
// sequences.
|
||||
//
|
||||
// See also SetSeqs() and SetSeq2().
|
||||
func (m *SequenceMatcher) SetSeq1(a []string) {
|
||||
if &a == &m.a {
|
||||
return
|
||||
}
|
||||
m.a = a
|
||||
m.matchingBlocks = nil
|
||||
m.opCodes = nil
|
||||
}
|
||||
|
||||
// SetSeq2 sets the second sequence to be compared. The first sequence to be compared is
|
||||
// not changed.
|
||||
func (m *SequenceMatcher) SetSeq2(b []string) {
|
||||
if &b == &m.b {
|
||||
return
|
||||
}
|
||||
m.b = b
|
||||
m.matchingBlocks = nil
|
||||
m.opCodes = nil
|
||||
m.fullBCount = nil
|
||||
m.chainB()
|
||||
}
|
||||
|
||||
func (m *SequenceMatcher) chainB() {
|
||||
// Populate line -> index mapping
|
||||
b2j := map[string][]int{}
|
||||
for i, s := range m.b {
|
||||
indices := b2j[s]
|
||||
indices = append(indices, i)
|
||||
b2j[s] = indices
|
||||
}
|
||||
|
||||
// Purge junk elements
|
||||
m.bJunk = map[string]struct{}{}
|
||||
if m.IsJunk != nil {
|
||||
junk := m.bJunk
|
||||
for s := range b2j {
|
||||
if m.IsJunk(s) {
|
||||
junk[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
for s := range junk {
|
||||
delete(b2j, s)
|
||||
}
|
||||
}
|
||||
|
||||
// Purge remaining popular elements
|
||||
popular := map[string]struct{}{}
|
||||
n := len(m.b)
|
||||
if m.autoJunk && n >= 200 {
|
||||
ntest := n/100 + 1
|
||||
for s, indices := range b2j {
|
||||
if len(indices) > ntest {
|
||||
popular[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
for s := range popular {
|
||||
delete(b2j, s)
|
||||
}
|
||||
}
|
||||
m.bPopular = popular
|
||||
m.b2j = b2j
|
||||
}
|
||||
|
||||
func (m *SequenceMatcher) isBJunk(s string) bool {
|
||||
_, ok := m.bJunk[s]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Find longest matching block in a[alo:ahi] and b[blo:bhi].
|
||||
//
|
||||
// If IsJunk is not defined:
|
||||
//
|
||||
// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where
|
||||
//
|
||||
// alo <= i <= i+k <= ahi
|
||||
// blo <= j <= j+k <= bhi
|
||||
//
|
||||
// and for all (i',j',k') meeting those conditions,
|
||||
//
|
||||
// k >= k'
|
||||
// i <= i'
|
||||
// and if i == i', j <= j'
|
||||
//
|
||||
// In other words, of all maximal matching blocks, return one that
|
||||
// starts earliest in a, and of all those maximal matching blocks that
|
||||
// start earliest in a, return the one that starts earliest in b.
|
||||
//
|
||||
// If IsJunk is defined, first the longest matching block is
|
||||
// determined as above, but with the additional restriction that no
|
||||
// junk element appears in the block. Then that block is extended as
|
||||
// far as possible by matching (only) junk elements on both sides. So
|
||||
// the resulting block never matches on junk except as identical junk
|
||||
// happens to be adjacent to an "interesting" match.
|
||||
//
|
||||
// If no blocks match, return (alo, blo, 0).
|
||||
func (m *SequenceMatcher) findLongestMatch(alo, ahi, blo, bhi int) Match {
|
||||
// CAUTION: stripping common prefix or suffix would be incorrect.
|
||||
// E.g.,
|
||||
// ab
|
||||
// acab
|
||||
// Longest matching block is "ab", but if common prefix is
|
||||
// stripped, it's "a" (tied with "b"). UNIX(tm) diff does so
|
||||
// strip, so ends up claiming that ab is changed to acab by
|
||||
// inserting "ca" in the middle. That's minimal but unintuitive:
|
||||
// "it's obvious" that someone inserted "ac" at the front.
|
||||
// Windiff ends up at the same place as diff, but by pairing up
|
||||
// the unique 'b's and then matching the first two 'a's.
|
||||
besti, bestj, bestsize := alo, blo, 0
|
||||
|
||||
// find longest junk-free match
|
||||
// during an iteration of the loop, j2len[j] = length of longest
|
||||
// junk-free match ending with a[i-1] and b[j]
|
||||
j2len := map[int]int{}
|
||||
for i := alo; i != ahi; i++ {
|
||||
// look at all instances of a[i] in b; note that because
|
||||
// b2j has no junk keys, the loop is skipped if a[i] is junk
|
||||
newj2len := map[int]int{}
|
||||
for _, j := range m.b2j[m.a[i]] {
|
||||
// a[i] matches b[j]
|
||||
if j < blo {
|
||||
continue
|
||||
}
|
||||
if j >= bhi {
|
||||
break
|
||||
}
|
||||
k := j2len[j-1] + 1
|
||||
newj2len[j] = k
|
||||
if k > bestsize {
|
||||
besti, bestj, bestsize = i-k+1, j-k+1, k
|
||||
}
|
||||
}
|
||||
j2len = newj2len
|
||||
}
|
||||
|
||||
// Extend the best by non-junk elements on each end. In particular,
|
||||
// "popular" non-junk elements aren't in b2j, which greatly speeds
|
||||
// the inner loop above, but also means "the best" match so far
|
||||
// doesn't contain any junk *or* popular non-junk elements.
|
||||
for besti > alo && bestj > blo && !m.isBJunk(m.b[bestj-1]) &&
|
||||
m.a[besti-1] == m.b[bestj-1] {
|
||||
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
|
||||
}
|
||||
for besti+bestsize < ahi && bestj+bestsize < bhi &&
|
||||
!m.isBJunk(m.b[bestj+bestsize]) &&
|
||||
m.a[besti+bestsize] == m.b[bestj+bestsize] {
|
||||
bestsize++
|
||||
}
|
||||
|
||||
// Now that we have a wholly interesting match (albeit possibly
|
||||
// empty!), we may as well suck up the matching junk on each
|
||||
// side of it too. Can't think of a good reason not to, and it
|
||||
// saves post-processing the (possibly considerable) expense of
|
||||
// figuring out what to do with it. In the case of an empty
|
||||
// interesting match, this is clearly the right thing to do,
|
||||
// because no other kind of match is possible in the regions.
|
||||
for besti > alo && bestj > blo && m.isBJunk(m.b[bestj-1]) &&
|
||||
m.a[besti-1] == m.b[bestj-1] {
|
||||
besti, bestj, bestsize = besti-1, bestj-1, bestsize+1
|
||||
}
|
||||
for besti+bestsize < ahi && bestj+bestsize < bhi &&
|
||||
m.isBJunk(m.b[bestj+bestsize]) &&
|
||||
m.a[besti+bestsize] == m.b[bestj+bestsize] {
|
||||
bestsize++
|
||||
}
|
||||
|
||||
return Match{A: besti, B: bestj, Size: bestsize}
|
||||
}
|
||||
|
||||
// GetMatchingBlocks returns list of triples describing matching subsequences.
|
||||
//
|
||||
// Each triple is of the form (i, j, n), and means that
|
||||
// a[i:i+n] == b[j:j+n]. The triples are monotonically increasing in
|
||||
// i and in j. It's also guaranteed that if (i, j, n) and (i', j', n') are
|
||||
// adjacent triples in the list, and the second is not the last triple in the
|
||||
// list, then i+n != i' or j+n != j'. IOW, adjacent triples never describe
|
||||
// adjacent equal blocks.
|
||||
//
|
||||
// The last triple is a dummy, (len(a), len(b), 0), and is the only
|
||||
// triple with n==0.
|
||||
func (m *SequenceMatcher) GetMatchingBlocks() []Match {
|
||||
if m.matchingBlocks != nil {
|
||||
return m.matchingBlocks
|
||||
}
|
||||
|
||||
var matchBlocks func(alo, ahi, blo, bhi int, matched []Match) []Match
|
||||
matchBlocks = func(alo, ahi, blo, bhi int, matched []Match) []Match {
|
||||
match := m.findLongestMatch(alo, ahi, blo, bhi)
|
||||
i, j, k := match.A, match.B, match.Size
|
||||
if match.Size > 0 {
|
||||
if alo < i && blo < j {
|
||||
matched = matchBlocks(alo, i, blo, j, matched)
|
||||
}
|
||||
matched = append(matched, match)
|
||||
if i+k < ahi && j+k < bhi {
|
||||
matched = matchBlocks(i+k, ahi, j+k, bhi, matched)
|
||||
}
|
||||
}
|
||||
return matched
|
||||
}
|
||||
matched := matchBlocks(0, len(m.a), 0, len(m.b), nil)
|
||||
|
||||
// It's possible that we have adjacent equal blocks in the
|
||||
// matching_blocks list now.
|
||||
nonAdjacent := []Match{}
|
||||
i1, j1, k1 := 0, 0, 0
|
||||
for _, b := range matched {
|
||||
// Is this block adjacent to i1, j1, k1?
|
||||
i2, j2, k2 := b.A, b.B, b.Size
|
||||
if i1+k1 == i2 && j1+k1 == j2 {
|
||||
// Yes, so collapse them -- this just increases the length of
|
||||
// the first block by the length of the second, and the first
|
||||
// block so lengthened remains the block to compare against.
|
||||
k1 += k2
|
||||
} else {
|
||||
// Not adjacent. Remember the first block (k1==0 means it's
|
||||
// the dummy we started with), and make the second block the
|
||||
// new block to compare against.
|
||||
if k1 > 0 {
|
||||
nonAdjacent = append(nonAdjacent, Match{i1, j1, k1})
|
||||
}
|
||||
i1, j1, k1 = i2, j2, k2
|
||||
}
|
||||
}
|
||||
if k1 > 0 {
|
||||
nonAdjacent = append(nonAdjacent, Match{i1, j1, k1})
|
||||
}
|
||||
|
||||
nonAdjacent = append(nonAdjacent, Match{len(m.a), len(m.b), 0})
|
||||
m.matchingBlocks = nonAdjacent
|
||||
return m.matchingBlocks
|
||||
}
|
||||
|
||||
// GetOpCodes returns a list of 5-tuples describing how to turn a into b.
|
||||
//
|
||||
// Each tuple is of the form (tag, i1, i2, j1, j2). The first tuple
|
||||
// has i1 == j1 == 0, and remaining tuples have i1 == the i2 from the
|
||||
// tuple preceding it, and likewise for j1 == the previous j2.
|
||||
//
|
||||
// The tags are characters, with these meanings:
|
||||
//
|
||||
// 'r' (replace): a[i1:i2] should be replaced by b[j1:j2]
|
||||
//
|
||||
// 'd' (delete): a[i1:i2] should be deleted, j1==j2 in this case.
|
||||
//
|
||||
// 'i' (insert): b[j1:j2] should be inserted at a[i1:i1], i1==i2 in this case.
|
||||
//
|
||||
// 'e' (equal): a[i1:i2] == b[j1:j2]
|
||||
func (m *SequenceMatcher) GetOpCodes() []OpCode {
|
||||
if m.opCodes != nil {
|
||||
return m.opCodes
|
||||
}
|
||||
i, j := 0, 0
|
||||
matching := m.GetMatchingBlocks()
|
||||
opCodes := make([]OpCode, 0, len(matching))
|
||||
for _, m := range matching {
|
||||
// invariant: we've pumped out correct diffs to change
|
||||
// a[:i] into b[:j], and the next matching block is
|
||||
// a[ai:ai+size] == b[bj:bj+size]. So we need to pump
|
||||
// out a diff to change a[i:ai] into b[j:bj], pump out
|
||||
// the matching block, and move (i,j) beyond the match
|
||||
ai, bj, size := m.A, m.B, m.Size
|
||||
tag := byte(0)
|
||||
if i < ai && j < bj {
|
||||
tag = 'r'
|
||||
} else if i < ai {
|
||||
tag = 'd'
|
||||
} else if j < bj {
|
||||
tag = 'i'
|
||||
}
|
||||
if tag > 0 {
|
||||
opCodes = append(opCodes, OpCode{tag, i, ai, j, bj})
|
||||
}
|
||||
i, j = ai+size, bj+size
|
||||
// the list of matching blocks is terminated by a
|
||||
// sentinel with size 0
|
||||
if size > 0 {
|
||||
opCodes = append(opCodes, OpCode{'e', ai, i, bj, j})
|
||||
}
|
||||
}
|
||||
m.opCodes = opCodes
|
||||
return m.opCodes
|
||||
}
|
||||
|
||||
// GetGroupedOpCodes isolates change clusters by eliminating ranges with no changes.
|
||||
//
|
||||
// Returns a generator of groups with up to n lines of context.
|
||||
// Each group is in the same format as returned by GetOpCodes().
|
||||
func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
|
||||
if n < 0 {
|
||||
n = 3
|
||||
}
|
||||
codes := m.GetOpCodes()
|
||||
if len(codes) == 0 {
|
||||
codes = []OpCode{{'e', 0, 1, 0, 1}}
|
||||
}
|
||||
// Fixup leading and trailing groups if they show no changes.
|
||||
if codes[0].Tag == 'e' {
|
||||
c := codes[0]
|
||||
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||
codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2}
|
||||
}
|
||||
if codes[len(codes)-1].Tag == 'e' {
|
||||
c := codes[len(codes)-1]
|
||||
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||
codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)}
|
||||
}
|
||||
nn := n + n
|
||||
groups := [][]OpCode{}
|
||||
group := []OpCode{}
|
||||
for _, c := range codes {
|
||||
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||
// End the current group and start a new one whenever
|
||||
// there is a large range with no changes.
|
||||
if c.Tag == 'e' && i2-i1 > nn {
|
||||
group = append(group, OpCode{c.Tag, i1, min(i2, i1+n),
|
||||
j1, min(j2, j1+n)})
|
||||
groups = append(groups, group)
|
||||
group = []OpCode{}
|
||||
i1, j1 = max(i1, i2-n), max(j1, j2-n)
|
||||
}
|
||||
group = append(group, OpCode{c.Tag, i1, i2, j1, j2})
|
||||
}
|
||||
if len(group) > 0 && !(len(group) == 1 && group[0].Tag == 'e') {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
// Ratio returns a measure of the sequences' similarity (float in [0,1]).
|
||||
//
|
||||
// Where T is the total number of elements in both sequences, and
|
||||
// M is the number of matches, this is 2.0*M / T.
|
||||
// Note that this is 1 if the sequences are identical, and 0 if
|
||||
// they have nothing in common.
|
||||
//
|
||||
// .Ratio() is expensive to compute if you haven't already computed
|
||||
// .GetMatchingBlocks() or .GetOpCodes(), in which case you may
|
||||
// want to try .QuickRatio() or .RealQuickRation() first to get an
|
||||
// upper bound.
|
||||
func (m *SequenceMatcher) Ratio() float64 {
|
||||
matches := 0
|
||||
for _, m := range m.GetMatchingBlocks() {
|
||||
matches += m.Size
|
||||
}
|
||||
return calculateRatio(matches, len(m.a)+len(m.b))
|
||||
}
|
||||
|
||||
// QuickRatio returns an upper bound on ratio() relatively quickly.
|
||||
//
|
||||
// This isn't defined beyond that it is an upper bound on .Ratio(), and
|
||||
// is faster to compute.
|
||||
func (m *SequenceMatcher) QuickRatio() float64 {
|
||||
// viewing a and b as multisets, set matches to the cardinality
|
||||
// of their intersection; this counts the number of matches
|
||||
// without regard to order, so is clearly an upper bound
|
||||
if m.fullBCount == nil {
|
||||
m.fullBCount = map[string]int{}
|
||||
for _, s := range m.b {
|
||||
m.fullBCount[s] = m.fullBCount[s] + 1
|
||||
}
|
||||
}
|
||||
|
||||
// avail[x] is the number of times x appears in 'b' less the
|
||||
// number of times we've seen it in 'a' so far ... kinda
|
||||
avail := map[string]int{}
|
||||
matches := 0
|
||||
for _, s := range m.a {
|
||||
n, ok := avail[s]
|
||||
if !ok {
|
||||
n = m.fullBCount[s]
|
||||
}
|
||||
avail[s] = n - 1
|
||||
if n > 0 {
|
||||
matches++
|
||||
}
|
||||
}
|
||||
return calculateRatio(matches, len(m.a)+len(m.b))
|
||||
}
|
||||
|
||||
// RealQuickRatio returns an upper bound on ratio() very quickly.
|
||||
//
|
||||
// This isn't defined beyond that it is an upper bound on .Ratio(), and
|
||||
// is faster to compute than either .Ratio() or .QuickRatio().
|
||||
func (m *SequenceMatcher) RealQuickRatio() float64 {
|
||||
la, lb := len(m.a), len(m.b)
|
||||
return calculateRatio(min(la, lb), la+lb)
|
||||
}
|
||||
|
||||
// Convert range to the "ed" format
|
||||
func formatRangeUnified(start, stop int) string {
|
||||
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||
beginning := start + 1 // lines start numbering with one
|
||||
length := stop - start
|
||||
if length == 1 {
|
||||
return fmt.Sprintf("%d", beginning)
|
||||
}
|
||||
if length == 0 {
|
||||
beginning-- // empty ranges begin at line just before the range
|
||||
}
|
||||
return fmt.Sprintf("%d,%d", beginning, length)
|
||||
}
|
||||
|
||||
// UnifiedDiff represents the unified diff parameters.
|
||||
type UnifiedDiff struct {
|
||||
A []string // First sequence lines
|
||||
FromFile string // First file name
|
||||
FromDate string // First file time
|
||||
B []string // Second sequence lines
|
||||
ToFile string // Second file name
|
||||
ToDate string // Second file time
|
||||
Eol string // Headers end of line, defaults to LF
|
||||
Context int // Number of context lines
|
||||
}
|
||||
|
||||
// WriteUnifiedDiff compares two sequences of lines; generates the delta as
|
||||
// a unified diff.
|
||||
//
|
||||
// Unified diffs are a compact way of showing line changes and a few
|
||||
// lines of context. The number of context lines is set by 'n' which
|
||||
// defaults to three.
|
||||
//
|
||||
// By default, the diff control lines (those with ---, +++, or @@) are
|
||||
// created with a trailing newline. This is helpful so that inputs
|
||||
// created from file.readlines() result in diffs that are suitable for
|
||||
// file.writelines() since both the inputs and outputs have trailing
|
||||
// newlines.
|
||||
//
|
||||
// For inputs that do not have trailing newlines, set the lineterm
|
||||
// argument to "" so that the output will be uniformly newline free.
|
||||
//
|
||||
// The unidiff format normally has a header for filenames and modification
|
||||
// times. Any or all of these may be specified using strings for
|
||||
// 'fromfile', 'tofile', 'fromfiledate', and 'tofiledate'.
|
||||
// The modification times are normally expressed in the ISO 8601 format.
|
||||
func WriteUnifiedDiff(writer io.Writer, diff UnifiedDiff) error {
|
||||
buf := bufio.NewWriter(writer)
|
||||
defer buf.Flush()
|
||||
wf := func(format string, args ...interface{}) error {
|
||||
_, err := buf.WriteString(fmt.Sprintf(format, args...))
|
||||
return err
|
||||
}
|
||||
ws := func(s string) error {
|
||||
_, err := buf.WriteString(s)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(diff.Eol) == 0 {
|
||||
diff.Eol = "\n"
|
||||
}
|
||||
|
||||
started := false
|
||||
m := NewMatcher(diff.A, diff.B)
|
||||
for _, g := range m.GetGroupedOpCodes(diff.Context) {
|
||||
if !started {
|
||||
started = true
|
||||
fromDate := ""
|
||||
if len(diff.FromDate) > 0 {
|
||||
fromDate = "\t" + diff.FromDate
|
||||
}
|
||||
toDate := ""
|
||||
if len(diff.ToDate) > 0 {
|
||||
toDate = "\t" + diff.ToDate
|
||||
}
|
||||
if diff.FromFile != "" || diff.ToFile != "" {
|
||||
err := wf("--- %s%s%s", diff.FromFile, fromDate, diff.Eol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = wf("+++ %s%s%s", diff.ToFile, toDate, diff.Eol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
first, last := g[0], g[len(g)-1]
|
||||
range1 := formatRangeUnified(first.I1, last.I2)
|
||||
range2 := formatRangeUnified(first.J1, last.J2)
|
||||
if err := wf("@@ -%s +%s @@%s", range1, range2, diff.Eol); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range g {
|
||||
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
|
||||
if c.Tag == 'e' {
|
||||
for _, line := range diff.A[i1:i2] {
|
||||
if err := ws(" " + line); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if c.Tag == 'r' || c.Tag == 'd' {
|
||||
for _, line := range diff.A[i1:i2] {
|
||||
if err := ws("-" + line); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if c.Tag == 'r' || c.Tag == 'i' {
|
||||
for _, line := range diff.B[j1:j2] {
|
||||
if err := ws("+" + line); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUnifiedDiffString is like WriteUnifiedDiff but returns the diff as a string.
|
||||
func GetUnifiedDiffString(diff UnifiedDiff) (string, error) {
|
||||
w := &bytes.Buffer{}
|
||||
err := WriteUnifiedDiff(w, diff)
|
||||
return w.String(), err
|
||||
}
|
||||
|
||||
// Convert range to the "ed" format.
|
||||
func formatRangeContext(start, stop int) string {
|
||||
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||
beginning := start + 1 // lines start numbering with one
|
||||
length := stop - start
|
||||
if length == 0 {
|
||||
beginning-- // empty ranges begin at line just before the range
|
||||
}
|
||||
if length <= 1 {
|
||||
return fmt.Sprintf("%d", beginning)
|
||||
}
|
||||
return fmt.Sprintf("%d,%d", beginning, beginning+length-1)
|
||||
}
|
||||
|
||||
type ContextDiff UnifiedDiff
|
||||
|
||||
// WriteContextDiff compares two sequences of lines; generates the delta as a context diff.
|
||||
//
|
||||
// Context diffs are a compact way of showing line changes and a few
|
||||
// lines of context. The number of context lines is set by diff.Context
|
||||
// which defaults to three.
|
||||
//
|
||||
// By default, the diff control lines (those with *** or ---) are
|
||||
// created with a trailing newline.
|
||||
//
|
||||
// For inputs that do not have trailing newlines, set the diff.Eol
|
||||
// argument to "" so that the output will be uniformly newline free.
|
||||
//
|
||||
// The context diff format normally has a header for filenames and
|
||||
// modification times. Any or all of these may be specified using
|
||||
// strings for diff.FromFile, diff.ToFile, diff.FromDate, diff.ToDate.
|
||||
// The modification times are normally expressed in the ISO 8601 format.
|
||||
// If not specified, the strings default to blanks.
|
||||
func WriteContextDiff(writer io.Writer, diff ContextDiff) error {
|
||||
buf := bufio.NewWriter(writer)
|
||||
defer buf.Flush()
|
||||
var diffErr error
|
||||
wf := func(format string, args ...interface{}) {
|
||||
_, err := buf.WriteString(fmt.Sprintf(format, args...))
|
||||
if diffErr == nil && err != nil {
|
||||
diffErr = err
|
||||
}
|
||||
}
|
||||
ws := func(s string) {
|
||||
_, err := buf.WriteString(s)
|
||||
if diffErr == nil && err != nil {
|
||||
diffErr = err
|
||||
}
|
||||
}
|
||||
|
||||
if len(diff.Eol) == 0 {
|
||||
diff.Eol = "\n"
|
||||
}
|
||||
|
||||
prefix := map[byte]string{
|
||||
'i': "+ ",
|
||||
'd': "- ",
|
||||
'r': "! ",
|
||||
'e': " ",
|
||||
}
|
||||
|
||||
started := false
|
||||
m := NewMatcher(diff.A, diff.B)
|
||||
for _, g := range m.GetGroupedOpCodes(diff.Context) {
|
||||
if !started {
|
||||
started = true
|
||||
fromDate := ""
|
||||
if len(diff.FromDate) > 0 {
|
||||
fromDate = "\t" + diff.FromDate
|
||||
}
|
||||
toDate := ""
|
||||
if len(diff.ToDate) > 0 {
|
||||
toDate = "\t" + diff.ToDate
|
||||
}
|
||||
if diff.FromFile != "" || diff.ToFile != "" {
|
||||
wf("*** %s%s%s", diff.FromFile, fromDate, diff.Eol)
|
||||
wf("--- %s%s%s", diff.ToFile, toDate, diff.Eol)
|
||||
}
|
||||
}
|
||||
|
||||
first, last := g[0], g[len(g)-1]
|
||||
ws("***************" + diff.Eol)
|
||||
|
||||
range1 := formatRangeContext(first.I1, last.I2)
|
||||
wf("*** %s ****%s", range1, diff.Eol)
|
||||
for _, c := range g {
|
||||
if c.Tag == 'r' || c.Tag == 'd' {
|
||||
for _, cc := range g {
|
||||
if cc.Tag == 'i' {
|
||||
continue
|
||||
}
|
||||
for _, line := range diff.A[cc.I1:cc.I2] {
|
||||
ws(prefix[cc.Tag] + line)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
range2 := formatRangeContext(first.J1, last.J2)
|
||||
wf("--- %s ----%s", range2, diff.Eol)
|
||||
for _, c := range g {
|
||||
if c.Tag == 'r' || c.Tag == 'i' {
|
||||
for _, cc := range g {
|
||||
if cc.Tag == 'd' {
|
||||
continue
|
||||
}
|
||||
for _, line := range diff.B[cc.J1:cc.J2] {
|
||||
ws(prefix[cc.Tag] + line)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return diffErr
|
||||
}
|
||||
|
||||
// GetContextDiffString is like WriteContextDiff but returns the diff as a string.
|
||||
func GetContextDiffString(diff ContextDiff) (string, error) {
|
||||
w := &bytes.Buffer{}
|
||||
err := WriteContextDiff(w, diff)
|
||||
return w.String(), err
|
||||
}
|
||||
|
||||
// SplitLines splits a string on "\n" while preserving them. The output can be used
|
||||
// as input for UnifiedDiff and ContextDiff structures.
|
||||
func SplitLines(s string) []string {
|
||||
lines := strings.SplitAfter(s, "\n")
|
||||
lines[len(lines)-1] += "\n"
|
||||
return lines
|
||||
}
|
326
internal/assert/difflib_test.go
Normal file
326
internal/assert/difflib_test.go
Normal file
@ -0,0 +1,326 @@
|
||||
// Copied from https://github.com/pmezard/go-difflib/blob/5d4384ee4fb2527b0a1256a821ebfc92f91efefc/difflib/difflib_test.go
|
||||
|
||||
// Copyright 2013 Patrick Mezard. All rights reserved. Use of this source code is
|
||||
// governed by a license that can be found in the THIRD-PARTY-NOTICES file.
|
||||
|
||||
package assert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func assertAlmostEqual(t *testing.T, a, b float64, places int) {
|
||||
if math.Abs(a-b) > math.Pow10(-places) {
|
||||
t.Errorf("%.7f != %.7f", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, a, b interface{}) {
|
||||
if !reflect.DeepEqual(a, b) {
|
||||
t.Errorf("%v != %v", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func splitChars(s string) []string {
|
||||
chars := make([]string, 0, len(s))
|
||||
// Assume ASCII inputs
|
||||
for i := 0; i != len(s); i++ {
|
||||
chars = append(chars, string(s[i]))
|
||||
}
|
||||
return chars
|
||||
}
|
||||
|
||||
func TestSequenceMatcherRatio(t *testing.T) {
|
||||
s := NewMatcher(splitChars("abcd"), splitChars("bcde"))
|
||||
assertEqual(t, s.Ratio(), 0.75)
|
||||
assertEqual(t, s.QuickRatio(), 0.75)
|
||||
assertEqual(t, s.RealQuickRatio(), 1.0)
|
||||
}
|
||||
|
||||
func TestGetOptCodes(t *testing.T) {
|
||||
a := "qabxcd"
|
||||
b := "abycdf"
|
||||
s := NewMatcher(splitChars(a), splitChars(b))
|
||||
w := &bytes.Buffer{}
|
||||
for _, op := range s.GetOpCodes() {
|
||||
fmt.Fprintf(w, "%s a[%d:%d], (%s) b[%d:%d] (%s)\n", string(op.Tag),
|
||||
op.I1, op.I2, a[op.I1:op.I2], op.J1, op.J2, b[op.J1:op.J2])
|
||||
}
|
||||
result := w.String()
|
||||
expected := `d a[0:1], (q) b[0:0] ()
|
||||
e a[1:3], (ab) b[0:2] (ab)
|
||||
r a[3:4], (x) b[2:3] (y)
|
||||
e a[4:6], (cd) b[3:5] (cd)
|
||||
i a[6:6], () b[5:6] (f)
|
||||
`
|
||||
if expected != result {
|
||||
t.Errorf("unexpected op codes: \n%s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupedOpCodes(t *testing.T) {
|
||||
a := []string{}
|
||||
for i := 0; i != 39; i++ {
|
||||
a = append(a, fmt.Sprintf("%02d", i))
|
||||
}
|
||||
b := []string{}
|
||||
b = append(b, a[:8]...)
|
||||
b = append(b, " i")
|
||||
b = append(b, a[8:19]...)
|
||||
b = append(b, " x")
|
||||
b = append(b, a[20:22]...)
|
||||
b = append(b, a[27:34]...)
|
||||
b = append(b, " y")
|
||||
b = append(b, a[35:]...)
|
||||
s := NewMatcher(a, b)
|
||||
w := &bytes.Buffer{}
|
||||
for _, g := range s.GetGroupedOpCodes(-1) {
|
||||
fmt.Fprintf(w, "group\n")
|
||||
for _, op := range g {
|
||||
fmt.Fprintf(w, " %s, %d, %d, %d, %d\n", string(op.Tag),
|
||||
op.I1, op.I2, op.J1, op.J2)
|
||||
}
|
||||
}
|
||||
result := w.String()
|
||||
expected := `group
|
||||
e, 5, 8, 5, 8
|
||||
i, 8, 8, 8, 9
|
||||
e, 8, 11, 9, 12
|
||||
group
|
||||
e, 16, 19, 17, 20
|
||||
r, 19, 20, 20, 21
|
||||
e, 20, 22, 21, 23
|
||||
d, 22, 27, 23, 23
|
||||
e, 27, 30, 23, 26
|
||||
group
|
||||
e, 31, 34, 27, 30
|
||||
r, 34, 35, 30, 31
|
||||
e, 35, 38, 31, 34
|
||||
`
|
||||
if expected != result {
|
||||
t.Errorf("unexpected op codes: \n%s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func rep(s string, count int) string {
|
||||
return strings.Repeat(s, count)
|
||||
}
|
||||
|
||||
func TestWithAsciiOneInsert(t *testing.T) {
|
||||
sm := NewMatcher(splitChars(rep("b", 100)),
|
||||
splitChars("a"+rep("b", 100)))
|
||||
assertAlmostEqual(t, sm.Ratio(), 0.995, 3)
|
||||
assertEqual(t, sm.GetOpCodes(),
|
||||
[]OpCode{{'i', 0, 0, 0, 1}, {'e', 0, 100, 1, 101}})
|
||||
assertEqual(t, len(sm.bPopular), 0)
|
||||
|
||||
sm = NewMatcher(splitChars(rep("b", 100)),
|
||||
splitChars(rep("b", 50)+"a"+rep("b", 50)))
|
||||
assertAlmostEqual(t, sm.Ratio(), 0.995, 3)
|
||||
assertEqual(t, sm.GetOpCodes(),
|
||||
[]OpCode{{'e', 0, 50, 0, 50}, {'i', 50, 50, 50, 51}, {'e', 50, 100, 51, 101}})
|
||||
assertEqual(t, len(sm.bPopular), 0)
|
||||
}
|
||||
|
||||
func TestWithAsciiOnDelete(t *testing.T) {
|
||||
sm := NewMatcher(splitChars(rep("a", 40)+"c"+rep("b", 40)),
|
||||
splitChars(rep("a", 40)+rep("b", 40)))
|
||||
assertAlmostEqual(t, sm.Ratio(), 0.994, 3)
|
||||
assertEqual(t, sm.GetOpCodes(),
|
||||
[]OpCode{{'e', 0, 40, 0, 40}, {'d', 40, 41, 40, 40}, {'e', 41, 81, 40, 80}})
|
||||
}
|
||||
|
||||
func TestWithAsciiBJunk(t *testing.T) {
|
||||
isJunk := func(s string) bool {
|
||||
return s == " "
|
||||
}
|
||||
sm := NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)),
|
||||
splitChars(rep("a", 44)+rep("b", 40)), true, isJunk)
|
||||
assertEqual(t, sm.bJunk, map[string]struct{}{})
|
||||
|
||||
sm = NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)),
|
||||
splitChars(rep("a", 44)+rep("b", 40)+rep(" ", 20)), false, isJunk)
|
||||
assertEqual(t, sm.bJunk, map[string]struct{}{" ": struct{}{}})
|
||||
|
||||
isJunk = func(s string) bool {
|
||||
return s == " " || s == "b"
|
||||
}
|
||||
sm = NewMatcherWithJunk(splitChars(rep("a", 40)+rep("b", 40)),
|
||||
splitChars(rep("a", 44)+rep("b", 40)+rep(" ", 20)), false, isJunk)
|
||||
assertEqual(t, sm.bJunk, map[string]struct{}{" ": struct{}{}, "b": struct{}{}})
|
||||
}
|
||||
|
||||
func TestSFBugsRatioForNullSeqn(t *testing.T) {
|
||||
sm := NewMatcher(nil, nil)
|
||||
assertEqual(t, sm.Ratio(), 1.0)
|
||||
assertEqual(t, sm.QuickRatio(), 1.0)
|
||||
assertEqual(t, sm.RealQuickRatio(), 1.0)
|
||||
}
|
||||
|
||||
func TestSFBugsComparingEmptyLists(t *testing.T) {
|
||||
groups := NewMatcher(nil, nil).GetGroupedOpCodes(-1)
|
||||
assertEqual(t, len(groups), 0)
|
||||
diff := UnifiedDiff{
|
||||
FromFile: "Original",
|
||||
ToFile: "Current",
|
||||
Context: 3,
|
||||
}
|
||||
result, err := GetUnifiedDiffString(diff)
|
||||
assertEqual(t, err, nil)
|
||||
assertEqual(t, result, "")
|
||||
}
|
||||
|
||||
func TestOutputFormatRangeFormatUnified(t *testing.T) {
|
||||
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||
//
|
||||
// Each <range> field shall be of the form:
|
||||
// %1d", <beginning line number> if the range contains exactly one line,
|
||||
// and:
|
||||
// "%1d,%1d", <beginning line number>, <number of lines> otherwise.
|
||||
// If a range is empty, its beginning line number shall be the number of
|
||||
// the line just before the range, or 0 if the empty range starts the file.
|
||||
fm := formatRangeUnified
|
||||
assertEqual(t, fm(3, 3), "3,0")
|
||||
assertEqual(t, fm(3, 4), "4")
|
||||
assertEqual(t, fm(3, 5), "4,2")
|
||||
assertEqual(t, fm(3, 6), "4,3")
|
||||
assertEqual(t, fm(0, 0), "0,0")
|
||||
}
|
||||
|
||||
func TestOutputFormatRangeFormatContext(t *testing.T) {
|
||||
// Per the diff spec at http://www.unix.org/single_unix_specification/
|
||||
//
|
||||
// The range of lines in file1 shall be written in the following format
|
||||
// if the range contains two or more lines:
|
||||
// "*** %d,%d ****\n", <beginning line number>, <ending line number>
|
||||
// and the following format otherwise:
|
||||
// "*** %d ****\n", <ending line number>
|
||||
// The ending line number of an empty range shall be the number of the preceding line,
|
||||
// or 0 if the range is at the start of the file.
|
||||
//
|
||||
// Next, the range of lines in file2 shall be written in the following format
|
||||
// if the range contains two or more lines:
|
||||
// "--- %d,%d ----\n", <beginning line number>, <ending line number>
|
||||
// and the following format otherwise:
|
||||
// "--- %d ----\n", <ending line number>
|
||||
fm := formatRangeContext
|
||||
assertEqual(t, fm(3, 3), "3")
|
||||
assertEqual(t, fm(3, 4), "4")
|
||||
assertEqual(t, fm(3, 5), "4,5")
|
||||
assertEqual(t, fm(3, 6), "4,6")
|
||||
assertEqual(t, fm(0, 0), "0")
|
||||
}
|
||||
|
||||
func TestOutputFormatTabDelimiter(t *testing.T) {
|
||||
diff := UnifiedDiff{
|
||||
A: splitChars("one"),
|
||||
B: splitChars("two"),
|
||||
FromFile: "Original",
|
||||
FromDate: "2005-01-26 23:30:50",
|
||||
ToFile: "Current",
|
||||
ToDate: "2010-04-12 10:20:52",
|
||||
Eol: "\n",
|
||||
}
|
||||
ud, err := GetUnifiedDiffString(diff)
|
||||
assertEqual(t, err, nil)
|
||||
assertEqual(t, SplitLines(ud)[:2], []string{
|
||||
"--- Original\t2005-01-26 23:30:50\n",
|
||||
"+++ Current\t2010-04-12 10:20:52\n",
|
||||
})
|
||||
cd, err := GetContextDiffString(ContextDiff(diff))
|
||||
assertEqual(t, err, nil)
|
||||
assertEqual(t, SplitLines(cd)[:2], []string{
|
||||
"*** Original\t2005-01-26 23:30:50\n",
|
||||
"--- Current\t2010-04-12 10:20:52\n",
|
||||
})
|
||||
}
|
||||
|
||||
func TestOutputFormatNoTrailingTabOnEmptyFiledate(t *testing.T) {
|
||||
diff := UnifiedDiff{
|
||||
A: splitChars("one"),
|
||||
B: splitChars("two"),
|
||||
FromFile: "Original",
|
||||
ToFile: "Current",
|
||||
Eol: "\n",
|
||||
}
|
||||
ud, err := GetUnifiedDiffString(diff)
|
||||
assertEqual(t, err, nil)
|
||||
assertEqual(t, SplitLines(ud)[:2], []string{"--- Original\n", "+++ Current\n"})
|
||||
|
||||
cd, err := GetContextDiffString(ContextDiff(diff))
|
||||
assertEqual(t, err, nil)
|
||||
assertEqual(t, SplitLines(cd)[:2], []string{"*** Original\n", "--- Current\n"})
|
||||
}
|
||||
|
||||
func TestOmitFilenames(t *testing.T) {
|
||||
diff := UnifiedDiff{
|
||||
A: SplitLines("o\nn\ne\n"),
|
||||
B: SplitLines("t\nw\no\n"),
|
||||
Eol: "\n",
|
||||
}
|
||||
ud, err := GetUnifiedDiffString(diff)
|
||||
assertEqual(t, err, nil)
|
||||
assertEqual(t, SplitLines(ud), []string{
|
||||
"@@ -0,0 +1,2 @@\n",
|
||||
"+t\n",
|
||||
"+w\n",
|
||||
"@@ -2,2 +3,0 @@\n",
|
||||
"-n\n",
|
||||
"-e\n",
|
||||
"\n",
|
||||
})
|
||||
|
||||
cd, err := GetContextDiffString(ContextDiff(diff))
|
||||
assertEqual(t, err, nil)
|
||||
assertEqual(t, SplitLines(cd), []string{
|
||||
"***************\n",
|
||||
"*** 0 ****\n",
|
||||
"--- 1,2 ----\n",
|
||||
"+ t\n",
|
||||
"+ w\n",
|
||||
"***************\n",
|
||||
"*** 2,3 ****\n",
|
||||
"- n\n",
|
||||
"- e\n",
|
||||
"--- 3 ----\n",
|
||||
"\n",
|
||||
})
|
||||
}
|
||||
|
||||
func TestSplitLines(t *testing.T) {
|
||||
allTests := []struct {
|
||||
input string
|
||||
want []string
|
||||
}{
|
||||
{"foo", []string{"foo\n"}},
|
||||
{"foo\nbar", []string{"foo\n", "bar\n"}},
|
||||
{"foo\nbar\n", []string{"foo\n", "bar\n", "\n"}},
|
||||
}
|
||||
for _, test := range allTests {
|
||||
assertEqual(t, SplitLines(test.input), test.want)
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkSplitLines(b *testing.B, count int) {
|
||||
str := strings.Repeat("foo\n", count)
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
n := 0
|
||||
for i := 0; i < b.N; i++ {
|
||||
n += len(SplitLines(str))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSplitLines100(b *testing.B) {
|
||||
benchmarkSplitLines(b, 100)
|
||||
}
|
||||
|
||||
func BenchmarkSplitLines10000(b *testing.B) {
|
||||
benchmarkSplitLines(b, 10000)
|
||||
}
|
60
internal/aws/awserr/error.go
Normal file
60
internal/aws/awserr/error.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/awserr/error.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
// Package awserr represents API error interface accessors for the SDK.
|
||||
package awserr
|
||||
|
||||
// An Error wraps lower level errors with code, message and an original error.
|
||||
// The underlying concrete error type may also satisfy other interfaces which
|
||||
// can be to used to obtain more specific information about the error.
|
||||
type Error interface {
|
||||
// Satisfy the generic error interface.
|
||||
error
|
||||
|
||||
// Returns the short phrase depicting the classification of the error.
|
||||
Code() string
|
||||
|
||||
// Returns the error details message.
|
||||
Message() string
|
||||
|
||||
// Returns the original error if one was set. Nil is returned if not set.
|
||||
OrigErr() error
|
||||
}
|
||||
|
||||
// BatchedErrors is a batch of errors which also wraps lower level errors with
|
||||
// code, message, and original errors. Calling Error() will include all errors
|
||||
// that occurred in the batch.
|
||||
//
|
||||
// Replaces BatchError
|
||||
type BatchedErrors interface {
|
||||
// Satisfy the base Error interface.
|
||||
Error
|
||||
|
||||
// Returns the original error if one was set. Nil is returned if not set.
|
||||
OrigErrs() []error
|
||||
}
|
||||
|
||||
// New returns an Error object described by the code, message, and origErr.
|
||||
//
|
||||
// If origErr satisfies the Error interface it will not be wrapped within a new
|
||||
// Error object and will instead be returned.
|
||||
func New(code, message string, origErr error) Error {
|
||||
var errs []error
|
||||
if origErr != nil {
|
||||
errs = append(errs, origErr)
|
||||
}
|
||||
return newBaseError(code, message, errs)
|
||||
}
|
||||
|
||||
// NewBatchError returns an BatchedErrors with a collection of errors as an
|
||||
// array of errors.
|
||||
func NewBatchError(code, message string, errs []error) BatchedErrors {
|
||||
return newBaseError(code, message, errs)
|
||||
}
|
144
internal/aws/awserr/types.go
Normal file
144
internal/aws/awserr/types.go
Normal file
@ -0,0 +1,144 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/awserr/types.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package awserr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SprintError returns a string of the formatted error code.
|
||||
//
|
||||
// Both extra and origErr are optional. If they are included their lines
|
||||
// will be added, but if they are not included their lines will be ignored.
|
||||
func SprintError(code, message, extra string, origErr error) string {
|
||||
msg := fmt.Sprintf("%s: %s", code, message)
|
||||
if extra != "" {
|
||||
msg = fmt.Sprintf("%s\n\t%s", msg, extra)
|
||||
}
|
||||
if origErr != nil {
|
||||
msg = fmt.Sprintf("%s\ncaused by: %s", msg, origErr.Error())
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// A baseError wraps the code and message which defines an error. It also
|
||||
// can be used to wrap an original error object.
|
||||
//
|
||||
// Should be used as the root for errors satisfying the awserr.Error. Also
|
||||
// for any error which does not fit into a specific error wrapper type.
|
||||
type baseError struct {
|
||||
// Classification of error
|
||||
code string
|
||||
|
||||
// Detailed information about error
|
||||
message string
|
||||
|
||||
// Optional original error this error is based off of. Allows building
|
||||
// chained errors.
|
||||
errs []error
|
||||
}
|
||||
|
||||
// newBaseError returns an error object for the code, message, and errors.
|
||||
//
|
||||
// code is a short no whitespace phrase depicting the classification of
|
||||
// the error that is being created.
|
||||
//
|
||||
// message is the free flow string containing detailed information about the
|
||||
// error.
|
||||
//
|
||||
// origErrs is the error objects which will be nested under the new errors to
|
||||
// be returned.
|
||||
func newBaseError(code, message string, origErrs []error) *baseError {
|
||||
b := &baseError{
|
||||
code: code,
|
||||
message: message,
|
||||
errs: origErrs,
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
//
|
||||
// See ErrorWithExtra for formatting.
|
||||
//
|
||||
// Satisfies the error interface.
|
||||
func (b baseError) Error() string {
|
||||
size := len(b.errs)
|
||||
if size > 0 {
|
||||
return SprintError(b.code, b.message, "", errorList(b.errs))
|
||||
}
|
||||
|
||||
return SprintError(b.code, b.message, "", nil)
|
||||
}
|
||||
|
||||
// String returns the string representation of the error.
|
||||
// Alias for Error to satisfy the stringer interface.
|
||||
func (b baseError) String() string {
|
||||
return b.Error()
|
||||
}
|
||||
|
||||
// Code returns the short phrase depicting the classification of the error.
|
||||
func (b baseError) Code() string {
|
||||
return b.code
|
||||
}
|
||||
|
||||
// Message returns the error details message.
|
||||
func (b baseError) Message() string {
|
||||
return b.message
|
||||
}
|
||||
|
||||
// OrigErr returns the original error if one was set. Nil is returned if no
|
||||
// error was set. This only returns the first element in the list. If the full
|
||||
// list is needed, use BatchedErrors.
|
||||
func (b baseError) OrigErr() error {
|
||||
switch len(b.errs) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
return b.errs[0]
|
||||
default:
|
||||
if err, ok := b.errs[0].(Error); ok {
|
||||
return NewBatchError(err.Code(), err.Message(), b.errs[1:])
|
||||
}
|
||||
return NewBatchError("BatchedErrors",
|
||||
"multiple errors occurred", b.errs)
|
||||
}
|
||||
}
|
||||
|
||||
// OrigErrs returns the original errors if one was set. An empty slice is
|
||||
// returned if no error was set.
|
||||
func (b baseError) OrigErrs() []error {
|
||||
return b.errs
|
||||
}
|
||||
|
||||
// An error list that satisfies the golang interface
|
||||
type errorList []error
|
||||
|
||||
// Error returns the string representation of the error.
|
||||
//
|
||||
// Satisfies the error interface.
|
||||
func (e errorList) Error() string {
|
||||
msg := ""
|
||||
// How do we want to handle the array size being zero
|
||||
if size := len(e); size > 0 {
|
||||
for i := 0; i < size; i++ {
|
||||
msg += e[i].Error()
|
||||
// We check the next index to see if it is within the slice.
|
||||
// If it is, then we append a newline. We do this, because unit tests
|
||||
// could be broken with the additional '\n'
|
||||
if i+1 < size {
|
||||
msg += "\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
return msg
|
||||
}
|
72
internal/aws/credentials/chain_provider.go
Normal file
72
internal/aws/credentials/chain_provider.go
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
|
||||
)
|
||||
|
||||
// A ChainProvider will search for a provider which returns credentials
|
||||
// and cache that provider until Retrieve is called again.
|
||||
//
|
||||
// The ChainProvider provides a way of chaining multiple providers together
|
||||
// which will pick the first available using priority order of the Providers
|
||||
// in the list.
|
||||
//
|
||||
// If none of the Providers retrieve valid credentials Value, ChainProvider's
|
||||
// Retrieve() will return the error ErrNoValidProvidersFoundInChain.
|
||||
//
|
||||
// If a Provider is found which returns valid credentials Value ChainProvider
|
||||
// will cache that Provider for all calls to IsExpired(), until Retrieve is
|
||||
// called again.
|
||||
type ChainProvider struct {
|
||||
Providers []Provider
|
||||
curr Provider
|
||||
}
|
||||
|
||||
// NewChainCredentials returns a pointer to a new Credentials object
|
||||
// wrapping a chain of providers.
|
||||
func NewChainCredentials(providers []Provider) *Credentials {
|
||||
return NewCredentials(&ChainProvider{
|
||||
Providers: append([]Provider{}, providers...),
|
||||
})
|
||||
}
|
||||
|
||||
// Retrieve returns the credentials value or error if no provider returned
|
||||
// without error.
|
||||
//
|
||||
// If a provider is found it will be cached and any calls to IsExpired()
|
||||
// will return the expired state of the cached provider.
|
||||
func (c *ChainProvider) Retrieve() (Value, error) {
|
||||
var errs = make([]error, 0, len(c.Providers))
|
||||
for _, p := range c.Providers {
|
||||
creds, err := p.Retrieve()
|
||||
if err == nil {
|
||||
c.curr = p
|
||||
return creds, nil
|
||||
}
|
||||
errs = append(errs, err)
|
||||
}
|
||||
c.curr = nil
|
||||
|
||||
var err = awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
|
||||
return Value{}, err
|
||||
}
|
||||
|
||||
// IsExpired will returned the expired state of the currently cached provider
|
||||
// if there is one. If there is no current provider, true will be returned.
|
||||
func (c *ChainProvider) IsExpired() bool {
|
||||
if c.curr != nil {
|
||||
return c.curr.IsExpired()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
176
internal/aws/credentials/chain_provider_test.go
Normal file
176
internal/aws/credentials/chain_provider_test.go
Normal file
@ -0,0 +1,176 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/chain_provider_test.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
|
||||
)
|
||||
|
||||
type secondStubProvider struct {
|
||||
creds Value
|
||||
expired bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *secondStubProvider) Retrieve() (Value, error) {
|
||||
s.expired = false
|
||||
s.creds.ProviderName = "secondStubProvider"
|
||||
return s.creds, s.err
|
||||
}
|
||||
func (s *secondStubProvider) IsExpired() bool {
|
||||
return s.expired
|
||||
}
|
||||
|
||||
func TestChainProviderWithNames(t *testing.T) {
|
||||
p := &ChainProvider{
|
||||
Providers: []Provider{
|
||||
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
|
||||
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
|
||||
&secondStubProvider{
|
||||
creds: Value{
|
||||
AccessKeyID: "AKIF",
|
||||
SecretAccessKey: "NOSECRET",
|
||||
SessionToken: "",
|
||||
},
|
||||
},
|
||||
&stubProvider{
|
||||
creds: Value{
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "SECRET",
|
||||
SessionToken: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
creds, err := p.Retrieve()
|
||||
if err != nil {
|
||||
t.Errorf("Expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "secondStubProvider", creds.ProviderName; e != a {
|
||||
t.Errorf("Expect provider name to match, %v got, %v", e, a)
|
||||
}
|
||||
|
||||
// Also check credentials
|
||||
if e, a := "AKIF", creds.AccessKeyID; e != a {
|
||||
t.Errorf("Expect access key ID to match, %v got %v", e, a)
|
||||
}
|
||||
if e, a := "NOSECRET", creds.SecretAccessKey; e != a {
|
||||
t.Errorf("Expect secret access key to match, %v got %v", e, a)
|
||||
}
|
||||
if v := creds.SessionToken; len(v) != 0 {
|
||||
t.Errorf("Expect session token to be empty, %v", v)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestChainProviderGet(t *testing.T) {
|
||||
p := &ChainProvider{
|
||||
Providers: []Provider{
|
||||
&stubProvider{err: awserr.New("FirstError", "first provider error", nil)},
|
||||
&stubProvider{err: awserr.New("SecondError", "second provider error", nil)},
|
||||
&stubProvider{
|
||||
creds: Value{
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "SECRET",
|
||||
SessionToken: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
creds, err := p.Retrieve()
|
||||
if err != nil {
|
||||
t.Errorf("Expect no error, got %v", err)
|
||||
}
|
||||
if e, a := "AKID", creds.AccessKeyID; e != a {
|
||||
t.Errorf("Expect access key ID to match, %v got %v", e, a)
|
||||
}
|
||||
if e, a := "SECRET", creds.SecretAccessKey; e != a {
|
||||
t.Errorf("Expect secret access key to match, %v got %v", e, a)
|
||||
}
|
||||
if v := creds.SessionToken; len(v) != 0 {
|
||||
t.Errorf("Expect session token to be empty, %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainProviderIsExpired(t *testing.T) {
|
||||
stubProvider := &stubProvider{expired: true}
|
||||
p := &ChainProvider{
|
||||
Providers: []Provider{
|
||||
stubProvider,
|
||||
},
|
||||
}
|
||||
|
||||
if !p.IsExpired() {
|
||||
t.Errorf("Expect expired to be true before any Retrieve")
|
||||
}
|
||||
_, err := p.Retrieve()
|
||||
if err != nil {
|
||||
t.Errorf("Expect no error, got %v", err)
|
||||
}
|
||||
if p.IsExpired() {
|
||||
t.Errorf("Expect not expired after retrieve")
|
||||
}
|
||||
|
||||
stubProvider.expired = true
|
||||
if !p.IsExpired() {
|
||||
t.Errorf("Expect return of expired provider")
|
||||
}
|
||||
|
||||
_, err = p.Retrieve()
|
||||
if err != nil {
|
||||
t.Errorf("Expect no error, got %v", err)
|
||||
}
|
||||
if p.IsExpired() {
|
||||
t.Errorf("Expect not expired after retrieve")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainProviderWithNoProvider(t *testing.T) {
|
||||
p := &ChainProvider{
|
||||
Providers: []Provider{},
|
||||
}
|
||||
|
||||
if !p.IsExpired() {
|
||||
t.Errorf("Expect expired with no providers")
|
||||
}
|
||||
_, err := p.Retrieve()
|
||||
if err.Error() != "NoCredentialProviders: no valid providers in chain" {
|
||||
t.Errorf("Expect no providers error returned, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainProviderWithNoValidProvider(t *testing.T) {
|
||||
errs := []error{
|
||||
awserr.New("FirstError", "first provider error", nil),
|
||||
awserr.New("SecondError", "second provider error", nil),
|
||||
}
|
||||
p := &ChainProvider{
|
||||
Providers: []Provider{
|
||||
&stubProvider{err: errs[0]},
|
||||
&stubProvider{err: errs[1]},
|
||||
},
|
||||
}
|
||||
|
||||
if !p.IsExpired() {
|
||||
t.Errorf("Expect expired with no providers")
|
||||
}
|
||||
_, err := p.Retrieve()
|
||||
|
||||
expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs)
|
||||
if e, a := expectErr, err; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("Expect no providers error returned, %v, got %v", e, a)
|
||||
}
|
||||
}
|
197
internal/aws/credentials/credentials.go
Normal file
197
internal/aws/credentials/credentials.go
Normal file
@ -0,0 +1,197 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// A Value is the AWS credentials value for individual credential fields.
|
||||
//
|
||||
// A Value is also used to represent Azure credentials.
|
||||
// Azure credentials only consist of an access token, which is stored in the `SessionToken` field.
|
||||
type Value struct {
|
||||
// AWS Access key ID
|
||||
AccessKeyID string
|
||||
|
||||
// AWS Secret Access Key
|
||||
SecretAccessKey string
|
||||
|
||||
// AWS Session Token
|
||||
SessionToken string
|
||||
|
||||
// Provider used to get credentials
|
||||
ProviderName string
|
||||
}
|
||||
|
||||
// HasKeys returns if the credentials Value has both AccessKeyID and
|
||||
// SecretAccessKey value set.
|
||||
func (v Value) HasKeys() bool {
|
||||
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
|
||||
}
|
||||
|
||||
// A Provider is the interface for any component which will provide credentials
|
||||
// Value. A provider is required to manage its own Expired state, and what to
|
||||
// be expired means.
|
||||
//
|
||||
// The Provider should not need to implement its own mutexes, because
|
||||
// that will be managed by Credentials.
|
||||
type Provider interface {
|
||||
// Retrieve returns nil if it successfully retrieved the value.
|
||||
// Error is returned if the value were not obtainable, or empty.
|
||||
Retrieve() (Value, error)
|
||||
|
||||
// IsExpired returns if the credentials are no longer valid, and need
|
||||
// to be retrieved.
|
||||
IsExpired() bool
|
||||
}
|
||||
|
||||
// ProviderWithContext is a Provider that can retrieve credentials with a Context
|
||||
type ProviderWithContext interface {
|
||||
Provider
|
||||
|
||||
RetrieveWithContext(context.Context) (Value, error)
|
||||
}
|
||||
|
||||
// A Credentials provides concurrency safe retrieval of AWS credentials Value.
|
||||
//
|
||||
// A Credentials is also used to fetch Azure credentials Value.
|
||||
//
|
||||
// Credentials will cache the credentials value until they expire. Once the value
|
||||
// expires the next Get will attempt to retrieve valid credentials.
|
||||
//
|
||||
// Credentials is safe to use across multiple goroutines and will manage the
|
||||
// synchronous state so the Providers do not need to implement their own
|
||||
// synchronization.
|
||||
//
|
||||
// The first Credentials.Get() will always call Provider.Retrieve() to get the
|
||||
// first instance of the credentials Value. All calls to Get() after that
|
||||
// will return the cached credentials Value until IsExpired() returns true.
|
||||
type Credentials struct {
|
||||
sf singleflight.Group
|
||||
|
||||
m sync.RWMutex
|
||||
creds Value
|
||||
provider Provider
|
||||
}
|
||||
|
||||
// NewCredentials returns a pointer to a new Credentials with the provider set.
|
||||
func NewCredentials(provider Provider) *Credentials {
|
||||
c := &Credentials{
|
||||
provider: provider,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// GetWithContext returns the credentials value, or error if the credentials
|
||||
// Value failed to be retrieved. Will return early if the passed in context is
|
||||
// canceled.
|
||||
//
|
||||
// Will return the cached credentials Value if it has not expired. If the
|
||||
// credentials Value has expired the Provider's Retrieve() will be called
|
||||
// to refresh the credentials.
|
||||
//
|
||||
// If Credentials.Expire() was called the credentials Value will be force
|
||||
// expired, and the next call to Get() will cause them to be refreshed.
|
||||
func (c *Credentials) GetWithContext(ctx context.Context) (Value, error) {
|
||||
// Check if credentials are cached, and not expired.
|
||||
select {
|
||||
case curCreds, ok := <-c.asyncIsExpired():
|
||||
// ok will only be true, of the credentials were not expired. ok will
|
||||
// be false and have no value if the credentials are expired.
|
||||
if ok {
|
||||
return curCreds, nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return Value{}, awserr.New("RequestCanceled",
|
||||
"request context canceled", ctx.Err())
|
||||
}
|
||||
|
||||
// Cannot pass context down to the actual retrieve, because the first
|
||||
// context would cancel the whole group when there is not direct
|
||||
// association of items in the group.
|
||||
resCh := c.sf.DoChan("", func() (interface{}, error) {
|
||||
return c.singleRetrieve(&suppressedContext{ctx})
|
||||
})
|
||||
select {
|
||||
case res := <-resCh:
|
||||
return res.Val.(Value), res.Err
|
||||
case <-ctx.Done():
|
||||
return Value{}, awserr.New("RequestCanceled",
|
||||
"request context canceled", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Credentials) singleRetrieve(ctx context.Context) (interface{}, error) {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
|
||||
return curCreds, nil
|
||||
}
|
||||
|
||||
var creds Value
|
||||
var err error
|
||||
if p, ok := c.provider.(ProviderWithContext); ok {
|
||||
creds, err = p.RetrieveWithContext(ctx)
|
||||
} else {
|
||||
creds, err = c.provider.Retrieve()
|
||||
}
|
||||
if err == nil {
|
||||
c.creds = creds
|
||||
}
|
||||
|
||||
return creds, err
|
||||
}
|
||||
|
||||
// asyncIsExpired returns a channel of credentials Value. If the channel is
|
||||
// closed the credentials are expired and credentials value are not empty.
|
||||
func (c *Credentials) asyncIsExpired() <-chan Value {
|
||||
ch := make(chan Value, 1)
|
||||
go func() {
|
||||
c.m.RLock()
|
||||
defer c.m.RUnlock()
|
||||
|
||||
if curCreds := c.creds; !c.isExpiredLocked(curCreds) {
|
||||
ch <- curCreds
|
||||
}
|
||||
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// isExpiredLocked helper method wrapping the definition of expired credentials.
|
||||
func (c *Credentials) isExpiredLocked(creds interface{}) bool {
|
||||
return creds == nil || creds.(Value) == Value{} || c.provider.IsExpired()
|
||||
}
|
||||
|
||||
type suppressedContext struct {
|
||||
context.Context
|
||||
}
|
||||
|
||||
func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func (s *suppressedContext) Done() <-chan struct{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *suppressedContext) Err() error {
|
||||
return nil
|
||||
}
|
192
internal/aws/credentials/credentials_test.go
Normal file
192
internal/aws/credentials/credentials_test.go
Normal file
@ -0,0 +1,192 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/credentials/credentials_test.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package credentials
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/awserr"
|
||||
)
|
||||
|
||||
func isExpired(c *Credentials) bool {
|
||||
c.m.RLock()
|
||||
defer c.m.RUnlock()
|
||||
|
||||
return c.isExpiredLocked(c.creds)
|
||||
}
|
||||
|
||||
type stubProvider struct {
|
||||
creds Value
|
||||
retrievedCount int
|
||||
expired bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubProvider) Retrieve() (Value, error) {
|
||||
s.retrievedCount++
|
||||
s.expired = false
|
||||
s.creds.ProviderName = "stubProvider"
|
||||
return s.creds, s.err
|
||||
}
|
||||
func (s *stubProvider) IsExpired() bool {
|
||||
return s.expired
|
||||
}
|
||||
|
||||
func TestCredentialsGet(t *testing.T) {
|
||||
c := NewCredentials(&stubProvider{
|
||||
creds: Value{
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "SECRET",
|
||||
SessionToken: "",
|
||||
},
|
||||
expired: true,
|
||||
})
|
||||
|
||||
creds, err := c.GetWithContext(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if e, a := "AKID", creds.AccessKeyID; e != a {
|
||||
t.Errorf("Expect access key ID to match, %v got %v", e, a)
|
||||
}
|
||||
if e, a := "SECRET", creds.SecretAccessKey; e != a {
|
||||
t.Errorf("Expect secret access key to match, %v got %v", e, a)
|
||||
}
|
||||
if v := creds.SessionToken; len(v) != 0 {
|
||||
t.Errorf("Expect session token to be empty, %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredentialsGetWithError(t *testing.T) {
|
||||
c := NewCredentials(&stubProvider{err: awserr.New("provider error", "", nil), expired: true})
|
||||
|
||||
_, err := c.GetWithContext(context.Background())
|
||||
if e, a := "provider error", err.(awserr.Error).Code(); e != a {
|
||||
t.Errorf("Expected provider error, %v got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredentialsExpire(t *testing.T) {
|
||||
stub := &stubProvider{}
|
||||
c := NewCredentials(stub)
|
||||
|
||||
stub.expired = false
|
||||
if !isExpired(c) {
|
||||
t.Errorf("Expected to start out expired")
|
||||
}
|
||||
|
||||
_, err := c.GetWithContext(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("Expected no err, got %v", err)
|
||||
}
|
||||
if isExpired(c) {
|
||||
t.Errorf("Expected not to be expired")
|
||||
}
|
||||
|
||||
stub.expired = true
|
||||
if !isExpired(c) {
|
||||
t.Errorf("Expected to be expired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredentialsGetWithProviderName(t *testing.T) {
|
||||
stub := &stubProvider{}
|
||||
|
||||
c := NewCredentials(stub)
|
||||
|
||||
creds, err := c.GetWithContext(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if e, a := creds.ProviderName, "stubProvider"; e != a {
|
||||
t.Errorf("Expected provider name to match, %v got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
type MockProvider struct {
|
||||
// The date/time when to expire on
|
||||
expiration time.Time
|
||||
|
||||
// If set will be used by IsExpired to determine the current time.
|
||||
// Defaults to time.Now if CurrentTime is not set. Available for testing
|
||||
// to be able to mock out the current time.
|
||||
CurrentTime func() time.Time
|
||||
}
|
||||
|
||||
// IsExpired returns if the credentials are expired.
|
||||
func (e *MockProvider) IsExpired() bool {
|
||||
curTime := e.CurrentTime
|
||||
if curTime == nil {
|
||||
curTime = time.Now
|
||||
}
|
||||
return e.expiration.Before(curTime())
|
||||
}
|
||||
|
||||
func (*MockProvider) Retrieve() (Value, error) {
|
||||
return Value{}, nil
|
||||
}
|
||||
|
||||
func TestCredentialsIsExpired_Race(_ *testing.T) {
|
||||
creds := NewChainCredentials([]Provider{&MockProvider{}})
|
||||
|
||||
starter := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-starter
|
||||
for i := 0; i < 100; i++ {
|
||||
isExpired(creds)
|
||||
}
|
||||
}()
|
||||
}
|
||||
close(starter)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type stubProviderConcurrent struct {
|
||||
stubProvider
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (s *stubProviderConcurrent) Retrieve() (Value, error) {
|
||||
<-s.done
|
||||
return s.stubProvider.Retrieve()
|
||||
}
|
||||
|
||||
func TestCredentialsGetConcurrent(t *testing.T) {
|
||||
stub := &stubProviderConcurrent{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
c := NewCredentials(stub)
|
||||
done := make(chan struct{})
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
go func() {
|
||||
_, err := c.GetWithContext(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("Expected no err, got %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
// Validates that a single call to Retrieve is shared between two calls to Get
|
||||
stub.done <- struct{}{}
|
||||
<-done
|
||||
<-done
|
||||
}
|
51
internal/aws/signer/v4/header_rules.go
Normal file
51
internal/aws/signer/v4/header_rules.go
Normal file
@ -0,0 +1,51 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/header_rules.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package v4
|
||||
|
||||
// validator houses a set of rule needed for validation of a
|
||||
// string value
|
||||
type rules []rule
|
||||
|
||||
// rule interface allows for more flexible rules and just simply
|
||||
// checks whether or not a value adheres to that rule
|
||||
type rule interface {
|
||||
IsValid(value string) bool
|
||||
}
|
||||
|
||||
// IsValid will iterate through all rules and see if any rules
|
||||
// apply to the value and supports nested rules
|
||||
func (r rules) IsValid(value string) bool {
|
||||
for _, rule := range r {
|
||||
if rule.IsValid(value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// mapRule generic rule for maps
|
||||
type mapRule map[string]struct{}
|
||||
|
||||
// IsValid for the map rule satisfies whether it exists in the map
|
||||
func (m mapRule) IsValid(value string) bool {
|
||||
_, ok := m[value]
|
||||
return ok
|
||||
}
|
||||
|
||||
// excludeList is a generic rule for exclude listing
|
||||
type excludeList struct {
|
||||
rule
|
||||
}
|
||||
|
||||
// IsValid for exclude list checks if the value is within the exclude list
|
||||
func (b excludeList) IsValid(value string) bool {
|
||||
return !b.rule.IsValid(value)
|
||||
}
|
80
internal/aws/signer/v4/request.go
Normal file
80
internal/aws/signer/v4/request.go
Normal file
@ -0,0 +1,80 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/request/request.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package v4
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Returns host from request
|
||||
func getHost(r *http.Request) string {
|
||||
if r.Host != "" {
|
||||
return r.Host
|
||||
}
|
||||
|
||||
if r.URL == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return r.URL.Host
|
||||
}
|
||||
|
||||
// Hostname returns u.Host, without any port number.
|
||||
//
|
||||
// If Host is an IPv6 literal with a port number, Hostname returns the
|
||||
// IPv6 literal without the square brackets. IPv6 literals may include
|
||||
// a zone identifier.
|
||||
//
|
||||
// Copied from the Go 1.8 standard library (net/url)
|
||||
func stripPort(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return hostport
|
||||
}
|
||||
if i := strings.IndexByte(hostport, ']'); i != -1 {
|
||||
return strings.TrimPrefix(hostport[:i], "[")
|
||||
}
|
||||
return hostport[:colon]
|
||||
}
|
||||
|
||||
// Port returns the port part of u.Host, without the leading colon.
|
||||
// If u.Host doesn't contain a port, Port returns an empty string.
|
||||
//
|
||||
// Copied from the Go 1.8 standard library (net/url)
|
||||
func portOnly(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return ""
|
||||
}
|
||||
if i := strings.Index(hostport, "]:"); i != -1 {
|
||||
return hostport[i+len("]:"):]
|
||||
}
|
||||
if strings.Contains(hostport, "]") {
|
||||
return ""
|
||||
}
|
||||
return hostport[colon+len(":"):]
|
||||
}
|
||||
|
||||
// Returns true if the specified URI is using the standard port
|
||||
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
|
||||
func isDefaultPort(scheme, port string) bool {
|
||||
if port == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
lowerCaseScheme := strings.ToLower(scheme)
|
||||
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
65
internal/aws/signer/v4/uri_path.go
Normal file
65
internal/aws/signer/v4/uri_path.go
Normal file
@ -0,0 +1,65 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/uri_path.go
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/private/protocol/rest/build.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package v4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Whether the byte value can be sent without escaping in AWS URLs
|
||||
var noEscape [256]bool
|
||||
|
||||
func init() {
|
||||
for i := 0; i < len(noEscape); i++ {
|
||||
// AWS expects every character except these to be escaped
|
||||
noEscape[i] = (i >= 'A' && i <= 'Z') ||
|
||||
(i >= 'a' && i <= 'z') ||
|
||||
(i >= '0' && i <= '9') ||
|
||||
i == '-' ||
|
||||
i == '.' ||
|
||||
i == '_' ||
|
||||
i == '~'
|
||||
}
|
||||
}
|
||||
|
||||
func getURIPath(u *url.URL) string {
|
||||
var uri string
|
||||
|
||||
if len(u.Opaque) > 0 {
|
||||
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
|
||||
} else {
|
||||
uri = u.EscapedPath()
|
||||
}
|
||||
|
||||
if len(uri) == 0 {
|
||||
uri = "/"
|
||||
}
|
||||
|
||||
return uri
|
||||
}
|
||||
|
||||
// EscapePath escapes part of a URL path in Amazon style
|
||||
func EscapePath(path string, encodeSep bool) string {
|
||||
var buf bytes.Buffer
|
||||
for i := 0; i < len(path); i++ {
|
||||
c := path[i]
|
||||
if noEscape[c] || (c == '/' && !encodeSep) {
|
||||
buf.WriteByte(c)
|
||||
} else {
|
||||
fmt.Fprintf(&buf, "%%%02X", c)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
421
internal/aws/signer/v4/v4.go
Normal file
421
internal/aws/signer/v4/v4.go
Normal file
@ -0,0 +1,421 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/v4.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package v4
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws"
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||
)
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
authHeaderSignatureElem = "Signature="
|
||||
|
||||
authHeaderPrefix = "AWS4-HMAC-SHA256"
|
||||
timeFormat = "20060102T150405Z"
|
||||
shortTimeFormat = "20060102"
|
||||
awsV4Request = "aws4_request"
|
||||
|
||||
// emptyStringSHA256 is a SHA256 of an empty string
|
||||
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
|
||||
)
|
||||
|
||||
var ignoredHeaders = rules{
|
||||
excludeList{
|
||||
mapRule{
|
||||
authorizationHeader: struct{}{},
|
||||
"User-Agent": struct{}{},
|
||||
"X-Amzn-Trace-Id": struct{}{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Signer applies AWS v4 signing to given request. Use this to sign requests
|
||||
// that need to be signed with AWS V4 Signatures.
|
||||
type Signer struct {
|
||||
// The authentication credentials the request will be signed against.
|
||||
// This value must be set to sign requests.
|
||||
Credentials *credentials.Credentials
|
||||
}
|
||||
|
||||
// NewSigner returns a Signer pointer configured with the credentials provided.
|
||||
func NewSigner(credentials *credentials.Credentials) *Signer {
|
||||
v4 := &Signer{
|
||||
Credentials: credentials,
|
||||
}
|
||||
|
||||
return v4
|
||||
}
|
||||
|
||||
type signingCtx struct {
|
||||
ServiceName string
|
||||
Region string
|
||||
Request *http.Request
|
||||
Body io.ReadSeeker
|
||||
Query url.Values
|
||||
Time time.Time
|
||||
SignedHeaderVals http.Header
|
||||
|
||||
credValues credentials.Value
|
||||
|
||||
bodyDigest string
|
||||
signedHeaders string
|
||||
canonicalHeaders string
|
||||
canonicalString string
|
||||
credentialString string
|
||||
stringToSign string
|
||||
signature string
|
||||
}
|
||||
|
||||
// Sign signs AWS v4 requests with the provided body, service name, region the
|
||||
// request is made to, and time the request is signed at. The signTime allows
|
||||
// you to specify that a request is signed for the future, and cannot be
|
||||
// used until then.
|
||||
//
|
||||
// Returns a list of HTTP headers that were included in the signature or an
|
||||
// error if signing the request failed. Generally for signed requests this value
|
||||
// is not needed as the full request context will be captured by the http.Request
|
||||
// value. It is included for reference though.
|
||||
//
|
||||
// Sign will set the request's Body to be the `body` parameter passed in. If
|
||||
// the body is not already an io.ReadCloser, it will be wrapped within one. If
|
||||
// a `nil` body parameter passed to Sign, the request's Body field will be
|
||||
// also set to nil. Its important to note that this functionality will not
|
||||
// change the request's ContentLength of the request.
|
||||
//
|
||||
// Sign differs from Presign in that it will sign the request using HTTP
|
||||
// header values. This type of signing is intended for http.Request values that
|
||||
// will not be shared, or are shared in a way the header values on the request
|
||||
// will not be lost.
|
||||
//
|
||||
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
|
||||
// generated. To bypass the signer computing the hash you can set the
|
||||
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
|
||||
// only compute the hash if the request header value is empty.
|
||||
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
|
||||
return v4.signWithBody(r, body, service, region, signTime)
|
||||
}
|
||||
|
||||
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
|
||||
ctx := &signingCtx{
|
||||
Request: r,
|
||||
Body: body,
|
||||
Query: r.URL.Query(),
|
||||
Time: signTime,
|
||||
ServiceName: service,
|
||||
Region: region,
|
||||
}
|
||||
|
||||
for key := range ctx.Query {
|
||||
sort.Strings(ctx.Query[key])
|
||||
}
|
||||
|
||||
if ctx.isRequestSigned() {
|
||||
ctx.Time = time.Now()
|
||||
}
|
||||
|
||||
var err error
|
||||
ctx.credValues, err = v4.Credentials.GetWithContext(r.Context())
|
||||
if err != nil {
|
||||
return http.Header{}, err
|
||||
}
|
||||
|
||||
ctx.sanitizeHostForHeader()
|
||||
ctx.assignAmzQueryValues()
|
||||
if err := ctx.build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reader io.ReadCloser
|
||||
if body != nil {
|
||||
var ok bool
|
||||
if reader, ok = body.(io.ReadCloser); !ok {
|
||||
reader = ioutil.NopCloser(body)
|
||||
}
|
||||
}
|
||||
r.Body = reader
|
||||
|
||||
return ctx.SignedHeaderVals, nil
|
||||
}
|
||||
|
||||
// sanitizeHostForHeader removes default port from host and updates request.Host
|
||||
func (ctx *signingCtx) sanitizeHostForHeader() {
|
||||
r := ctx.Request
|
||||
host := getHost(r)
|
||||
port := portOnly(host)
|
||||
if port != "" && isDefaultPort(r.URL.Scheme, port) {
|
||||
r.Host = stripPort(host)
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) assignAmzQueryValues() {
|
||||
if ctx.credValues.SessionToken != "" {
|
||||
ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) build() error {
|
||||
ctx.buildTime() // no depends
|
||||
ctx.buildCredentialString() // no depends
|
||||
|
||||
if err := ctx.buildBodyDigest(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unsignedHeaders := ctx.Request.Header
|
||||
|
||||
ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
|
||||
ctx.buildCanonicalString() // depends on canon headers / signed headers
|
||||
ctx.buildStringToSign() // depends on canon string
|
||||
ctx.buildSignature() // depends on string to sign
|
||||
|
||||
parts := []string{
|
||||
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
|
||||
"SignedHeaders=" + ctx.signedHeaders,
|
||||
authHeaderSignatureElem + ctx.signature,
|
||||
}
|
||||
ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildTime() {
|
||||
ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildCredentialString() {
|
||||
ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
|
||||
headers := make([]string, 0, len(header)+1)
|
||||
headers = append(headers, "host")
|
||||
for k, v := range header {
|
||||
if !r.IsValid(k) {
|
||||
continue // ignored header
|
||||
}
|
||||
if ctx.SignedHeaderVals == nil {
|
||||
ctx.SignedHeaderVals = make(http.Header)
|
||||
}
|
||||
|
||||
lowerCaseKey := strings.ToLower(k)
|
||||
if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok {
|
||||
// include additional values
|
||||
ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...)
|
||||
continue
|
||||
}
|
||||
|
||||
headers = append(headers, lowerCaseKey)
|
||||
ctx.SignedHeaderVals[lowerCaseKey] = v
|
||||
}
|
||||
sort.Strings(headers)
|
||||
|
||||
ctx.signedHeaders = strings.Join(headers, ";")
|
||||
|
||||
headerItems := make([]string, len(headers))
|
||||
for i, k := range headers {
|
||||
if k == "host" {
|
||||
if ctx.Request.Host != "" {
|
||||
headerItems[i] = "host:" + ctx.Request.Host
|
||||
} else {
|
||||
headerItems[i] = "host:" + ctx.Request.URL.Host
|
||||
}
|
||||
} else {
|
||||
headerValues := make([]string, len(ctx.SignedHeaderVals[k]))
|
||||
for i, v := range ctx.SignedHeaderVals[k] {
|
||||
headerValues[i] = strings.TrimSpace(v)
|
||||
}
|
||||
headerItems[i] = k + ":" +
|
||||
strings.Join(headerValues, ",")
|
||||
}
|
||||
}
|
||||
stripExcessSpaces(headerItems)
|
||||
ctx.canonicalHeaders = strings.Join(headerItems, "\n")
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildCanonicalString() {
|
||||
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
|
||||
|
||||
uri := getURIPath(ctx.Request.URL)
|
||||
|
||||
uri = EscapePath(uri, false)
|
||||
|
||||
ctx.canonicalString = strings.Join([]string{
|
||||
ctx.Request.Method,
|
||||
uri,
|
||||
ctx.Request.URL.RawQuery,
|
||||
ctx.canonicalHeaders + "\n",
|
||||
ctx.signedHeaders,
|
||||
ctx.bodyDigest,
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildStringToSign() {
|
||||
ctx.stringToSign = strings.Join([]string{
|
||||
authHeaderPrefix,
|
||||
formatTime(ctx.Time),
|
||||
ctx.credentialString,
|
||||
hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildSignature() {
|
||||
creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
|
||||
signature := hmacSHA256(creds, []byte(ctx.stringToSign))
|
||||
ctx.signature = hex.EncodeToString(signature)
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildBodyDigest() error {
|
||||
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
|
||||
if hash == "" {
|
||||
if ctx.Body == nil {
|
||||
hash = emptyStringSHA256
|
||||
} else {
|
||||
if !aws.IsReaderSeekable(ctx.Body) {
|
||||
return fmt.Errorf("cannot use unseekable request body %T, for signed request with body", ctx.Body)
|
||||
}
|
||||
hashBytes, err := makeSha256Reader(ctx.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hash = hex.EncodeToString(hashBytes)
|
||||
}
|
||||
}
|
||||
ctx.bodyDigest = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isRequestSigned returns if the request is currently signed or presigned
|
||||
func (ctx *signingCtx) isRequestSigned() bool {
|
||||
return ctx.Request.Header.Get("Authorization") != ""
|
||||
}
|
||||
|
||||
func hmacSHA256(key []byte, data []byte) []byte {
|
||||
hash := hmac.New(sha256.New, key)
|
||||
hash.Write(data)
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
func hashSHA256(data []byte) []byte {
|
||||
hash := sha256.New()
|
||||
hash.Write(data)
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
|
||||
hash := sha256.New()
|
||||
start, err := reader.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
// ensure error is return if unable to seek back to start of payload.
|
||||
_, err = reader.Seek(start, io.SeekStart)
|
||||
}()
|
||||
|
||||
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
|
||||
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
|
||||
size, err := aws.SeekerLen(reader)
|
||||
if err != nil {
|
||||
_, _ = io.Copy(hash, reader)
|
||||
} else {
|
||||
_, _ = io.CopyN(hash, reader, size)
|
||||
}
|
||||
|
||||
return hash.Sum(nil), nil
|
||||
}
|
||||
|
||||
const doubleSpace = " "
|
||||
|
||||
// stripExcessSpaces will rewrite the passed in slice's string values to not
|
||||
// contain multiple side-by-side spaces.
|
||||
func stripExcessSpaces(vals []string) {
|
||||
var j, k, l, m, spaces int
|
||||
for i, str := range vals {
|
||||
// revive:disable:empty-block
|
||||
|
||||
// Trim trailing spaces
|
||||
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
|
||||
}
|
||||
|
||||
// Trim leading spaces
|
||||
for k = 0; k < j && str[k] == ' '; k++ {
|
||||
}
|
||||
|
||||
// revive:enable:empty-block
|
||||
|
||||
str = str[k : j+1]
|
||||
|
||||
// Strip multiple spaces.
|
||||
j = strings.Index(str, doubleSpace)
|
||||
if j < 0 {
|
||||
vals[i] = str
|
||||
continue
|
||||
}
|
||||
|
||||
buf := []byte(str)
|
||||
for k, m, l = j, j, len(buf); k < l; k++ {
|
||||
if buf[k] == ' ' {
|
||||
if spaces == 0 {
|
||||
// First space.
|
||||
buf[m] = buf[k]
|
||||
m++
|
||||
}
|
||||
spaces++
|
||||
} else {
|
||||
// End of multiple spaces.
|
||||
spaces = 0
|
||||
buf[m] = buf[k]
|
||||
m++
|
||||
}
|
||||
}
|
||||
|
||||
vals[i] = string(buf[:m])
|
||||
}
|
||||
}
|
||||
|
||||
func buildSigningScope(region, service string, dt time.Time) string {
|
||||
return strings.Join([]string{
|
||||
formatShortTime(dt),
|
||||
region,
|
||||
service,
|
||||
awsV4Request,
|
||||
}, "/")
|
||||
}
|
||||
|
||||
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
|
||||
keyDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
|
||||
keyRegion := hmacSHA256(keyDate, []byte(region))
|
||||
keyService := hmacSHA256(keyRegion, []byte(service))
|
||||
signingKey := hmacSHA256(keyService, []byte(awsV4Request))
|
||||
return signingKey
|
||||
}
|
||||
|
||||
func formatShortTime(dt time.Time) string {
|
||||
return dt.UTC().Format(shortTimeFormat)
|
||||
}
|
||||
|
||||
func formatTime(dt time.Time) string {
|
||||
return dt.UTC().Format(timeFormat)
|
||||
}
|
434
internal/aws/signer/v4/v4_test.go
Normal file
434
internal/aws/signer/v4/v4_test.go
Normal file
@ -0,0 +1,434 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/signer/v4/v4_test.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package v4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws"
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||
"gitea.psichedelico.com/go/bson/internal/credproviders"
|
||||
)
|
||||
|
||||
func epochTime() time.Time { return time.Unix(0, 0) }
|
||||
|
||||
func TestStripExcessHeaders(t *testing.T) {
|
||||
vals := []string{
|
||||
"",
|
||||
"123",
|
||||
"1 2 3",
|
||||
"1 2 3 ",
|
||||
" 1 2 3",
|
||||
"1 2 3",
|
||||
"1 23",
|
||||
"1 2 3",
|
||||
"1 2 ",
|
||||
" 1 2 ",
|
||||
"12 3",
|
||||
"12 3 1",
|
||||
"12 3 1",
|
||||
"12 3 1abc123",
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
"",
|
||||
"123",
|
||||
"1 2 3",
|
||||
"1 2 3",
|
||||
"1 2 3",
|
||||
"1 2 3",
|
||||
"1 23",
|
||||
"1 2 3",
|
||||
"1 2",
|
||||
"1 2",
|
||||
"12 3",
|
||||
"12 3 1",
|
||||
"12 3 1",
|
||||
"12 3 1abc123",
|
||||
}
|
||||
|
||||
stripExcessSpaces(vals)
|
||||
for i := 0; i < len(vals); i++ {
|
||||
if e, a := expected[i], vals[i]; e != a {
|
||||
t.Errorf("%d, expect %v, got %v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildRequest(body string) (*http.Request, io.ReadSeeker) {
|
||||
reader := strings.NewReader(body)
|
||||
return buildRequestWithBodyReader("dynamodb", "us-east-1", reader)
|
||||
}
|
||||
|
||||
func buildRequestReaderSeeker(serviceName, region, body string) (*http.Request, io.ReadSeeker) {
|
||||
reader := &readerSeekerWrapper{strings.NewReader(body)}
|
||||
return buildRequestWithBodyReader(serviceName, region, reader)
|
||||
}
|
||||
|
||||
func buildRequestWithBodyReader(serviceName, region string, body io.Reader) (*http.Request, io.ReadSeeker) {
|
||||
var bodyLen int
|
||||
|
||||
type lenner interface {
|
||||
Len() int
|
||||
}
|
||||
if lr, ok := body.(lenner); ok {
|
||||
bodyLen = lr.Len()
|
||||
}
|
||||
|
||||
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
|
||||
req, _ := http.NewRequest("POST", endpoint, body)
|
||||
req.URL.Opaque = "//example.org/bucket/key-._~,!@#$%^&*()"
|
||||
req.Header.Set("X-Amz-Target", "prefix.Operation")
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
|
||||
if bodyLen > 0 {
|
||||
req.Header.Set("Content-Length", strconv.Itoa(bodyLen))
|
||||
}
|
||||
|
||||
req.Header.Set("X-Amz-Meta-Other-Header", "some-value=!@#$%^&* (+)")
|
||||
req.Header.Add("X-Amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
|
||||
req.Header.Add("X-amz-Meta-Other-Header_With_Underscore", "some-value=!@#$%^&* (+)")
|
||||
|
||||
var seeker io.ReadSeeker
|
||||
if sr, ok := body.(io.ReadSeeker); ok {
|
||||
seeker = sr
|
||||
} else {
|
||||
seeker = aws.ReadSeekCloser(body)
|
||||
}
|
||||
|
||||
return req, seeker
|
||||
}
|
||||
|
||||
func buildSigner() Signer {
|
||||
return Signer{
|
||||
Credentials: newTestStaticCredentials(),
|
||||
}
|
||||
}
|
||||
|
||||
func newTestStaticCredentials() *credentials.Credentials {
|
||||
return credentials.NewCredentials(&credproviders.StaticProvider{Value: credentials.Value{
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "SECRET",
|
||||
SessionToken: "SESSION",
|
||||
}})
|
||||
}
|
||||
|
||||
func TestSignRequest(t *testing.T) {
|
||||
req, body := buildRequest("{}")
|
||||
signer := buildSigner()
|
||||
_, err := signer.Sign(req, body, "dynamodb", "us-east-1", epochTime())
|
||||
if err != nil {
|
||||
t.Errorf("Expected no err, got %v", err)
|
||||
}
|
||||
|
||||
expectedDate := "19700101T000000Z"
|
||||
expectedSig := "AWS4-HMAC-SHA256 Credential=AKID/19700101/us-east-1/dynamodb/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date;x-amz-meta-other-header;x-amz-meta-other-header_with_underscore;x-amz-security-token;x-amz-target, Signature=a518299330494908a70222cec6899f6f32f297f8595f6df1776d998936652ad9"
|
||||
|
||||
q := req.Header
|
||||
if e, a := expectedSig, q.Get("Authorization"); e != a {
|
||||
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
|
||||
}
|
||||
if e, a := expectedDate, q.Get("X-Amz-Date"); e != a {
|
||||
t.Errorf("expect\n%v\nactual\n%v\n", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignUnseekableBody(t *testing.T) {
|
||||
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
|
||||
signer := buildSigner()
|
||||
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
|
||||
if err == nil {
|
||||
t.Fatalf("expect error signing request")
|
||||
}
|
||||
|
||||
if e, a := "unseekable request body", err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("expect %q to be in %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignPreComputedHashUnseekableBody(t *testing.T) {
|
||||
req, body := buildRequestWithBodyReader("mock-service", "mock-region", bytes.NewBuffer([]byte("hello")))
|
||||
|
||||
signer := buildSigner()
|
||||
|
||||
req.Header.Set("X-Amz-Content-Sha256", "some-content-sha256")
|
||||
_, err := signer.Sign(req, body, "mock-service", "mock-region", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||
if e, a := "some-content-sha256", hash; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignPrecomputedBodyChecksum(t *testing.T) {
|
||||
req, body := buildRequest("hello")
|
||||
req.Header.Set("X-Amz-Content-Sha256", "PRECOMPUTED")
|
||||
signer := buildSigner()
|
||||
_, err := signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
|
||||
if err != nil {
|
||||
t.Errorf("Expected no err, got %v", err)
|
||||
}
|
||||
hash := req.Header.Get("X-Amz-Content-Sha256")
|
||||
if e, a := "PRECOMPUTED", hash; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignWithRequestBody(t *testing.T) {
|
||||
creds := newTestStaticCredentials()
|
||||
signer := NewSigner(creds)
|
||||
|
||||
expectBody := []byte("abc123")
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("expect no error, got %v", err)
|
||||
}
|
||||
if e, a := expectBody, b; !reflect.DeepEqual(e, a) {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL, nil)
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
|
||||
_, err = signer.Sign(req, bytes.NewReader(expectBody), "service", "region", time.Now())
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignWithRequestBody_Overwrite(t *testing.T) {
|
||||
creds := newTestStaticCredentials()
|
||||
signer := NewSigner(creds)
|
||||
|
||||
var expectBody []byte
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
if e, a := len(expectBody), len(b); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL, strings.NewReader("invalid body"))
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
|
||||
_, err = signer.Sign(req, nil, "service", "region", time.Now())
|
||||
req.ContentLength = 0
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Errorf("expect not no error, got %v", err)
|
||||
}
|
||||
if e, a := http.StatusOK, resp.StatusCode; e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCanonicalRequest(t *testing.T) {
|
||||
req, body := buildRequest("{}")
|
||||
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
|
||||
ctx := &signingCtx{
|
||||
ServiceName: "dynamodb",
|
||||
Region: "us-east-1",
|
||||
Request: req,
|
||||
Body: body,
|
||||
Query: req.URL.Query(),
|
||||
Time: time.Now(),
|
||||
}
|
||||
|
||||
ctx.buildCanonicalString()
|
||||
expected := "https://example.org/bucket/key-._~,!@#$%^&*()?Foo=z&Foo=o&Foo=m&Foo=a"
|
||||
if e, a := expected, ctx.Request.URL.String(); e != a {
|
||||
t.Errorf("expect %v, got %v", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignWithBody_ReplaceRequestBody(t *testing.T) {
|
||||
creds := newTestStaticCredentials()
|
||||
req, seekerBody := buildRequest("{}")
|
||||
req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||
|
||||
s := NewSigner(creds)
|
||||
origBody := req.Body
|
||||
|
||||
_, err := s.Sign(req, seekerBody, "dynamodb", "us-east-1", time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("expect no error, got %v", err)
|
||||
}
|
||||
|
||||
if req.Body == origBody {
|
||||
t.Errorf("expect request body to not be origBody")
|
||||
}
|
||||
|
||||
if req.Body == nil {
|
||||
t.Errorf("expect request body to be changed but was nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHost(t *testing.T) {
|
||||
req, body := buildRequest("{}")
|
||||
req.URL.RawQuery = "Foo=z&Foo=o&Foo=m&Foo=a"
|
||||
req.Host = "myhost"
|
||||
ctx := &signingCtx{
|
||||
ServiceName: "dynamodb",
|
||||
Region: "us-east-1",
|
||||
Request: req,
|
||||
Body: body,
|
||||
Query: req.URL.Query(),
|
||||
Time: time.Now(),
|
||||
}
|
||||
|
||||
ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
|
||||
if !strings.Contains(ctx.canonicalHeaders, "host:"+req.Host) {
|
||||
t.Errorf("canonical host header invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSign_buildCanonicalHeaders(t *testing.T) {
|
||||
serviceName := "mockAPI"
|
||||
region := "mock-region"
|
||||
endpoint := "https://" + serviceName + "." + region + ".amazonaws.com"
|
||||
|
||||
req, err := http.NewRequest("POST", endpoint, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request, %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("FooInnerSpace", " inner space ")
|
||||
req.Header.Set("FooLeadingSpace", " leading-space")
|
||||
req.Header.Add("FooMultipleSpace", "no-space")
|
||||
req.Header.Add("FooMultipleSpace", "\ttab-space")
|
||||
req.Header.Add("FooMultipleSpace", "trailing-space ")
|
||||
req.Header.Set("FooNoSpace", "no-space")
|
||||
req.Header.Set("FooTabSpace", "\ttab-space\t")
|
||||
req.Header.Set("FooTrailingSpace", "trailing-space ")
|
||||
req.Header.Set("FooWrappedSpace", " wrapped-space ")
|
||||
|
||||
ctx := &signingCtx{
|
||||
ServiceName: serviceName,
|
||||
Region: region,
|
||||
Request: req,
|
||||
Body: nil,
|
||||
Query: req.URL.Query(),
|
||||
Time: time.Now(),
|
||||
}
|
||||
|
||||
ctx.buildCanonicalHeaders(ignoredHeaders, ctx.Request.Header)
|
||||
|
||||
expectCanonicalHeaders := strings.Join([]string{
|
||||
`fooinnerspace:inner space`,
|
||||
`fooleadingspace:leading-space`,
|
||||
`foomultiplespace:no-space,tab-space,trailing-space`,
|
||||
`foonospace:no-space`,
|
||||
`footabspace:tab-space`,
|
||||
`footrailingspace:trailing-space`,
|
||||
`foowrappedspace:wrapped-space`,
|
||||
`host:mockAPI.mock-region.amazonaws.com`,
|
||||
}, "\n")
|
||||
if e, a := expectCanonicalHeaders, ctx.canonicalHeaders; e != a {
|
||||
t.Errorf("expect:\n%s\n\nactual:\n%s", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSignRequest(b *testing.B) {
|
||||
signer := buildSigner()
|
||||
req, body := buildRequestReaderSeeker("dynamodb", "us-east-1", "{}")
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := signer.Sign(req, body, "dynamodb", "us-east-1", time.Now())
|
||||
if err != nil {
|
||||
b.Errorf("Expected no err, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var stripExcessSpaceCases = []string{
|
||||
`AWS4-HMAC-SHA256 Credential=AKIDFAKEIDFAKEID/20160628/us-west-2/s3/aws4_request, SignedHeaders=host;x-amz-date, Signature=1234567890abcdef1234567890abcdef1234567890abcdef`,
|
||||
`123 321 123 321`,
|
||||
` 123 321 123 321 `,
|
||||
` 123 321 123 321 `,
|
||||
"123",
|
||||
"1 2 3",
|
||||
" 1 2 3",
|
||||
"1 2 3",
|
||||
"1 23",
|
||||
"1 2 3",
|
||||
"1 2 ",
|
||||
" 1 2 ",
|
||||
"12 3",
|
||||
"12 3 1",
|
||||
"12 3 1",
|
||||
"12 3 1abc123",
|
||||
}
|
||||
|
||||
func BenchmarkStripExcessSpaces(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Make sure to start with a copy of the cases
|
||||
cases := append([]string{}, stripExcessSpaceCases...)
|
||||
stripExcessSpaces(cases)
|
||||
}
|
||||
}
|
||||
|
||||
// readerSeekerWrapper mimics the interface provided by request.offsetReader
|
||||
type readerSeekerWrapper struct {
|
||||
r *strings.Reader
|
||||
}
|
||||
|
||||
func (r *readerSeekerWrapper) Read(p []byte) (n int, err error) {
|
||||
return r.r.Read(p)
|
||||
}
|
||||
|
||||
func (r *readerSeekerWrapper) Seek(offset int64, whence int) (int64, error) {
|
||||
return r.r.Seek(offset, whence)
|
||||
}
|
||||
|
||||
func (r *readerSeekerWrapper) Len() int {
|
||||
return r.r.Len()
|
||||
}
|
153
internal/aws/types.go
Normal file
153
internal/aws/types.go
Normal file
@ -0,0 +1,153 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.44.225/aws/types.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package aws
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the
|
||||
// SDK to accept an io.Reader that is not also an io.Seeker for unsigned
|
||||
// streaming payload API operations.
|
||||
//
|
||||
// A ReadSeekCloser wrapping an nonseekable io.Reader used in an API
|
||||
// operation's input will prevent that operation being retried in the case of
|
||||
// network errors, and cause operation requests to fail if the operation
|
||||
// requires payload signing.
|
||||
//
|
||||
// Note: If using With S3 PutObject to stream an object upload The SDK's S3
|
||||
// Upload manager (s3manager.Uploader) provides support for streaming with the
|
||||
// ability to retry network errors.
|
||||
func ReadSeekCloser(r io.Reader) ReaderSeekerCloser {
|
||||
return ReaderSeekerCloser{r}
|
||||
}
|
||||
|
||||
// ReaderSeekerCloser represents a reader that can also delegate io.Seeker and
|
||||
// io.Closer interfaces to the underlying object if they are available.
|
||||
type ReaderSeekerCloser struct {
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
// IsReaderSeekable returns if the underlying reader type can be seeked. A
|
||||
// io.Reader might not actually be seekable if it is the ReaderSeekerCloser
|
||||
// type.
|
||||
func IsReaderSeekable(r io.Reader) bool {
|
||||
switch v := r.(type) {
|
||||
case ReaderSeekerCloser:
|
||||
return v.IsSeeker()
|
||||
case *ReaderSeekerCloser:
|
||||
return v.IsSeeker()
|
||||
case io.ReadSeeker:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Read reads from the reader up to size of p. The number of bytes read, and
|
||||
// error if it occurred will be returned.
|
||||
//
|
||||
// If the reader is not an io.Reader zero bytes read, and nil error will be
|
||||
// returned.
|
||||
//
|
||||
// Performs the same functionality as io.Reader Read
|
||||
func (r ReaderSeekerCloser) Read(p []byte) (int, error) {
|
||||
switch t := r.r.(type) {
|
||||
case io.Reader:
|
||||
return t.Read(p)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Seek sets the offset for the next Read to offset, interpreted according to
|
||||
// whence: 0 means relative to the origin of the file, 1 means relative to the
|
||||
// current offset, and 2 means relative to the end. Seek returns the new offset
|
||||
// and an error, if any.
|
||||
//
|
||||
// If the ReaderSeekerCloser is not an io.Seeker nothing will be done.
|
||||
func (r ReaderSeekerCloser) Seek(offset int64, whence int) (int64, error) {
|
||||
switch t := r.r.(type) {
|
||||
case io.Seeker:
|
||||
return t.Seek(offset, whence)
|
||||
}
|
||||
return int64(0), nil
|
||||
}
|
||||
|
||||
// IsSeeker returns if the underlying reader is also a seeker.
|
||||
func (r ReaderSeekerCloser) IsSeeker() bool {
|
||||
_, ok := r.r.(io.Seeker)
|
||||
return ok
|
||||
}
|
||||
|
||||
// HasLen returns the length of the underlying reader if the value implements
|
||||
// the Len() int method.
|
||||
func (r ReaderSeekerCloser) HasLen() (int, bool) {
|
||||
type lenner interface {
|
||||
Len() int
|
||||
}
|
||||
|
||||
if lr, ok := r.r.(lenner); ok {
|
||||
return lr.Len(), true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetLen returns the length of the bytes remaining in the underlying reader.
|
||||
// Checks first for Len(), then io.Seeker to determine the size of the
|
||||
// underlying reader.
|
||||
//
|
||||
// Will return -1 if the length cannot be determined.
|
||||
func (r ReaderSeekerCloser) GetLen() (int64, error) {
|
||||
if l, ok := r.HasLen(); ok {
|
||||
return int64(l), nil
|
||||
}
|
||||
|
||||
if s, ok := r.r.(io.Seeker); ok {
|
||||
return seekerLen(s)
|
||||
}
|
||||
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// SeekerLen attempts to get the number of bytes remaining at the seeker's
|
||||
// current position. Returns the number of bytes remaining or error.
|
||||
func SeekerLen(s io.Seeker) (int64, error) {
|
||||
// Determine if the seeker is actually seekable. ReaderSeekerCloser
|
||||
// hides the fact that a io.Readers might not actually be seekable.
|
||||
switch v := s.(type) {
|
||||
case ReaderSeekerCloser:
|
||||
return v.GetLen()
|
||||
case *ReaderSeekerCloser:
|
||||
return v.GetLen()
|
||||
}
|
||||
|
||||
return seekerLen(s)
|
||||
}
|
||||
|
||||
func seekerLen(s io.Seeker) (int64, error) {
|
||||
curOffset, err := s.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
endOffset, err := s.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = s.Seek(curOffset, io.SeekStart)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return endOffset - curOffset, nil
|
||||
}
|
40
internal/bsoncoreutil/bsoncoreutil.go
Normal file
40
internal/bsoncoreutil/bsoncoreutil.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncoreutil
|
||||
|
||||
// Truncate truncates a given string for a certain width
|
||||
func Truncate(str string, width int) string {
|
||||
if width == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(str) <= width {
|
||||
return str
|
||||
}
|
||||
|
||||
// Truncate the byte slice of the string to the given width.
|
||||
newStr := str[:width]
|
||||
|
||||
// Check if the last byte is at the beginning of a multi-byte character.
|
||||
// If it is, then remove the last byte.
|
||||
if newStr[len(newStr)-1]&0xC0 == 0xC0 {
|
||||
return newStr[:len(newStr)-1]
|
||||
}
|
||||
|
||||
// Check if the last byte is a multi-byte character
|
||||
if newStr[len(newStr)-1]&0xC0 == 0x80 {
|
||||
// If it is, step back until you we are at the start of a character
|
||||
for i := len(newStr) - 1; i >= 0; i-- {
|
||||
if newStr[i]&0xC0 == 0xC0 {
|
||||
// Truncate at the end of the character before the character we stepped back to
|
||||
return newStr[:i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return newStr
|
||||
}
|
59
internal/bsoncoreutil/bsoncoreutil_test.go
Normal file
59
internal/bsoncoreutil/bsoncoreutil_test.go
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncoreutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
)
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tcase := range []struct {
|
||||
name string
|
||||
arg string
|
||||
width int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
arg: "",
|
||||
width: 0,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "short",
|
||||
arg: "foo",
|
||||
width: 1000,
|
||||
expected: "foo",
|
||||
},
|
||||
{
|
||||
name: "long",
|
||||
arg: "foo bar baz",
|
||||
width: 9,
|
||||
expected: "foo bar b",
|
||||
},
|
||||
{
|
||||
name: "multi-byte",
|
||||
arg: "你好",
|
||||
width: 4,
|
||||
expected: "你",
|
||||
},
|
||||
} {
|
||||
tcase := tcase
|
||||
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := Truncate(tcase.arg, tcase.width)
|
||||
assert.Equal(t, tcase.expected, actual)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
62
internal/bsonutil/bsonutil.go
Normal file
62
internal/bsonutil/bsonutil.go
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
)
|
||||
|
||||
// StringSliceFromRawValue decodes the provided BSON value into a []string. This function returns an error if the value
|
||||
// is not an array or any of the elements in the array are not strings. The name parameter is used to add context to
|
||||
// error messages.
|
||||
func StringSliceFromRawValue(name string, val bson.RawValue) ([]string, error) {
|
||||
arr, ok := val.ArrayOK()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected '%s' to be an array but it's a BSON %s", name, val.Type)
|
||||
}
|
||||
|
||||
arrayValues, err := arr.Values()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
strs := make([]string, 0, len(arrayValues))
|
||||
for _, arrayVal := range arrayValues {
|
||||
str, ok := arrayVal.StringValueOK()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected '%s' to be an array of strings, but found a BSON %s", name, arrayVal.Type)
|
||||
}
|
||||
strs = append(strs, str)
|
||||
}
|
||||
return strs, nil
|
||||
}
|
||||
|
||||
// RawArrayToDocuments converts an array of documents to []bson.Raw.
|
||||
func RawArrayToDocuments(arr bson.RawArray) []bson.Raw {
|
||||
values, err := arr.Values()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error converting BSON document to values: %v", err))
|
||||
}
|
||||
|
||||
out := make([]bson.Raw, len(values))
|
||||
for i := range values {
|
||||
out[i] = values[i].Document()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// RawToInterfaces takes one or many bson.Raw documents and returns them as a []interface{}.
|
||||
func RawToInterfaces(docs ...bson.Raw) []interface{} {
|
||||
out := make([]interface{}, len(docs))
|
||||
for i := range docs {
|
||||
out[i] = docs[i]
|
||||
}
|
||||
return out
|
||||
}
|
62
internal/codecutil/encoding.go
Normal file
62
internal/codecutil/encoding.go
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package codecutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
var ErrNilValue = errors.New("value is nil")
|
||||
|
||||
// MarshalError is returned when attempting to transform a value into a document
|
||||
// results in an error.
|
||||
type MarshalError struct {
|
||||
Value interface{}
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e MarshalError) Error() string {
|
||||
return fmt.Sprintf("cannot transform type %s to a BSON Document: %v",
|
||||
reflect.TypeOf(e.Value), e.Err)
|
||||
}
|
||||
|
||||
// EncoderFn is used to functionally construct an encoder for marshaling values.
|
||||
type EncoderFn func(io.Writer) *bson.Encoder
|
||||
|
||||
// MarshalValue will attempt to encode the value with the encoder returned by
|
||||
// the encoder function.
|
||||
func MarshalValue(val interface{}, encFn EncoderFn) (bsoncore.Value, error) {
|
||||
// If the val is already a bsoncore.Value, then do nothing.
|
||||
if bval, ok := val.(bsoncore.Value); ok {
|
||||
return bval, nil
|
||||
}
|
||||
|
||||
if val == nil {
|
||||
return bsoncore.Value{}, ErrNilValue
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
enc := encFn(buf)
|
||||
|
||||
// Encode the value in a single-element document with an empty key. Use
|
||||
// bsoncore to extract the first element and return the BSON value.
|
||||
err := enc.Encode(bson.D{{Key: "", Value: val}})
|
||||
if err != nil {
|
||||
return bsoncore.Value{}, MarshalError{Value: val, Err: err}
|
||||
}
|
||||
|
||||
return bsoncore.Document(buf.Bytes()).Index(0).Value(), nil
|
||||
}
|
82
internal/codecutil/encoding_test.go
Normal file
82
internal/codecutil/encoding_test.go
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package codecutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
)
|
||||
|
||||
func testEncFn(t *testing.T) EncoderFn {
|
||||
t.Helper()
|
||||
|
||||
return func(w io.Writer) *bson.Encoder {
|
||||
rw := bson.NewDocumentWriter(w)
|
||||
return bson.NewEncoder(rw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
val interface{}
|
||||
registry *bson.Registry
|
||||
encFn EncoderFn
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
val: nil,
|
||||
want: "",
|
||||
wantErr: ErrNilValue,
|
||||
encFn: testEncFn(t),
|
||||
},
|
||||
{
|
||||
name: "bson.D",
|
||||
val: bson.D{{"foo", "bar"}},
|
||||
want: `{"foo": "bar"}`,
|
||||
encFn: testEncFn(t),
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
val: map[string]interface{}{"foo": "bar"},
|
||||
want: `{"foo": "bar"}`,
|
||||
encFn: testEncFn(t),
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
val: struct{ Foo string }{Foo: "bar"},
|
||||
want: `{"foo": "bar"}`,
|
||||
encFn: testEncFn(t),
|
||||
},
|
||||
{
|
||||
name: "non-document type",
|
||||
val: "foo: bar",
|
||||
want: `"foo: bar"`,
|
||||
encFn: testEncFn(t),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
value, err := MarshalValue(test.val, test.encFn)
|
||||
|
||||
assert.Equal(t, test.wantErr, err, "expected and actual error do not match")
|
||||
assert.Equal(t, test.want, value.String(), "expected and actual comments are different")
|
||||
})
|
||||
}
|
||||
}
|
148
internal/credproviders/assume_role_provider.go
Normal file
148
internal/credproviders/assume_role_provider.go
Normal file
@ -0,0 +1,148 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package credproviders
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||
"gitea.psichedelico.com/go/bson/internal/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// assumeRoleProviderName provides a name of assume role provider
|
||||
assumeRoleProviderName = "AssumeRoleProvider"
|
||||
|
||||
stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15`
|
||||
)
|
||||
|
||||
// An AssumeRoleProvider retrieves credentials for assume role with web identity.
|
||||
type AssumeRoleProvider struct {
|
||||
AwsRoleArnEnv EnvVar
|
||||
AwsWebIdentityTokenFileEnv EnvVar
|
||||
AwsRoleSessionNameEnv EnvVar
|
||||
|
||||
httpClient *http.Client
|
||||
expiration time.Time
|
||||
|
||||
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
|
||||
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
|
||||
//
|
||||
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
||||
// 10 seconds before the credentials are actually expired.
|
||||
expiryWindow time.Duration
|
||||
}
|
||||
|
||||
// NewAssumeRoleProvider returns a pointer to an assume role provider.
|
||||
func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider {
|
||||
return &AssumeRoleProvider{
|
||||
// AwsRoleArnEnv is the environment variable for AWS_ROLE_ARN
|
||||
AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"),
|
||||
// AwsWebIdentityTokenFileEnv is the environment variable for AWS_WEB_IDENTITY_TOKEN_FILE
|
||||
AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"),
|
||||
// AwsRoleSessionNameEnv is the environment variable for AWS_ROLE_SESSION_NAME
|
||||
AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"),
|
||||
httpClient: httpClient,
|
||||
expiryWindow: expiryWindow,
|
||||
}
|
||||
}
|
||||
|
||||
// RetrieveWithContext retrieves the keys from the AWS service.
|
||||
func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
|
||||
const defaultHTTPTimeout = 10 * time.Second
|
||||
|
||||
v := credentials.Value{ProviderName: assumeRoleProviderName}
|
||||
|
||||
roleArn := a.AwsRoleArnEnv.Get()
|
||||
tokenFile := a.AwsWebIdentityTokenFileEnv.Get()
|
||||
if tokenFile == "" && roleArn == "" {
|
||||
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing")
|
||||
}
|
||||
if tokenFile != "" && roleArn == "" {
|
||||
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing")
|
||||
}
|
||||
if tokenFile == "" && roleArn != "" {
|
||||
return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing")
|
||||
}
|
||||
token, err := ioutil.ReadFile(tokenFile)
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
|
||||
sessionName := a.AwsRoleSessionNameEnv.Get()
|
||||
if sessionName == "" {
|
||||
// Use a UUID if the RoleSessionName is not given.
|
||||
id, err := uuid.New()
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
sessionName = id.String()
|
||||
}
|
||||
|
||||
fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token))
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, fullURI, nil)
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||
defer cancel()
|
||||
resp, err := a.httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return v, fmt.Errorf("response failure: %s", resp.Status)
|
||||
}
|
||||
|
||||
var stsResp struct {
|
||||
Response struct {
|
||||
Result struct {
|
||||
Credentials struct {
|
||||
AccessKeyID string `json:"AccessKeyId"`
|
||||
SecretAccessKey string `json:"SecretAccessKey"`
|
||||
Token string `json:"SessionToken"`
|
||||
Expiration float64 `json:"Expiration"`
|
||||
} `json:"Credentials"`
|
||||
} `json:"AssumeRoleWithWebIdentityResult"`
|
||||
} `json:"AssumeRoleWithWebIdentityResponse"`
|
||||
}
|
||||
|
||||
err = json.NewDecoder(resp.Body).Decode(&stsResp)
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID
|
||||
v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey
|
||||
v.SessionToken = stsResp.Response.Result.Credentials.Token
|
||||
if !v.HasKeys() {
|
||||
return v, errors.New("failed to retrieve web identity keys")
|
||||
}
|
||||
sec := int64(stsResp.Response.Result.Credentials.Expiration)
|
||||
a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow)
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Retrieve retrieves the keys from the AWS service.
|
||||
func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
|
||||
return a.RetrieveWithContext(context.Background())
|
||||
}
|
||||
|
||||
// IsExpired returns true if the credentials are expired.
|
||||
func (a *AssumeRoleProvider) IsExpired() bool {
|
||||
return a.expiration.Before(time.Now())
|
||||
}
|
183
internal/credproviders/ec2_provider.go
Normal file
183
internal/credproviders/ec2_provider.go
Normal file
@ -0,0 +1,183 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package credproviders
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||
)
|
||||
|
||||
const (
|
||||
// ec2ProviderName provides a name of EC2 provider
|
||||
ec2ProviderName = "EC2Provider"
|
||||
|
||||
awsEC2URI = "http://169.254.169.254/"
|
||||
awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
|
||||
awsEC2TokenPath = "latest/api/token"
|
||||
|
||||
defaultHTTPTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// An EC2Provider retrieves credentials from EC2 metadata.
|
||||
type EC2Provider struct {
|
||||
httpClient *http.Client
|
||||
expiration time.Time
|
||||
|
||||
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
|
||||
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
|
||||
//
|
||||
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
||||
// 10 seconds before the credentials are actually expired.
|
||||
expiryWindow time.Duration
|
||||
}
|
||||
|
||||
// NewEC2Provider returns a pointer to an EC2 credential provider.
|
||||
func NewEC2Provider(httpClient *http.Client, expiryWindow time.Duration) *EC2Provider {
|
||||
return &EC2Provider{
|
||||
httpClient: httpClient,
|
||||
expiryWindow: expiryWindow,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EC2Provider) getToken(ctx context.Context) (string, error) {
|
||||
req, err := http.NewRequest(http.MethodPut, awsEC2URI+awsEC2TokenPath, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
const defaultEC2TTLSeconds = "30"
|
||||
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", defaultEC2TTLSeconds)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||
defer cancel()
|
||||
resp, err := e.httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
|
||||
}
|
||||
|
||||
token, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(token) == 0 {
|
||||
return "", errors.New("unable to retrieve token from EC2 metadata")
|
||||
}
|
||||
return string(token), nil
|
||||
}
|
||||
|
||||
func (e *EC2Provider) getRoleName(ctx context.Context, token string) (string, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, awsEC2URI+awsEC2RolePath, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("X-aws-ec2-metadata-token", token)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||
defer cancel()
|
||||
resp, err := e.httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
|
||||
}
|
||||
|
||||
role, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(role) == 0 {
|
||||
return "", errors.New("unable to retrieve role_name from EC2 metadata")
|
||||
}
|
||||
return string(role), nil
|
||||
}
|
||||
|
||||
func (e *EC2Provider) getCredentials(ctx context.Context, token string, role string) (credentials.Value, time.Time, error) {
|
||||
v := credentials.Value{ProviderName: ec2ProviderName}
|
||||
|
||||
pathWithRole := awsEC2URI + awsEC2RolePath + role
|
||||
req, err := http.NewRequest(http.MethodGet, pathWithRole, nil)
|
||||
if err != nil {
|
||||
return v, time.Time{}, err
|
||||
}
|
||||
req.Header.Set("X-aws-ec2-metadata-token", token)
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||
defer cancel()
|
||||
resp, err := e.httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return v, time.Time{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return v, time.Time{}, fmt.Errorf("%s %s failed: %s", req.Method, req.URL.String(), resp.Status)
|
||||
}
|
||||
|
||||
var ec2Resp struct {
|
||||
AccessKeyID string `json:"AccessKeyId"`
|
||||
SecretAccessKey string `json:"SecretAccessKey"`
|
||||
Token string `json:"Token"`
|
||||
Expiration time.Time `json:"Expiration"`
|
||||
}
|
||||
|
||||
err = json.NewDecoder(resp.Body).Decode(&ec2Resp)
|
||||
if err != nil {
|
||||
return v, time.Time{}, err
|
||||
}
|
||||
|
||||
v.AccessKeyID = ec2Resp.AccessKeyID
|
||||
v.SecretAccessKey = ec2Resp.SecretAccessKey
|
||||
v.SessionToken = ec2Resp.Token
|
||||
|
||||
return v, ec2Resp.Expiration, nil
|
||||
}
|
||||
|
||||
// RetrieveWithContext retrieves the keys from the AWS service.
|
||||
func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
|
||||
v := credentials.Value{ProviderName: ec2ProviderName}
|
||||
|
||||
token, err := e.getToken(ctx)
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
|
||||
role, err := e.getRoleName(ctx, token)
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
|
||||
v, exp, err := e.getCredentials(ctx, token, role)
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
if !v.HasKeys() {
|
||||
return v, errors.New("failed to retrieve EC2 keys")
|
||||
}
|
||||
e.expiration = exp.Add(-e.expiryWindow)
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Retrieve retrieves the keys from the AWS service.
|
||||
func (e *EC2Provider) Retrieve() (credentials.Value, error) {
|
||||
return e.RetrieveWithContext(context.Background())
|
||||
}
|
||||
|
||||
// IsExpired returns true if the credentials are expired.
|
||||
func (e *EC2Provider) IsExpired() bool {
|
||||
return e.expiration.Before(time.Now())
|
||||
}
|
112
internal/credproviders/ecs_provider.go
Normal file
112
internal/credproviders/ecs_provider.go
Normal file
@ -0,0 +1,112 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package credproviders
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||
)
|
||||
|
||||
const (
|
||||
// ecsProviderName provides a name of ECS provider
|
||||
ecsProviderName = "ECSProvider"
|
||||
|
||||
awsRelativeURI = "http://169.254.170.2/"
|
||||
)
|
||||
|
||||
// An ECSProvider retrieves credentials from ECS metadata.
|
||||
type ECSProvider struct {
|
||||
AwsContainerCredentialsRelativeURIEnv EnvVar
|
||||
|
||||
httpClient *http.Client
|
||||
expiration time.Time
|
||||
|
||||
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
|
||||
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
|
||||
//
|
||||
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
||||
// 10 seconds before the credentials are actually expired.
|
||||
expiryWindow time.Duration
|
||||
}
|
||||
|
||||
// NewECSProvider returns a pointer to an ECS credential provider.
|
||||
func NewECSProvider(httpClient *http.Client, expiryWindow time.Duration) *ECSProvider {
|
||||
return &ECSProvider{
|
||||
// AwsContainerCredentialsRelativeURIEnv is the environment variable for AWS_CONTAINER_CREDENTIALS_RELATIVE_URI
|
||||
AwsContainerCredentialsRelativeURIEnv: EnvVar("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"),
|
||||
httpClient: httpClient,
|
||||
expiryWindow: expiryWindow,
|
||||
}
|
||||
}
|
||||
|
||||
// RetrieveWithContext retrieves the keys from the AWS service.
|
||||
func (e *ECSProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
|
||||
const defaultHTTPTimeout = 10 * time.Second
|
||||
|
||||
v := credentials.Value{ProviderName: ecsProviderName}
|
||||
|
||||
relativeEcsURI := e.AwsContainerCredentialsRelativeURIEnv.Get()
|
||||
if len(relativeEcsURI) == 0 {
|
||||
return v, errors.New("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is missing")
|
||||
}
|
||||
fullURI := awsRelativeURI + relativeEcsURI
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, fullURI, nil)
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
|
||||
defer cancel()
|
||||
resp, err := e.httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return v, fmt.Errorf("response failure: %s", resp.Status)
|
||||
}
|
||||
|
||||
var ecsResp struct {
|
||||
AccessKeyID string `json:"AccessKeyId"`
|
||||
SecretAccessKey string `json:"SecretAccessKey"`
|
||||
Token string `json:"Token"`
|
||||
Expiration time.Time `json:"Expiration"`
|
||||
}
|
||||
|
||||
err = json.NewDecoder(resp.Body).Decode(&ecsResp)
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
|
||||
v.AccessKeyID = ecsResp.AccessKeyID
|
||||
v.SecretAccessKey = ecsResp.SecretAccessKey
|
||||
v.SessionToken = ecsResp.Token
|
||||
if !v.HasKeys() {
|
||||
return v, errors.New("failed to retrieve ECS keys")
|
||||
}
|
||||
e.expiration = ecsResp.Expiration.Add(-e.expiryWindow)
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// Retrieve retrieves the keys from the AWS service.
|
||||
func (e *ECSProvider) Retrieve() (credentials.Value, error) {
|
||||
return e.RetrieveWithContext(context.Background())
|
||||
}
|
||||
|
||||
// IsExpired returns true if the credentials are expired.
|
||||
func (e *ECSProvider) IsExpired() bool {
|
||||
return e.expiration.Before(time.Now())
|
||||
}
|
69
internal/credproviders/env_provider.go
Normal file
69
internal/credproviders/env_provider.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package credproviders
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||
)
|
||||
|
||||
// envProviderName provides a name of Env provider
|
||||
const envProviderName = "EnvProvider"
|
||||
|
||||
// EnvVar is an environment variable
|
||||
type EnvVar string
|
||||
|
||||
// Get retrieves the environment variable
|
||||
func (ev EnvVar) Get() string {
|
||||
return os.Getenv(string(ev))
|
||||
}
|
||||
|
||||
// A EnvProvider retrieves credentials from the environment variables of the
|
||||
// running process. Environment credentials never expire.
|
||||
type EnvProvider struct {
|
||||
AwsAccessKeyIDEnv EnvVar
|
||||
AwsSecretAccessKeyEnv EnvVar
|
||||
AwsSessionTokenEnv EnvVar
|
||||
|
||||
retrieved bool
|
||||
}
|
||||
|
||||
// NewEnvProvider returns a pointer to an ECS credential provider.
|
||||
func NewEnvProvider() *EnvProvider {
|
||||
return &EnvProvider{
|
||||
// AwsAccessKeyIDEnv is the environment variable for AWS_ACCESS_KEY_ID
|
||||
AwsAccessKeyIDEnv: EnvVar("AWS_ACCESS_KEY_ID"),
|
||||
// AwsSecretAccessKeyEnv is the environment variable for AWS_SECRET_ACCESS_KEY
|
||||
AwsSecretAccessKeyEnv: EnvVar("AWS_SECRET_ACCESS_KEY"),
|
||||
// AwsSessionTokenEnv is the environment variable for AWS_SESSION_TOKEN
|
||||
AwsSessionTokenEnv: EnvVar("AWS_SESSION_TOKEN"),
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve retrieves the keys from the environment.
|
||||
func (e *EnvProvider) Retrieve() (credentials.Value, error) {
|
||||
e.retrieved = false
|
||||
|
||||
v := credentials.Value{
|
||||
AccessKeyID: e.AwsAccessKeyIDEnv.Get(),
|
||||
SecretAccessKey: e.AwsSecretAccessKeyEnv.Get(),
|
||||
SessionToken: e.AwsSessionTokenEnv.Get(),
|
||||
ProviderName: envProviderName,
|
||||
}
|
||||
err := verify(v)
|
||||
if err == nil {
|
||||
e.retrieved = true
|
||||
}
|
||||
|
||||
return v, err
|
||||
}
|
||||
|
||||
// IsExpired returns true if the credentials have not been retrieved.
|
||||
func (e *EnvProvider) IsExpired() bool {
|
||||
return !e.retrieved
|
||||
}
|
103
internal/credproviders/imds_provider.go
Normal file
103
internal/credproviders/imds_provider.go
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package credproviders
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||
)
|
||||
|
||||
const (
|
||||
// AzureProviderName provides a name of Azure provider
|
||||
AzureProviderName = "AzureProvider"
|
||||
|
||||
azureURI = "http://169.254.169.254/metadata/identity/oauth2/token"
|
||||
)
|
||||
|
||||
// An AzureProvider retrieves credentials from Azure IMDS.
|
||||
type AzureProvider struct {
|
||||
httpClient *http.Client
|
||||
expiration time.Time
|
||||
expiryWindow time.Duration
|
||||
}
|
||||
|
||||
// NewAzureProvider returns a pointer to an Azure credential provider.
|
||||
func NewAzureProvider(httpClient *http.Client, expiryWindow time.Duration) *AzureProvider {
|
||||
return &AzureProvider{
|
||||
httpClient: httpClient,
|
||||
expiration: time.Time{},
|
||||
expiryWindow: expiryWindow,
|
||||
}
|
||||
}
|
||||
|
||||
// RetrieveWithContext retrieves the keys from the Azure service.
|
||||
func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
|
||||
v := credentials.Value{ProviderName: AzureProviderName}
|
||||
req, err := http.NewRequest(http.MethodGet, azureURI, nil)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
|
||||
}
|
||||
q := make(url.Values)
|
||||
q.Set("api-version", "2018-02-01")
|
||||
q.Set("resource", "https://vault.azure.net")
|
||||
req.URL.RawQuery = q.Encode()
|
||||
req.Header.Set("Metadata", "true")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := a.httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("unable to retrieve Azure credentials: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("unable to retrieve Azure credentials: error reading response body: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return v, fmt.Errorf("unable to retrieve Azure credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body)
|
||||
}
|
||||
var tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn string `json:"expires_in"`
|
||||
}
|
||||
// Attempt to read body as JSON
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
if err != nil {
|
||||
return v, fmt.Errorf("unable to retrieve Azure credentials: error reading body JSON: %w (response body: %s)", err, body)
|
||||
}
|
||||
if tokenResponse.AccessToken == "" {
|
||||
return v, fmt.Errorf("unable to retrieve Azure credentials: got unexpected empty accessToken from Azure Metadata Server. Response body: %s", body)
|
||||
}
|
||||
v.SessionToken = tokenResponse.AccessToken
|
||||
|
||||
expiresIn, err := time.ParseDuration(tokenResponse.ExpiresIn + "s")
|
||||
if err != nil {
|
||||
return v, err
|
||||
}
|
||||
if expiration := expiresIn - a.expiryWindow; expiration > 0 {
|
||||
a.expiration = time.Now().Add(expiration)
|
||||
}
|
||||
|
||||
return v, err
|
||||
}
|
||||
|
||||
// Retrieve retrieves the keys from the Azure service.
|
||||
func (a *AzureProvider) Retrieve() (credentials.Value, error) {
|
||||
return a.RetrieveWithContext(context.Background())
|
||||
}
|
||||
|
||||
// IsExpired returns if the credentials have been retrieved.
|
||||
func (a *AzureProvider) IsExpired() bool {
|
||||
return a.expiration.Before(time.Now())
|
||||
}
|
59
internal/credproviders/static_provider.go
Normal file
59
internal/credproviders/static_provider.go
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package credproviders
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/aws/credentials"
|
||||
)
|
||||
|
||||
// staticProviderName provides a name of Static provider
|
||||
const staticProviderName = "StaticProvider"
|
||||
|
||||
// A StaticProvider is a set of credentials which are set programmatically,
|
||||
// and will never expire.
|
||||
type StaticProvider struct {
|
||||
credentials.Value
|
||||
|
||||
verified bool
|
||||
err error
|
||||
}
|
||||
|
||||
func verify(v credentials.Value) error {
|
||||
if !v.HasKeys() {
|
||||
return errors.New("failed to retrieve ACCESS_KEY_ID and SECRET_ACCESS_KEY")
|
||||
}
|
||||
if v.AccessKeyID != "" && v.SecretAccessKey == "" {
|
||||
return errors.New("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing")
|
||||
}
|
||||
if v.AccessKeyID == "" && v.SecretAccessKey != "" {
|
||||
return errors.New("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing")
|
||||
}
|
||||
if v.AccessKeyID == "" && v.SecretAccessKey == "" && v.SessionToken != "" {
|
||||
return errors.New("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing")
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Retrieve returns the credentials or error if the credentials are invalid.
|
||||
func (s *StaticProvider) Retrieve() (credentials.Value, error) {
|
||||
if !s.verified {
|
||||
s.err = verify(s.Value)
|
||||
s.Value.ProviderName = staticProviderName
|
||||
s.verified = true
|
||||
}
|
||||
return s.Value, s.err
|
||||
}
|
||||
|
||||
// IsExpired returns if the credentials are expired.
|
||||
//
|
||||
// For StaticProvider, the credentials never expired.
|
||||
func (s *StaticProvider) IsExpired() bool {
|
||||
return false
|
||||
}
|
40
internal/csfle/csfle.go
Normal file
40
internal/csfle/csfle.go
Normal file
@ -0,0 +1,40 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package csfle
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
const (
|
||||
EncryptedCacheCollection = "ecc"
|
||||
EncryptedStateCollection = "esc"
|
||||
EncryptedCompactionCollection = "ecoc"
|
||||
)
|
||||
|
||||
// GetEncryptedStateCollectionName returns the encrypted state collection name associated with dataCollectionName.
|
||||
func GetEncryptedStateCollectionName(efBSON bsoncore.Document, dataCollectionName string, stateCollection string) (string, error) {
|
||||
fieldName := stateCollection + "Collection"
|
||||
val, err := efBSON.LookupErr(fieldName)
|
||||
if err != nil {
|
||||
if !errors.Is(err, bsoncore.ErrElementNotFound) {
|
||||
return "", err
|
||||
}
|
||||
// Return default name.
|
||||
defaultName := "enxcol_." + dataCollectionName + "." + stateCollection
|
||||
return defaultName, nil
|
||||
}
|
||||
|
||||
stateCollectionName, ok := val.StringValueOK()
|
||||
if !ok {
|
||||
return "", fmt.Errorf("expected string for '%v', got: %v", fieldName, val.Type)
|
||||
}
|
||||
return stateCollectionName, nil
|
||||
}
|
106
internal/csot/csot.go
Normal file
106
internal/csot/csot.go
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package csot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type clientLevel struct{}
|
||||
|
||||
func isClientLevel(ctx context.Context) bool {
|
||||
val := ctx.Value(clientLevel{})
|
||||
if val == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return val.(bool)
|
||||
}
|
||||
|
||||
// IsTimeoutContext checks if the provided context has been assigned a deadline
|
||||
// or has unlimited retries.
|
||||
func IsTimeoutContext(ctx context.Context) bool {
|
||||
_, ok := ctx.Deadline()
|
||||
|
||||
return ok || isClientLevel(ctx)
|
||||
}
|
||||
|
||||
// WithTimeout will set the given timeout on the context, if no deadline has
|
||||
// already been set.
|
||||
//
|
||||
// This function assumes that the timeout field is static, given that the
|
||||
// timeout should be sourced from the client. Therefore, once a timeout function
|
||||
// parameter has been applied to the context, it will remain for the lifetime
|
||||
// of the context.
|
||||
func WithTimeout(parent context.Context, timeout *time.Duration) (context.Context, context.CancelFunc) {
|
||||
cancel := func() {}
|
||||
|
||||
if timeout == nil || IsTimeoutContext(parent) {
|
||||
// In the following conditions, do nothing:
|
||||
// 1. The parent already has a deadline
|
||||
// 2. The parent does not have a deadline, but a client-level timeout has
|
||||
// been applied.
|
||||
// 3. The parent does not have a deadline, there is not client-level
|
||||
// timeout, and the timeout parameter DNE.
|
||||
return parent, cancel
|
||||
}
|
||||
|
||||
// If a client-level timeout has not been applied, then apply it.
|
||||
parent = context.WithValue(parent, clientLevel{}, true)
|
||||
|
||||
dur := *timeout
|
||||
|
||||
if dur == 0 {
|
||||
// If the parent does not have a deadline and the timeout is zero, then
|
||||
// do nothing.
|
||||
return parent, cancel
|
||||
}
|
||||
|
||||
// If the parent does not have a dealine and the timeout is non-zero, then
|
||||
// apply the timeout.
|
||||
return context.WithTimeout(parent, dur)
|
||||
}
|
||||
|
||||
// WithServerSelectionTimeout creates a context with a timeout that is the
|
||||
// minimum of serverSelectionTimeoutMS and context deadline. The usage of
|
||||
// non-positive values for serverSelectionTimeoutMS are an anti-pattern and are
|
||||
// not considered in this calculation.
|
||||
func WithServerSelectionTimeout(
|
||||
parent context.Context,
|
||||
serverSelectionTimeout time.Duration,
|
||||
) (context.Context, context.CancelFunc) {
|
||||
if serverSelectionTimeout <= 0 {
|
||||
return parent, func() {}
|
||||
}
|
||||
|
||||
return context.WithTimeout(parent, serverSelectionTimeout)
|
||||
}
|
||||
|
||||
// ZeroRTTMonitor implements the RTTMonitor interface and is used internally for testing. It returns 0 for all
|
||||
// RTT calculations and an empty string for RTT statistics.
|
||||
type ZeroRTTMonitor struct{}
|
||||
|
||||
// EWMA implements the RTT monitor interface.
|
||||
func (zrm *ZeroRTTMonitor) EWMA() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Min implements the RTT monitor interface.
|
||||
func (zrm *ZeroRTTMonitor) Min() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// P90 implements the RTT monitor interface.
|
||||
func (zrm *ZeroRTTMonitor) P90() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Stats implements the RTT monitor interface.
|
||||
func (zrm *ZeroRTTMonitor) Stats() string {
|
||||
return ""
|
||||
}
|
249
internal/csot/csot_test.go
Normal file
249
internal/csot/csot_test.go
Normal file
@ -0,0 +1,249 @@
|
||||
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package csot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"gitea.psichedelico.com/go/bson/internal/ptrutil"
|
||||
)
|
||||
|
||||
func newTestContext(t *testing.T, timeout time.Duration) context.Context {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func TestWithServerSelectionTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
parent context.Context
|
||||
serverSelectionTimeout time.Duration
|
||||
wantTimeout time.Duration
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "no context deadine and ssto is zero",
|
||||
parent: context.Background(),
|
||||
serverSelectionTimeout: 0,
|
||||
wantTimeout: 0,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "no context deadline and ssto is positive",
|
||||
parent: context.Background(),
|
||||
serverSelectionTimeout: 1,
|
||||
wantTimeout: 1,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "no context deadline and ssto is negative",
|
||||
parent: context.Background(),
|
||||
serverSelectionTimeout: -1,
|
||||
wantTimeout: 0,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "context deadline is zero and ssto is positive",
|
||||
parent: newTestContext(t, 0),
|
||||
serverSelectionTimeout: 1,
|
||||
wantTimeout: 1,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "context deadline is zero and ssto is negative",
|
||||
parent: newTestContext(t, 0),
|
||||
serverSelectionTimeout: -1,
|
||||
wantTimeout: 0,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "context deadline is negative and ssto is zero",
|
||||
parent: newTestContext(t, -1),
|
||||
serverSelectionTimeout: 0,
|
||||
wantTimeout: -1,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "context deadline is negative and ssto is positive",
|
||||
parent: newTestContext(t, -1),
|
||||
serverSelectionTimeout: 1,
|
||||
wantTimeout: 1,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "context deadline is negative and ssto is negative",
|
||||
parent: newTestContext(t, -1),
|
||||
serverSelectionTimeout: -1,
|
||||
wantTimeout: -1,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "context deadline is positive and ssto is zero",
|
||||
parent: newTestContext(t, 1),
|
||||
serverSelectionTimeout: 0,
|
||||
wantTimeout: 1,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "context deadline is positive and equal to ssto",
|
||||
parent: newTestContext(t, 1),
|
||||
serverSelectionTimeout: 1,
|
||||
wantTimeout: 1,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "context deadline is positive lt ssto",
|
||||
parent: newTestContext(t, 1),
|
||||
serverSelectionTimeout: 2,
|
||||
wantTimeout: 2,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "context deadline is positive gt ssto",
|
||||
parent: newTestContext(t, 2),
|
||||
serverSelectionTimeout: 1,
|
||||
wantTimeout: 2,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "context deadline is positive and ssto is negative",
|
||||
parent: newTestContext(t, -1),
|
||||
serverSelectionTimeout: -1,
|
||||
wantTimeout: 1,
|
||||
wantOk: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test // Capture the range variable
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := WithServerSelectionTimeout(test.parent, test.serverSelectionTimeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
deadline, gotOk := ctx.Deadline()
|
||||
assert.Equal(t, test.wantOk, gotOk)
|
||||
|
||||
if gotOk {
|
||||
delta := time.Until(deadline) - test.wantTimeout
|
||||
tolerance := 10 * time.Millisecond
|
||||
|
||||
assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance)
|
||||
assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
parent context.Context
|
||||
timeout *time.Duration
|
||||
wantTimeout time.Duration
|
||||
wantDeadline bool
|
||||
wantValues []interface{}
|
||||
}{
|
||||
{
|
||||
name: "deadline set with non-zero timeout",
|
||||
parent: newTestContext(t, 1),
|
||||
timeout: ptrutil.Ptr(time.Duration(2)),
|
||||
wantTimeout: 1,
|
||||
wantDeadline: true,
|
||||
wantValues: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "deadline set with zero timeout",
|
||||
parent: newTestContext(t, 1),
|
||||
timeout: ptrutil.Ptr(time.Duration(0)),
|
||||
wantTimeout: 1,
|
||||
wantDeadline: true,
|
||||
wantValues: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "deadline set with nil timeout",
|
||||
parent: newTestContext(t, 1),
|
||||
timeout: nil,
|
||||
wantTimeout: 1,
|
||||
wantDeadline: true,
|
||||
wantValues: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "deadline unset with non-zero timeout",
|
||||
parent: context.Background(),
|
||||
timeout: ptrutil.Ptr(time.Duration(1)),
|
||||
wantTimeout: 1,
|
||||
wantDeadline: true,
|
||||
wantValues: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "deadline unset with zero timeout",
|
||||
parent: context.Background(),
|
||||
timeout: ptrutil.Ptr(time.Duration(0)),
|
||||
wantTimeout: 0,
|
||||
wantDeadline: false,
|
||||
wantValues: []interface{}{clientLevel{}},
|
||||
},
|
||||
{
|
||||
name: "deadline unset with nil timeout",
|
||||
parent: context.Background(),
|
||||
timeout: nil,
|
||||
wantTimeout: 0,
|
||||
wantDeadline: false,
|
||||
wantValues: []interface{}{},
|
||||
},
|
||||
{
|
||||
// If "clientLevel" has been set, but a new timeout is applied
|
||||
// to the context, then the constructed context should retain the old
|
||||
// timeout. To simplify the code, we assume the first timeout is static.
|
||||
name: "deadline unset with non-zero timeout at clientLevel",
|
||||
parent: context.WithValue(context.Background(), clientLevel{}, true),
|
||||
timeout: ptrutil.Ptr(time.Duration(1)),
|
||||
wantTimeout: 0,
|
||||
wantDeadline: false,
|
||||
wantValues: []interface{}{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test // Capture the range variable
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := WithTimeout(test.parent, test.timeout)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
deadline, gotDeadline := ctx.Deadline()
|
||||
assert.Equal(t, test.wantDeadline, gotDeadline)
|
||||
|
||||
if gotDeadline {
|
||||
delta := time.Until(deadline) - test.wantTimeout
|
||||
tolerance := 10 * time.Millisecond
|
||||
|
||||
assert.True(t, delta > -1*tolerance, "expected delta=%d > %d", delta, -1*tolerance)
|
||||
assert.True(t, delta <= tolerance, "expected delta=%d <= %d", delta, tolerance)
|
||||
}
|
||||
|
||||
for _, wantValue := range test.wantValues {
|
||||
assert.NotNil(t, ctx.Value(wantValue), "expected context to have value %v", wantValue)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
117
internal/decimal128/decinal128.go
Normal file
117
internal/decimal128/decinal128.go
Normal file
@ -0,0 +1,117 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package decimal128
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
|
||||
const (
|
||||
MaxDecimal128Exp = 6111
|
||||
MinDecimal128Exp = -6176
|
||||
)
|
||||
|
||||
func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) {
|
||||
div64 := uint64(div)
|
||||
a := h >> 32
|
||||
aq := a / div64
|
||||
ar := a % div64
|
||||
b := ar<<32 + h&(1<<32-1)
|
||||
bq := b / div64
|
||||
br := b % div64
|
||||
c := br<<32 + l>>32
|
||||
cq := c / div64
|
||||
cr := c % div64
|
||||
d := cr<<32 + l&(1<<32-1)
|
||||
dq := d / div64
|
||||
dr := d % div64
|
||||
return (aq<<32 | bq), (cq<<32 | dq), uint32(dr)
|
||||
}
|
||||
|
||||
// String returns a string representation of the decimal value.
|
||||
func String(h, l uint64) string {
|
||||
var posSign int // positive sign
|
||||
var exp int // exponent
|
||||
var high, low uint64 // significand high/low
|
||||
|
||||
if h>>63&1 == 0 {
|
||||
posSign = 1
|
||||
}
|
||||
|
||||
switch h >> 58 & (1<<5 - 1) {
|
||||
case 0x1F:
|
||||
return "NaN"
|
||||
case 0x1E:
|
||||
return "-Infinity"[posSign:]
|
||||
}
|
||||
|
||||
low = l
|
||||
if h>>61&3 == 3 {
|
||||
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
|
||||
// Implicit 0b100 prefix in significand.
|
||||
exp = int(h >> 47 & (1<<14 - 1))
|
||||
// Spec says all of these values are out of range.
|
||||
high, low = 0, 0
|
||||
} else {
|
||||
// Bits: 1*sign 14*exponent 113*significand
|
||||
exp = int(h >> 49 & (1<<14 - 1))
|
||||
high = h & (1<<49 - 1)
|
||||
}
|
||||
exp += MinDecimal128Exp
|
||||
|
||||
// Would be handled by the logic below, but that's trivial and common.
|
||||
if high == 0 && low == 0 && exp == 0 {
|
||||
return "-0"[posSign:]
|
||||
}
|
||||
|
||||
var repr [48]byte // Loop 5 times over 9 digits plus dot, negative sign, and leading zero.
|
||||
var last = len(repr)
|
||||
var i = len(repr)
|
||||
var dot = len(repr) + exp
|
||||
var rem uint32
|
||||
Loop:
|
||||
for d9 := 0; d9 < 5; d9++ {
|
||||
high, low, rem = divmod(high, low, 1e9)
|
||||
for d1 := 0; d1 < 9; d1++ {
|
||||
// Handle "-0.0", "0.00123400", "-1.00E-6", "1.050E+3", etc.
|
||||
if i < len(repr) && (dot == i || low == 0 && high == 0 && rem > 0 && rem < 10 && (dot < i-6 || exp > 0)) {
|
||||
exp += len(repr) - i
|
||||
i--
|
||||
repr[i] = '.'
|
||||
last = i - 1
|
||||
dot = len(repr) // Unmark.
|
||||
}
|
||||
c := '0' + byte(rem%10)
|
||||
rem /= 10
|
||||
i--
|
||||
repr[i] = c
|
||||
// Handle "0E+3", "1E+3", etc.
|
||||
if low == 0 && high == 0 && rem == 0 && i == len(repr)-1 && (dot < i-5 || exp > 0) {
|
||||
last = i
|
||||
break Loop
|
||||
}
|
||||
if c != '0' {
|
||||
last = i
|
||||
}
|
||||
// Break early. Works without it, but why.
|
||||
if dot > i && low == 0 && high == 0 && rem == 0 {
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
}
|
||||
repr[last-1] = '-'
|
||||
last--
|
||||
|
||||
if exp > 0 {
|
||||
return string(repr[last+posSign:]) + "E+" + strconv.Itoa(exp)
|
||||
}
|
||||
if exp < 0 {
|
||||
return string(repr[last+posSign:]) + "E" + strconv.Itoa(exp)
|
||||
}
|
||||
return string(repr[last+posSign:])
|
||||
}
|
85
internal/errutil/join.go
Normal file
85
internal/errutil/join.go
Normal file
@ -0,0 +1,85 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package errutil
|
||||
|
||||
import "errors"
|
||||
|
||||
// join is a Go 1.13-1.19 compatible version of [errors.Join]. It is only called
|
||||
// by Join in join_go1.19.go. It is included here in a file without build
|
||||
// constraints only for testing purposes.
|
||||
//
|
||||
// It is heavily based on Join from
|
||||
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
|
||||
func join(errs ...error) error {
|
||||
n := 0
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
n++
|
||||
}
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
e := &joinError{
|
||||
errs: make([]error, 0, n),
|
||||
}
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
e.errs = append(e.errs, err)
|
||||
}
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// joinError is a Go 1.13-1.19 compatible joinable error type. Its error
|
||||
// message is identical to [errors.Join], but it implements "Unwrap() error"
|
||||
// instead of "Unwrap() []error".
|
||||
//
|
||||
// It is heavily based on the joinError from
|
||||
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go
|
||||
type joinError struct {
|
||||
errs []error
|
||||
}
|
||||
|
||||
func (e *joinError) Error() string {
|
||||
var b []byte
|
||||
for i, err := range e.errs {
|
||||
if i > 0 {
|
||||
b = append(b, '\n')
|
||||
}
|
||||
b = append(b, err.Error()...)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// Unwrap returns another joinError with the same errors as the current
|
||||
// joinError except the first error in the slice. Continuing to call Unwrap
|
||||
// on each returned error will increment through every error in the slice. The
|
||||
// resulting behavior when using [errors.Is] and [errors.As] is similar to an
|
||||
// error created using [errors.Join] in Go 1.20+.
|
||||
func (e *joinError) Unwrap() error {
|
||||
if len(e.errs) == 1 {
|
||||
return e.errs[0]
|
||||
}
|
||||
return &joinError{errs: e.errs[1:]}
|
||||
}
|
||||
|
||||
// Is calls [errors.Is] with the first error in the slice.
|
||||
func (e *joinError) Is(target error) bool {
|
||||
if len(e.errs) == 0 {
|
||||
return false
|
||||
}
|
||||
return errors.Is(e.errs[0], target)
|
||||
}
|
||||
|
||||
// As calls [errors.As] with the first error in the slice.
|
||||
func (e *joinError) As(target interface{}) bool {
|
||||
if len(e.errs) == 0 {
|
||||
return false
|
||||
}
|
||||
return errors.As(e.errs[0], target)
|
||||
}
|
20
internal/errutil/join_go1.19.go
Normal file
20
internal/errutil/join_go1.19.go
Normal file
@ -0,0 +1,20 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build !go1.20
|
||||
// +build !go1.20
|
||||
|
||||
package errutil
|
||||
|
||||
// Join returns an error that wraps the given errors. Any nil error values are
|
||||
// discarded. Join returns nil if every value in errs is nil. The error formats
|
||||
// as the concatenation of the strings obtained by calling the Error method of
|
||||
// each element of errs, with a newline between each string.
|
||||
//
|
||||
// A non-nil error returned by Join implements the "Unwrap() error" method.
|
||||
func Join(errs ...error) error {
|
||||
return join(errs...)
|
||||
}
|
17
internal/errutil/join_go1.20.go
Normal file
17
internal/errutil/join_go1.20.go
Normal file
@ -0,0 +1,17 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build go1.20
|
||||
// +build go1.20
|
||||
|
||||
package errutil
|
||||
|
||||
import "errors"
|
||||
|
||||
// Join calls [errors.Join].
|
||||
func Join(errs ...error) error {
|
||||
return errors.Join(errs...)
|
||||
}
|
243
internal/errutil/join_test.go
Normal file
243
internal/errutil/join_test.go
Normal file
@ -0,0 +1,243 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package errutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
)
|
||||
|
||||
// TestJoin_Nil asserts that join returns a nil error for the same inputs that
|
||||
// [errors.Join] returns a nil error.
|
||||
func TestJoin_Nil(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, errors.Join(), join(), "errors.Join() != join()")
|
||||
assert.Equal(t, errors.Join(nil), join(nil), "errors.Join(nil) != join(nil)")
|
||||
assert.Equal(t, errors.Join(nil, nil), join(nil, nil), "errors.Join(nil, nil) != join(nil, nil)")
|
||||
}
|
||||
|
||||
// TestJoin_Error asserts that join returns an error with the same error message
|
||||
// as the error returned by [errors.Join].
|
||||
func TestJoin_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err1 := errors.New("err1")
|
||||
err2 := errors.New("err2")
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
errs []error
|
||||
}{{
|
||||
desc: "single error",
|
||||
errs: []error{err1},
|
||||
}, {
|
||||
desc: "two errors",
|
||||
errs: []error{err1, err2},
|
||||
}, {
|
||||
desc: "two errors and a nil value",
|
||||
errs: []error{err1, nil, err2},
|
||||
}}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test // Capture range variable.
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
want := errors.Join(test.errs...).Error()
|
||||
got := join(test.errs...).Error()
|
||||
assert.Equal(t,
|
||||
want,
|
||||
got,
|
||||
"errors.Join().Error() != join().Error() for input %v",
|
||||
test.errs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
|
||||
// to the error returned by [errors.Join] when passed to [errors.Is].
|
||||
func TestJoin_ErrorsIs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err1 := errors.New("err1")
|
||||
err2 := errors.New("err2")
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
errs []error
|
||||
target error
|
||||
}{{
|
||||
desc: "one error with a matching target",
|
||||
errs: []error{err1},
|
||||
target: err1,
|
||||
}, {
|
||||
desc: "one error with a non-matching target",
|
||||
errs: []error{err1},
|
||||
target: err2,
|
||||
}, {
|
||||
desc: "nil error",
|
||||
errs: []error{nil},
|
||||
target: err1,
|
||||
}, {
|
||||
desc: "no errors",
|
||||
errs: []error{},
|
||||
target: err1,
|
||||
}, {
|
||||
desc: "two different errors with a matching target",
|
||||
errs: []error{err1, err2},
|
||||
target: err2,
|
||||
}, {
|
||||
desc: "two identical errors with a matching target",
|
||||
errs: []error{err1, err1},
|
||||
target: err1,
|
||||
}, {
|
||||
desc: "wrapped error with a matching target",
|
||||
errs: []error{fmt.Errorf("error: %w", err1)},
|
||||
target: err1,
|
||||
}, {
|
||||
desc: "nested joined error with a matching target",
|
||||
errs: []error{err1, join(err2, errors.New("nope"))},
|
||||
target: err2,
|
||||
}, {
|
||||
desc: "nested joined error with no matching targets",
|
||||
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
|
||||
target: err2,
|
||||
}, {
|
||||
desc: "nested joined error with a wrapped matching target",
|
||||
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
|
||||
target: err1,
|
||||
}, {
|
||||
desc: "context.DeadlineExceeded",
|
||||
errs: []error{err1, nil, context.DeadlineExceeded, err2},
|
||||
target: context.DeadlineExceeded,
|
||||
}, {
|
||||
desc: "wrapped context.DeadlineExceeded",
|
||||
errs: []error{err1, nil, fmt.Errorf("error: %w", context.DeadlineExceeded), err2},
|
||||
target: context.DeadlineExceeded,
|
||||
}}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test // Capture range variable.
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
// Assert that top-level errors returned by errors.Join and join
|
||||
// behave the same with errors.Is.
|
||||
want := errors.Join(test.errs...)
|
||||
got := join(test.errs...)
|
||||
assert.Equal(t,
|
||||
errors.Is(want, test.target),
|
||||
errors.Is(got, test.target),
|
||||
"errors.Join() and join() behave differently with errors.Is")
|
||||
|
||||
// Assert that wrapped errors returned by errors.Join and join
|
||||
// behave the same with errors.Is.
|
||||
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
|
||||
got = fmt.Errorf("error: %w", join(test.errs...))
|
||||
assert.Equal(t,
|
||||
errors.Is(want, test.target),
|
||||
errors.Is(got, test.target),
|
||||
"errors.Join() and join(), when wrapped, behave differently with errors.Is")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type errType1 struct{}
|
||||
|
||||
func (errType1) Error() string { return "" }
|
||||
|
||||
type errType2 struct{}
|
||||
|
||||
func (errType2) Error() string { return "" }
|
||||
|
||||
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically
|
||||
// to the error returned by [errors.Join] when passed to [errors.As].
|
||||
func TestJoin_ErrorsAs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err1 := errType1{}
|
||||
err2 := errType2{}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
errs []error
|
||||
target interface{}
|
||||
}{{
|
||||
desc: "one error with a matching target",
|
||||
errs: []error{err1},
|
||||
target: &errType1{},
|
||||
}, {
|
||||
desc: "one error with a non-matching target",
|
||||
errs: []error{err1},
|
||||
target: &errType2{},
|
||||
}, {
|
||||
desc: "nil error",
|
||||
errs: []error{nil},
|
||||
target: &errType1{},
|
||||
}, {
|
||||
desc: "no errors",
|
||||
errs: []error{},
|
||||
target: &errType1{},
|
||||
}, {
|
||||
desc: "two different errors with a matching target",
|
||||
errs: []error{err1, err2},
|
||||
target: &errType2{},
|
||||
}, {
|
||||
desc: "two identical errors with a matching target",
|
||||
errs: []error{err1, err1},
|
||||
target: &errType1{},
|
||||
}, {
|
||||
desc: "wrapped error with a matching target",
|
||||
errs: []error{fmt.Errorf("error: %w", err1)},
|
||||
target: &errType1{},
|
||||
}, {
|
||||
desc: "nested joined error with a matching target",
|
||||
errs: []error{err1, join(err2, errors.New("nope"))},
|
||||
target: &errType2{},
|
||||
}, {
|
||||
desc: "nested joined error with no matching targets",
|
||||
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))},
|
||||
target: &errType2{},
|
||||
}, {
|
||||
desc: "nested joined error with a wrapped matching target",
|
||||
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2},
|
||||
target: &errType1{},
|
||||
}, {
|
||||
desc: "context.DeadlineExceeded",
|
||||
errs: []error{err1, nil, context.DeadlineExceeded, err2},
|
||||
target: &errType2{},
|
||||
}}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test // Capture range variable.
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
// Assert that top-level errors returned by errors.Join and join
|
||||
// behave the same with errors.As.
|
||||
want := errors.Join(test.errs...)
|
||||
got := join(test.errs...)
|
||||
assert.Equal(t,
|
||||
errors.As(want, test.target),
|
||||
errors.As(got, test.target),
|
||||
"errors.Join() and join() behave differently with errors.As")
|
||||
|
||||
// Assert that wrapped errors returned by errors.Join and join
|
||||
// behave the same with errors.As.
|
||||
want = fmt.Errorf("error: %w", errors.Join(test.errs...))
|
||||
got = fmt.Errorf("error: %w", join(test.errs...))
|
||||
assert.Equal(t,
|
||||
errors.As(want, test.target),
|
||||
errors.As(got, test.target),
|
||||
"errors.Join() and join(), when wrapped, behave differently with errors.As")
|
||||
})
|
||||
}
|
||||
}
|
63
internal/failpoint/failpoint.go
Normal file
63
internal/failpoint/failpoint.go
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (C) MongoDB, Inc. 2024-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package failpoint
|
||||
|
||||
import (
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
)
|
||||
|
||||
const (
|
||||
// ModeAlwaysOn is the fail point mode that enables the fail point for an
|
||||
// indefinite number of matching commands.
|
||||
ModeAlwaysOn = "alwaysOn"
|
||||
|
||||
// ModeOff is the fail point mode that disables the fail point.
|
||||
ModeOff = "off"
|
||||
)
|
||||
|
||||
// FailPoint is used to configure a server fail point. It is intended to be
|
||||
// passed as the command argument to RunCommand.
|
||||
//
|
||||
// For more information about fail points, see
|
||||
// https://github.com/mongodb/specifications/tree/HEAD/source/transactions/tests#server-fail-point
|
||||
type FailPoint struct {
|
||||
ConfigureFailPoint string `bson:"configureFailPoint"`
|
||||
// Mode should be a string, FailPointMode, or map[string]interface{}
|
||||
Mode interface{} `bson:"mode"`
|
||||
Data Data `bson:"data"`
|
||||
}
|
||||
|
||||
// Mode configures when a fail point will be enabled. It is used to set the
|
||||
// FailPoint.Mode field.
|
||||
type Mode struct {
|
||||
Times int32 `bson:"times"`
|
||||
Skip int32 `bson:"skip"`
|
||||
}
|
||||
|
||||
// Data configures how a fail point will behave. It is used to set the
|
||||
// FailPoint.Data field.
|
||||
type Data struct {
|
||||
FailCommands []string `bson:"failCommands,omitempty"`
|
||||
CloseConnection bool `bson:"closeConnection,omitempty"`
|
||||
ErrorCode int32 `bson:"errorCode,omitempty"`
|
||||
FailBeforeCommitExceptionCode int32 `bson:"failBeforeCommitExceptionCode,omitempty"`
|
||||
ErrorLabels *[]string `bson:"errorLabels,omitempty"`
|
||||
WriteConcernError *WriteConcernError `bson:"writeConcernError,omitempty"`
|
||||
BlockConnection bool `bson:"blockConnection,omitempty"`
|
||||
BlockTimeMS int32 `bson:"blockTimeMS,omitempty"`
|
||||
AppName string `bson:"appName,omitempty"`
|
||||
}
|
||||
|
||||
// WriteConcernError is the write concern error to return when the fail point is
|
||||
// triggered. It is used to set the FailPoint.Data.WriteConcernError field.
|
||||
type WriteConcernError struct {
|
||||
Code int32 `bson:"code"`
|
||||
Name string `bson:"codeName"`
|
||||
Errmsg string `bson:"errmsg"`
|
||||
ErrorLabels *[]string `bson:"errorLabels,omitempty"`
|
||||
ErrInfo bson.Raw `bson:"errInfo,omitempty"`
|
||||
}
|
13
internal/handshake/handshake.go
Normal file
13
internal/handshake/handshake.go
Normal file
@ -0,0 +1,13 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package handshake
|
||||
|
||||
// LegacyHello is the legacy version of the hello command.
|
||||
var LegacyHello = "isMaster"
|
||||
|
||||
// LegacyHelloLowercase is the lowercase, legacy version of the hello command.
|
||||
var LegacyHelloLowercase = "ismaster"
|
30
internal/httputil/httputil.go
Normal file
30
internal/httputil/httputil.go
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package httputil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// DefaultHTTPClient is the default HTTP client used across the driver.
|
||||
var DefaultHTTPClient = &http.Client{
|
||||
Transport: http.DefaultTransport.(*http.Transport).Clone(),
|
||||
}
|
||||
|
||||
// CloseIdleHTTPConnections closes any connections which were previously
|
||||
// connected from previous requests but are now sitting idle in a "keep-alive"
|
||||
// state. It does not interrupt any connections currently in use.
|
||||
//
|
||||
// Borrowed from the Go standard library.
|
||||
func CloseIdleHTTPConnections(client *http.Client) {
|
||||
type closeIdler interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
if tr, ok := client.Transport.(closeIdler); ok {
|
||||
tr.CloseIdleConnections()
|
||||
}
|
||||
}
|
14
internal/israce/norace.go
Normal file
14
internal/israce/norace.go
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build !race
|
||||
// +build !race
|
||||
|
||||
// Package israce reports if the Go race detector is enabled.
|
||||
package israce
|
||||
|
||||
// Enabled reports if the race detector is enabled.
|
||||
const Enabled = false
|
14
internal/israce/race.go
Normal file
14
internal/israce/race.go
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build race
|
||||
// +build race
|
||||
|
||||
// Package israce reports if the Go race detector is enabled.
|
||||
package israce
|
||||
|
||||
// Enabled reports if the race detector is enabled.
|
||||
const Enabled = true
|
313
internal/logger/component.go
Normal file
313
internal/logger/component.go
Normal file
@ -0,0 +1,313 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
)
|
||||
|
||||
const (
|
||||
CommandFailed = "Command failed"
|
||||
CommandStarted = "Command started"
|
||||
CommandSucceeded = "Command succeeded"
|
||||
ConnectionPoolCreated = "Connection pool created"
|
||||
ConnectionPoolReady = "Connection pool ready"
|
||||
ConnectionPoolCleared = "Connection pool cleared"
|
||||
ConnectionPoolClosed = "Connection pool closed"
|
||||
ConnectionCreated = "Connection created"
|
||||
ConnectionReady = "Connection ready"
|
||||
ConnectionClosed = "Connection closed"
|
||||
ConnectionCheckoutStarted = "Connection checkout started"
|
||||
ConnectionCheckoutFailed = "Connection checkout failed"
|
||||
ConnectionCheckedOut = "Connection checked out"
|
||||
ConnectionCheckedIn = "Connection checked in"
|
||||
ServerSelectionFailed = "Server selection failed"
|
||||
ServerSelectionStarted = "Server selection started"
|
||||
ServerSelectionSucceeded = "Server selection succeeded"
|
||||
ServerSelectionWaiting = "Waiting for suitable server to become available"
|
||||
TopologyClosed = "Stopped topology monitoring"
|
||||
TopologyDescriptionChanged = "Topology description changed"
|
||||
TopologyOpening = "Starting topology monitoring"
|
||||
TopologyServerClosed = "Stopped server monitoring"
|
||||
TopologyServerHeartbeatFailed = "Server heartbeat failed"
|
||||
TopologyServerHeartbeatStarted = "Server heartbeat started"
|
||||
TopologyServerHeartbeatSucceeded = "Server heartbeat succeeded"
|
||||
TopologyServerOpening = "Starting server monitoring"
|
||||
)
|
||||
|
||||
const (
|
||||
KeyAwaited = "awaited"
|
||||
KeyCommand = "command"
|
||||
KeyCommandName = "commandName"
|
||||
KeyDatabaseName = "databaseName"
|
||||
KeyDriverConnectionID = "driverConnectionId"
|
||||
KeyDurationMS = "durationMS"
|
||||
KeyError = "error"
|
||||
KeyFailure = "failure"
|
||||
KeyMaxConnecting = "maxConnecting"
|
||||
KeyMaxIdleTimeMS = "maxIdleTimeMS"
|
||||
KeyMaxPoolSize = "maxPoolSize"
|
||||
KeyMessage = "message"
|
||||
KeyMinPoolSize = "minPoolSize"
|
||||
KeyNewDescription = "newDescription"
|
||||
KeyOperation = "operation"
|
||||
KeyOperationID = "operationId"
|
||||
KeyPreviousDescription = "previousDescription"
|
||||
KeyRemainingTimeMS = "remainingTimeMS"
|
||||
KeyReason = "reason"
|
||||
KeyReply = "reply"
|
||||
KeyRequestID = "requestId"
|
||||
KeySelector = "selector"
|
||||
KeyServerConnectionID = "serverConnectionId"
|
||||
KeyServerHost = "serverHost"
|
||||
KeyServerPort = "serverPort"
|
||||
KeyServiceID = "serviceId"
|
||||
KeyTimestamp = "timestamp"
|
||||
KeyTopologyDescription = "topologyDescription"
|
||||
KeyTopologyID = "topologyId"
|
||||
)
|
||||
|
||||
// KeyValues is a list of key-value pairs.
|
||||
type KeyValues []interface{}
|
||||
|
||||
// Add adds a key-value pair to an instance of a KeyValues list.
|
||||
func (kvs *KeyValues) Add(key string, value interface{}) {
|
||||
*kvs = append(*kvs, key, value)
|
||||
}
|
||||
|
||||
const (
|
||||
ReasonConnClosedStale = "Connection became stale because the pool was cleared"
|
||||
ReasonConnClosedIdle = "Connection has been available but unused for longer than the configured max idle time"
|
||||
ReasonConnClosedError = "An error occurred while using the connection"
|
||||
ReasonConnClosedPoolClosed = "Connection pool was closed"
|
||||
ReasonConnCheckoutFailedTimout = "Wait queue timeout elapsed without a connection becoming available"
|
||||
ReasonConnCheckoutFailedError = "An error occurred while trying to establish a new connection"
|
||||
ReasonConnCheckoutFailedPoolClosed = "Connection pool was closed"
|
||||
)
|
||||
|
||||
// Component is an enumeration representing the "components" which can be
|
||||
// logged against. A LogLevel can be configured on a per-component basis.
|
||||
type Component int
|
||||
|
||||
const (
|
||||
// ComponentAll enables logging for all components.
|
||||
ComponentAll Component = iota
|
||||
|
||||
// ComponentCommand enables command monitor logging.
|
||||
ComponentCommand
|
||||
|
||||
// ComponentTopology enables topology logging.
|
||||
ComponentTopology
|
||||
|
||||
// ComponentServerSelection enables server selection logging.
|
||||
ComponentServerSelection
|
||||
|
||||
// ComponentConnection enables connection services logging.
|
||||
ComponentConnection
|
||||
)
|
||||
|
||||
const (
|
||||
mongoDBLogAllEnvVar = "MONGODB_LOG_ALL"
|
||||
mongoDBLogCommandEnvVar = "MONGODB_LOG_COMMAND"
|
||||
mongoDBLogTopologyEnvVar = "MONGODB_LOG_TOPOLOGY"
|
||||
mongoDBLogServerSelectionEnvVar = "MONGODB_LOG_SERVER_SELECTION"
|
||||
mongoDBLogConnectionEnvVar = "MONGODB_LOG_CONNECTION"
|
||||
)
|
||||
|
||||
var componentEnvVarMap = map[string]Component{
|
||||
mongoDBLogAllEnvVar: ComponentAll,
|
||||
mongoDBLogCommandEnvVar: ComponentCommand,
|
||||
mongoDBLogTopologyEnvVar: ComponentTopology,
|
||||
mongoDBLogServerSelectionEnvVar: ComponentServerSelection,
|
||||
mongoDBLogConnectionEnvVar: ComponentConnection,
|
||||
}
|
||||
|
||||
// EnvHasComponentVariables returns true if the environment contains any of the
|
||||
// component environment variables.
|
||||
func EnvHasComponentVariables() bool {
|
||||
for envVar := range componentEnvVarMap {
|
||||
if os.Getenv(envVar) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Command is a struct defining common fields that must be included in all
|
||||
// commands.
|
||||
type Command struct {
|
||||
DriverConnectionID int64 // Driver's ID for the connection
|
||||
Name string // Command name
|
||||
DatabaseName string // Database name
|
||||
Message string // Message associated with the command
|
||||
OperationID int32 // Driver-generated operation ID
|
||||
RequestID int64 // Driver-generated request ID
|
||||
ServerConnectionID *int64 // Server's ID for the connection used for the command
|
||||
ServerHost string // Hostname or IP address for the server
|
||||
ServerPort string // Port for the server
|
||||
ServiceID *bson.ObjectID // ID for the command in load balancer mode
|
||||
}
|
||||
|
||||
// SerializeCommand takes a command and a variable number of key-value pairs and
|
||||
// returns a slice of interface{} that can be passed to the logger for
|
||||
// structured logging.
|
||||
func SerializeCommand(cmd Command, extraKeysAndValues ...interface{}) KeyValues {
|
||||
// Initialize the boilerplate keys and values.
|
||||
keysAndValues := KeyValues{
|
||||
KeyCommandName, cmd.Name,
|
||||
KeyDatabaseName, cmd.DatabaseName,
|
||||
KeyDriverConnectionID, cmd.DriverConnectionID,
|
||||
KeyMessage, cmd.Message,
|
||||
KeyOperationID, cmd.OperationID,
|
||||
KeyRequestID, cmd.RequestID,
|
||||
KeyServerHost, cmd.ServerHost,
|
||||
}
|
||||
|
||||
// Add the extra keys and values.
|
||||
for i := 0; i < len(extraKeysAndValues); i += 2 {
|
||||
keysAndValues.Add(extraKeysAndValues[i].(string), extraKeysAndValues[i+1])
|
||||
}
|
||||
|
||||
port, err := strconv.ParseInt(cmd.ServerPort, 10, 32)
|
||||
if err == nil {
|
||||
keysAndValues.Add(KeyServerPort, port)
|
||||
}
|
||||
|
||||
// Add the "serverConnectionId" if it is not nil.
|
||||
if cmd.ServerConnectionID != nil {
|
||||
keysAndValues.Add(KeyServerConnectionID, *cmd.ServerConnectionID)
|
||||
}
|
||||
|
||||
// Add the "serviceId" if it is not nil.
|
||||
if cmd.ServiceID != nil {
|
||||
keysAndValues.Add(KeyServiceID, cmd.ServiceID.Hex())
|
||||
}
|
||||
|
||||
return keysAndValues
|
||||
}
|
||||
|
||||
// Connection contains data that all connection log messages MUST contain.
|
||||
type Connection struct {
|
||||
Message string // Message associated with the connection
|
||||
ServerHost string // Hostname or IP address for the server
|
||||
ServerPort string // Port for the server
|
||||
}
|
||||
|
||||
// SerializeConnection serializes a Connection message into a slice of keys and
|
||||
// values that can be passed to a logger.
|
||||
func SerializeConnection(conn Connection, extraKeysAndValues ...interface{}) KeyValues {
|
||||
// Initialize the boilerplate keys and values.
|
||||
keysAndValues := KeyValues{
|
||||
KeyMessage, conn.Message,
|
||||
KeyServerHost, conn.ServerHost,
|
||||
}
|
||||
|
||||
// Add the optional keys and values.
|
||||
for i := 0; i < len(extraKeysAndValues); i += 2 {
|
||||
keysAndValues.Add(extraKeysAndValues[i].(string), extraKeysAndValues[i+1])
|
||||
}
|
||||
|
||||
port, err := strconv.ParseInt(conn.ServerPort, 10, 32)
|
||||
if err == nil {
|
||||
keysAndValues.Add(KeyServerPort, port)
|
||||
}
|
||||
|
||||
return keysAndValues
|
||||
}
|
||||
|
||||
// Server contains data that all server messages MAY contain.
|
||||
type Server struct {
|
||||
DriverConnectionID int64 // Driver's ID for the connection
|
||||
TopologyID bson.ObjectID // Driver's unique ID for this topology
|
||||
Message string // Message associated with the topology
|
||||
ServerConnectionID *int64 // Server's ID for the connection
|
||||
ServerHost string // Hostname or IP address for the server
|
||||
ServerPort string // Port for the server
|
||||
}
|
||||
|
||||
// SerializeServer serializes a Server message into a slice of keys and
|
||||
// values that can be passed to a logger.
|
||||
func SerializeServer(srv Server, extraKV ...interface{}) KeyValues {
|
||||
// Initialize the boilerplate keys and values.
|
||||
keysAndValues := KeyValues{
|
||||
KeyDriverConnectionID, srv.DriverConnectionID,
|
||||
KeyMessage, srv.Message,
|
||||
KeyServerHost, srv.ServerHost,
|
||||
KeyTopologyID, srv.TopologyID.Hex(),
|
||||
}
|
||||
|
||||
if connID := srv.ServerConnectionID; connID != nil {
|
||||
keysAndValues.Add(KeyServerConnectionID, *connID)
|
||||
}
|
||||
|
||||
port, err := strconv.ParseInt(srv.ServerPort, 10, 32)
|
||||
if err == nil {
|
||||
keysAndValues.Add(KeyServerPort, port)
|
||||
}
|
||||
|
||||
// Add the optional keys and values.
|
||||
for i := 0; i < len(extraKV); i += 2 {
|
||||
keysAndValues.Add(extraKV[i].(string), extraKV[i+1])
|
||||
}
|
||||
|
||||
return keysAndValues
|
||||
}
|
||||
|
||||
// ServerSelection contains data that all server selection messages MUST
|
||||
// contain.
|
||||
type ServerSelection struct {
|
||||
Selector string
|
||||
OperationID *int32
|
||||
Operation string
|
||||
TopologyDescription string
|
||||
}
|
||||
|
||||
// SerializeServerSelection serializes a Topology message into a slice of keys
|
||||
// and values that can be passed to a logger.
|
||||
func SerializeServerSelection(srvSelection ServerSelection, extraKV ...interface{}) KeyValues {
|
||||
keysAndValues := KeyValues{
|
||||
KeySelector, srvSelection.Selector,
|
||||
KeyOperation, srvSelection.Operation,
|
||||
KeyTopologyDescription, srvSelection.TopologyDescription,
|
||||
}
|
||||
|
||||
if srvSelection.OperationID != nil {
|
||||
keysAndValues.Add(KeyOperationID, *srvSelection.OperationID)
|
||||
}
|
||||
|
||||
// Add the optional keys and values.
|
||||
for i := 0; i < len(extraKV); i += 2 {
|
||||
keysAndValues.Add(extraKV[i].(string), extraKV[i+1])
|
||||
}
|
||||
|
||||
return keysAndValues
|
||||
}
|
||||
|
||||
// Topology contains data that all topology messages MAY contain.
|
||||
type Topology struct {
|
||||
ID bson.ObjectID // Driver's unique ID for this topology
|
||||
Message string // Message associated with the topology
|
||||
}
|
||||
|
||||
// SerializeTopology serializes a Topology message into a slice of keys and
|
||||
// values that can be passed to a logger.
|
||||
func SerializeTopology(topo Topology, extraKV ...interface{}) KeyValues {
|
||||
keysAndValues := KeyValues{
|
||||
KeyTopologyID, topo.ID.Hex(),
|
||||
}
|
||||
|
||||
// Add the optional keys and values.
|
||||
for i := 0; i < len(extraKV); i += 2 {
|
||||
keysAndValues.Add(extraKV[i].(string), extraKV[i+1])
|
||||
}
|
||||
|
||||
return keysAndValues
|
||||
}
|
229
internal/logger/component_test.go
Normal file
229
internal/logger/component_test.go
Normal file
@ -0,0 +1,229 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package logger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
)
|
||||
|
||||
func verifySerialization(t *testing.T, got, want KeyValues) {
|
||||
t.Helper()
|
||||
|
||||
for i := 0; i < len(got); i += 2 {
|
||||
assert.Equal(t, want[i], got[i], "key position mismatch")
|
||||
assert.Equal(t, want[i+1], got[i+1], "value position mismatch for %q", want[i])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeCommand(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serverConnectionID := int64(100)
|
||||
serviceID := bson.NewObjectID()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd Command
|
||||
extraKeysAndValues []interface{}
|
||||
want KeyValues
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
want: KeyValues{
|
||||
KeyCommandName, "",
|
||||
KeyDatabaseName, "",
|
||||
KeyDriverConnectionID, int64(0),
|
||||
KeyMessage, "",
|
||||
KeyOperationID, int32(0),
|
||||
KeyRequestID, int64(0),
|
||||
KeyServerHost, "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complete Command object",
|
||||
cmd: Command{
|
||||
DriverConnectionID: 1,
|
||||
Name: "foo",
|
||||
DatabaseName: "db",
|
||||
Message: "bar",
|
||||
OperationID: 2,
|
||||
RequestID: 3,
|
||||
ServerHost: "localhost",
|
||||
ServerPort: "27017",
|
||||
ServerConnectionID: &serverConnectionID,
|
||||
ServiceID: &serviceID,
|
||||
},
|
||||
want: KeyValues{
|
||||
KeyCommandName, "foo",
|
||||
KeyDatabaseName, "db",
|
||||
KeyDriverConnectionID, int64(1),
|
||||
KeyMessage, "bar",
|
||||
KeyOperationID, int32(2),
|
||||
KeyRequestID, int64(3),
|
||||
KeyServerHost, "localhost",
|
||||
KeyServerPort, int64(27017),
|
||||
KeyServerConnectionID, serverConnectionID,
|
||||
KeyServiceID, serviceID.Hex(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := SerializeCommand(test.cmd, test.extraKeysAndValues...)
|
||||
verifySerialization(t, got, test.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
conn Connection
|
||||
extraKeysAndValues []interface{}
|
||||
want KeyValues
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
want: KeyValues{
|
||||
KeyMessage, "",
|
||||
KeyServerHost, "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complete Connection object",
|
||||
conn: Connection{
|
||||
Message: "foo",
|
||||
ServerHost: "localhost",
|
||||
ServerPort: "27017",
|
||||
},
|
||||
want: KeyValues{
|
||||
"message", "foo",
|
||||
"serverHost", "localhost",
|
||||
"serverPort", int64(27017),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := SerializeConnection(test.conn, test.extraKeysAndValues...)
|
||||
verifySerialization(t, got, test.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
topologyID := bson.NewObjectID()
|
||||
serverConnectionID := int64(100)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srv Server
|
||||
extraKeysAndValues []interface{}
|
||||
want KeyValues
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
want: KeyValues{
|
||||
KeyDriverConnectionID, int64(0),
|
||||
KeyMessage, "",
|
||||
KeyServerHost, "",
|
||||
KeyTopologyID, bson.ObjectID{}.Hex(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complete Server object",
|
||||
srv: Server{
|
||||
DriverConnectionID: 1,
|
||||
TopologyID: topologyID,
|
||||
Message: "foo",
|
||||
ServerConnectionID: &serverConnectionID,
|
||||
ServerHost: "localhost",
|
||||
ServerPort: "27017",
|
||||
},
|
||||
want: KeyValues{
|
||||
KeyDriverConnectionID, int64(1),
|
||||
KeyMessage, "foo",
|
||||
KeyServerHost, "localhost",
|
||||
KeyTopologyID, topologyID.Hex(),
|
||||
KeyServerConnectionID, serverConnectionID,
|
||||
KeyServerPort, int64(27017),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := SerializeServer(test.srv, test.extraKeysAndValues...)
|
||||
verifySerialization(t, got, test.want)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSerializeTopology(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
topologyID := bson.NewObjectID()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
topo Topology
|
||||
extraKeysAndValues []interface{}
|
||||
want KeyValues
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
want: KeyValues{
|
||||
KeyTopologyID, bson.ObjectID{}.Hex(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complete Server object",
|
||||
topo: Topology{
|
||||
ID: topologyID,
|
||||
Message: "foo",
|
||||
},
|
||||
want: KeyValues{
|
||||
KeyTopologyID, topologyID.Hex(),
|
||||
KeyMessage, "foo",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := SerializeTopology(test.topo, test.extraKeysAndValues...)
|
||||
verifySerialization(t, got, test.want)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
48
internal/logger/context.go
Normal file
48
internal/logger/context.go
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package logger
|
||||
|
||||
import "context"
|
||||
|
||||
// contextKey is a custom type used to prevent key collisions when using the
|
||||
// context package.
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
contextKeyOperation contextKey = "operation"
|
||||
contextKeyOperationID contextKey = "operationID"
|
||||
)
|
||||
|
||||
// WithOperationName adds the operation name to the context.
|
||||
func WithOperationName(ctx context.Context, operation string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyOperation, operation)
|
||||
}
|
||||
|
||||
// WithOperationID adds the operation ID to the context.
|
||||
func WithOperationID(ctx context.Context, operationID int32) context.Context {
|
||||
return context.WithValue(ctx, contextKeyOperationID, operationID)
|
||||
}
|
||||
|
||||
// OperationName returns the operation name from the context.
|
||||
func OperationName(ctx context.Context) (string, bool) {
|
||||
operationName := ctx.Value(contextKeyOperation)
|
||||
if operationName == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return operationName.(string), true
|
||||
}
|
||||
|
||||
// OperationID returns the operation ID from the context.
|
||||
func OperationID(ctx context.Context) (int32, bool) {
|
||||
operationID := ctx.Value(contextKeyOperationID)
|
||||
if operationID == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return operationID.(int32), true
|
||||
}
|
187
internal/logger/context_test.go
Normal file
187
internal/logger/context_test.go
Normal file
@ -0,0 +1,187 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"gitea.psichedelico.com/go/bson/internal/logger"
|
||||
)
|
||||
|
||||
func TestContext_WithOperationName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
opName string
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
ctx: context.Background(),
|
||||
opName: "foo",
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt // Capture the range variable.
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := logger.WithOperationName(tt.ctx, tt.opName)
|
||||
|
||||
opName, ok := logger.OperationName(ctx)
|
||||
assert.Equal(t, tt.ok, ok)
|
||||
|
||||
if ok {
|
||||
assert.Equal(t, tt.opName, opName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext_OperationName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
opName interface{}
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
ctx: context.Background(),
|
||||
opName: nil,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
name: "string type",
|
||||
ctx: context.Background(),
|
||||
opName: "foo",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "non-string type",
|
||||
ctx: context.Background(),
|
||||
opName: int32(1),
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt // Capture the range variable.
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
if opNameStr, ok := tt.opName.(string); ok {
|
||||
ctx = logger.WithOperationName(tt.ctx, opNameStr)
|
||||
}
|
||||
|
||||
opName, ok := logger.OperationName(ctx)
|
||||
assert.Equal(t, tt.ok, ok)
|
||||
|
||||
if ok {
|
||||
assert.Equal(t, tt.opName, opName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext_WithOperationID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
opID int32
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "non-zero",
|
||||
ctx: context.Background(),
|
||||
opID: 1,
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt // Capture the range variable.
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := logger.WithOperationID(tt.ctx, tt.opID)
|
||||
|
||||
opID, ok := logger.OperationID(ctx)
|
||||
assert.Equal(t, tt.ok, ok)
|
||||
|
||||
if ok {
|
||||
assert.Equal(t, tt.opID, opID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext_OperationID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
opID interface{}
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
ctx: context.Background(),
|
||||
opID: nil,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
name: "i32 type",
|
||||
ctx: context.Background(),
|
||||
opID: int32(1),
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "non-i32 type",
|
||||
ctx: context.Background(),
|
||||
opID: "foo",
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt // Capture the range variable.
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
if opIDI32, ok := tt.opID.(int32); ok {
|
||||
ctx = logger.WithOperationID(tt.ctx, opIDI32)
|
||||
}
|
||||
|
||||
opName, ok := logger.OperationID(ctx)
|
||||
assert.Equal(t, tt.ok, ok)
|
||||
|
||||
if ok {
|
||||
assert.Equal(t, tt.opID, opName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
63
internal/logger/io_sink.go
Normal file
63
internal/logger/io_sink.go
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package logger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IOSink writes a JSON-encoded message to the io.Writer.
|
||||
type IOSink struct {
|
||||
enc *json.Encoder
|
||||
|
||||
// encMu protects the encoder from concurrent writes. While the logger
|
||||
// itself does not concurrently write to the sink, the sink may be used
|
||||
// concurrently within the driver.
|
||||
encMu sync.Mutex
|
||||
}
|
||||
|
||||
// Compile-time check to ensure IOSink implements the LogSink interface.
|
||||
var _ LogSink = &IOSink{}
|
||||
|
||||
// NewIOSink will create an IOSink object that writes JSON messages to the
|
||||
// provided io.Writer.
|
||||
func NewIOSink(out io.Writer) *IOSink {
|
||||
return &IOSink{
|
||||
enc: json.NewEncoder(out),
|
||||
}
|
||||
}
|
||||
|
||||
// Info will write a JSON-encoded message to the io.Writer.
|
||||
func (sink *IOSink) Info(_ int, msg string, keysAndValues ...interface{}) {
|
||||
mapSize := len(keysAndValues) / 2
|
||||
if math.MaxInt-mapSize >= 2 {
|
||||
mapSize += 2
|
||||
}
|
||||
kvMap := make(map[string]interface{}, mapSize)
|
||||
|
||||
kvMap[KeyTimestamp] = time.Now().UnixNano()
|
||||
kvMap[KeyMessage] = msg
|
||||
|
||||
for i := 0; i < len(keysAndValues); i += 2 {
|
||||
kvMap[keysAndValues[i].(string)] = keysAndValues[i+1]
|
||||
}
|
||||
|
||||
sink.encMu.Lock()
|
||||
defer sink.encMu.Unlock()
|
||||
|
||||
_ = sink.enc.Encode(kvMap)
|
||||
}
|
||||
|
||||
// Error will write a JSON-encoded error message to the io.Writer.
|
||||
func (sink *IOSink) Error(err error, msg string, kv ...interface{}) {
|
||||
kv = append(kv, KeyError, err.Error())
|
||||
sink.Info(0, msg, kv...)
|
||||
}
|
74
internal/logger/level.go
Normal file
74
internal/logger/level.go
Normal file
@ -0,0 +1,74 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package logger
|
||||
|
||||
import "strings"
|
||||
|
||||
// DiffToInfo is the number of levels in the Go Driver that come before the
|
||||
// "Info" level. This should ensure that "Info" is the 0th level passed to the
|
||||
// sink.
|
||||
const DiffToInfo = 1
|
||||
|
||||
// Level is an enumeration representing the log severity levels supported by
|
||||
// the driver. The order of the logging levels is important. The driver expects
|
||||
// that a user will likely use the "logr" package to create a LogSink, which
|
||||
// defaults InfoLevel as 0. Any additions to the Level enumeration before the
|
||||
// InfoLevel will need to also update the "diffToInfo" constant.
|
||||
type Level int
|
||||
|
||||
const (
|
||||
// LevelOff suppresses logging.
|
||||
LevelOff Level = iota
|
||||
|
||||
// LevelInfo enables logging of informational messages. These logs are
|
||||
// high-level information about normal driver behavior.
|
||||
LevelInfo
|
||||
|
||||
// LevelDebug enables logging of debug messages. These logs can be
|
||||
// voluminous and are intended for detailed information that may be
|
||||
// helpful when debugging an application.
|
||||
LevelDebug
|
||||
)
|
||||
|
||||
const (
|
||||
levelLiteralOff = "off"
|
||||
levelLiteralEmergency = "emergency"
|
||||
levelLiteralAlert = "alert"
|
||||
levelLiteralCritical = "critical"
|
||||
levelLiteralError = "error"
|
||||
levelLiteralWarning = "warning"
|
||||
levelLiteralNotice = "notice"
|
||||
levelLiteralInfo = "info"
|
||||
levelLiteralDebug = "debug"
|
||||
levelLiteralTrace = "trace"
|
||||
)
|
||||
|
||||
var LevelLiteralMap = map[string]Level{
|
||||
levelLiteralOff: LevelOff,
|
||||
levelLiteralEmergency: LevelInfo,
|
||||
levelLiteralAlert: LevelInfo,
|
||||
levelLiteralCritical: LevelInfo,
|
||||
levelLiteralError: LevelInfo,
|
||||
levelLiteralWarning: LevelInfo,
|
||||
levelLiteralNotice: LevelInfo,
|
||||
levelLiteralInfo: LevelInfo,
|
||||
levelLiteralDebug: LevelDebug,
|
||||
levelLiteralTrace: LevelDebug,
|
||||
}
|
||||
|
||||
// ParseLevel will check if the given string is a valid environment variable
|
||||
// for a logging severity level. If it is, then it will return the associated
|
||||
// driver's Level. The default Level is “LevelOff”.
|
||||
func ParseLevel(str string) Level {
|
||||
for literal, level := range LevelLiteralMap {
|
||||
if strings.EqualFold(literal, str) {
|
||||
return level
|
||||
}
|
||||
}
|
||||
|
||||
return LevelOff
|
||||
}
|
265
internal/logger/logger.go
Normal file
265
internal/logger/logger.go
Normal file
@ -0,0 +1,265 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Package logger provides the internal logging solution for the MongoDB Go
|
||||
// Driver.
|
||||
package logger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
"gitea.psichedelico.com/go/bson/internal/bsoncoreutil"
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// DefaultMaxDocumentLength is the default maximum number of bytes that can be
|
||||
// logged for a stringified BSON document.
|
||||
const DefaultMaxDocumentLength = 1000
|
||||
|
||||
// TruncationSuffix are trailing ellipsis "..." appended to a message to
|
||||
// indicate to the user that truncation occurred. This constant does not count
|
||||
// toward the max document length.
|
||||
const TruncationSuffix = "..."
|
||||
|
||||
const logSinkPathEnvVar = "MONGODB_LOG_PATH"
|
||||
const maxDocumentLengthEnvVar = "MONGODB_LOG_MAX_DOCUMENT_LENGTH"
|
||||
|
||||
// LogSink represents a logging implementation, this interface should be 1-1
|
||||
// with the exported "LogSink" interface in the mongo/options package.
|
||||
type LogSink interface {
|
||||
// Info logs a non-error message with the given key/value pairs. The
|
||||
// level argument is provided for optional logging.
|
||||
Info(level int, msg string, keysAndValues ...interface{})
|
||||
|
||||
// Error logs an error, with the given message and key/value pairs.
|
||||
Error(err error, msg string, keysAndValues ...interface{})
|
||||
}
|
||||
|
||||
// Logger represents the configuration for the internal logger.
|
||||
type Logger struct {
|
||||
ComponentLevels map[Component]Level // Log levels for each component.
|
||||
Sink LogSink // LogSink for log printing.
|
||||
MaxDocumentLength uint // Command truncation width.
|
||||
logFile *os.File // File to write logs to.
|
||||
}
|
||||
|
||||
// New will construct a new logger. If any of the given options are the
|
||||
// zero-value of the argument type, then the constructor will attempt to
|
||||
// source the data from the environment. If the environment has not been set,
|
||||
// then the constructor will the respective default values.
|
||||
func New(sink LogSink, maxDocLen uint, compLevels map[Component]Level) (*Logger, error) {
|
||||
logger := &Logger{
|
||||
ComponentLevels: selectComponentLevels(compLevels),
|
||||
MaxDocumentLength: selectMaxDocumentLength(maxDocLen),
|
||||
}
|
||||
|
||||
sink, logFile, err := selectLogSink(sink)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Sink = sink
|
||||
logger.logFile = logFile
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// Close will close the logger's log file, if it exists.
|
||||
func (logger *Logger) Close() error {
|
||||
if logger.logFile != nil {
|
||||
return logger.logFile.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LevelComponentEnabled will return true if the given LogLevel is enabled for
|
||||
// the given LogComponent. If the ComponentLevels on the logger are enabled for
|
||||
// "ComponentAll", then this function will return true for any level bound by
|
||||
// the level assigned to "ComponentAll".
|
||||
//
|
||||
// If the level is not enabled (i.e. LevelOff), then false is returned. This is
|
||||
// to avoid false positives, such as returning "true" for a component that is
|
||||
// not enabled. For example, without this condition, an empty LevelComponent
|
||||
// would be considered "enabled" for "LevelOff".
|
||||
func (logger *Logger) LevelComponentEnabled(level Level, component Component) bool {
|
||||
if level == LevelOff {
|
||||
return false
|
||||
}
|
||||
|
||||
if logger.ComponentLevels == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return logger.ComponentLevels[component] >= level ||
|
||||
logger.ComponentLevels[ComponentAll] >= level
|
||||
}
|
||||
|
||||
// Print will synchronously print the given message to the configured LogSink.
|
||||
// If the LogSink is nil, then this method will do nothing. Future work could be done to make
|
||||
// this method asynchronous, see buffer management in libraries such as log4j.
|
||||
//
|
||||
// It's worth noting that many structured logs defined by DBX-wide
|
||||
// specifications include a "message" field, which is often shared with the
|
||||
// message arguments passed to this print function. The "Info" method used by
|
||||
// this function is implemented based on the go-logr/logr LogSink interface,
|
||||
// which is why "Print" has a message parameter. Any duplication in code is
|
||||
// intentional to adhere to the logr pattern.
|
||||
func (logger *Logger) Print(level Level, component Component, msg string, keysAndValues ...interface{}) {
|
||||
// If the level is not enabled for the component, then
|
||||
// skip the message.
|
||||
if !logger.LevelComponentEnabled(level, component) {
|
||||
return
|
||||
}
|
||||
|
||||
// If the sink is nil, then skip the message.
|
||||
if logger.Sink == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Sink.Info(int(level)-DiffToInfo, msg, keysAndValues...)
|
||||
}
|
||||
|
||||
// Error logs an error, with the given message and key/value pairs.
|
||||
// It functions similarly to Print, but may have unique behavior, and should be
|
||||
// preferred for logging errors.
|
||||
func (logger *Logger) Error(err error, msg string, keysAndValues ...interface{}) {
|
||||
if logger.Sink == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Sink.Error(err, msg, keysAndValues...)
|
||||
}
|
||||
|
||||
// selectMaxDocumentLength will return the integer value of the first non-zero
|
||||
// function, with the user-defined function taking priority over the environment
|
||||
// variables. For the environment, the function will attempt to get the value of
|
||||
// "MONGODB_LOG_MAX_DOCUMENT_LENGTH" and parse it as an unsigned integer. If the
|
||||
// environment variable is not set or is not an unsigned integer, then this
|
||||
// function will return the default max document length.
|
||||
func selectMaxDocumentLength(maxDocLen uint) uint {
|
||||
if maxDocLen != 0 {
|
||||
return maxDocLen
|
||||
}
|
||||
|
||||
maxDocLenEnv := os.Getenv(maxDocumentLengthEnvVar)
|
||||
if maxDocLenEnv != "" {
|
||||
maxDocLenEnvInt, err := strconv.ParseUint(maxDocLenEnv, 10, 32)
|
||||
if err == nil {
|
||||
return uint(maxDocLenEnvInt)
|
||||
}
|
||||
}
|
||||
|
||||
return DefaultMaxDocumentLength
|
||||
}
|
||||
|
||||
const (
|
||||
logSinkPathStdout = "stdout"
|
||||
logSinkPathStderr = "stderr"
|
||||
)
|
||||
|
||||
// selectLogSink will return the first non-nil LogSink, with the user-defined
|
||||
// LogSink taking precedence over the environment-defined LogSink. If no LogSink
|
||||
// is defined, then this function will return a LogSink that writes to stderr.
|
||||
func selectLogSink(sink LogSink) (LogSink, *os.File, error) {
|
||||
if sink != nil {
|
||||
return sink, nil, nil
|
||||
}
|
||||
|
||||
path := os.Getenv(logSinkPathEnvVar)
|
||||
lowerPath := strings.ToLower(path)
|
||||
|
||||
if lowerPath == string(logSinkPathStderr) {
|
||||
return NewIOSink(os.Stderr), nil, nil
|
||||
}
|
||||
|
||||
if lowerPath == string(logSinkPathStdout) {
|
||||
return NewIOSink(os.Stdout), nil, nil
|
||||
}
|
||||
|
||||
if path != "" {
|
||||
logFile, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to open log file: %w", err)
|
||||
}
|
||||
|
||||
return NewIOSink(logFile), logFile, nil
|
||||
}
|
||||
|
||||
return NewIOSink(os.Stderr), nil, nil
|
||||
}
|
||||
|
||||
// selectComponentLevels returns a new map of LogComponents to LogLevels that is
|
||||
// the result of merging the user-defined data with the environment, with the
|
||||
// user-defined data taking priority.
|
||||
func selectComponentLevels(componentLevels map[Component]Level) map[Component]Level {
|
||||
selected := make(map[Component]Level)
|
||||
|
||||
// Determine if the "MONGODB_LOG_ALL" environment variable is set.
|
||||
var globalEnvLevel *Level
|
||||
if all := os.Getenv(mongoDBLogAllEnvVar); all != "" {
|
||||
level := ParseLevel(all)
|
||||
globalEnvLevel = &level
|
||||
}
|
||||
|
||||
for envVar, component := range componentEnvVarMap {
|
||||
// If the component already has a level, then skip it.
|
||||
if _, ok := componentLevels[component]; ok {
|
||||
selected[component] = componentLevels[component]
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// If the "MONGODB_LOG_ALL" environment variable is set, then
|
||||
// set the level for the component to the value of the
|
||||
// environment variable.
|
||||
if globalEnvLevel != nil {
|
||||
selected[component] = *globalEnvLevel
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise, set the level for the component to the value of
|
||||
// the environment variable.
|
||||
selected[component] = ParseLevel(os.Getenv(envVar))
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
// FormatDocument formats a BSON document or RawValue for logging. The document is truncated
|
||||
// to the given width.
|
||||
func FormatDocument(msg bson.Raw, width uint) string {
|
||||
if len(msg) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
str := bsoncore.Document(msg).StringN(int(width))
|
||||
|
||||
// If the last byte is not a closing bracket, then the document was truncated
|
||||
if len(str) > 0 && str[len(str)-1] != '}' {
|
||||
str += TruncationSuffix
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// FormatString formats a String for logging. The string is truncated
|
||||
// to the given width.
|
||||
func FormatString(str string, width uint) string {
|
||||
strTrunc := bsoncoreutil.Truncate(str, int(width))
|
||||
|
||||
// Checks if the string was truncating by comparing the lengths of the two strings.
|
||||
if len(strTrunc) < len(str) {
|
||||
strTrunc += TruncationSuffix
|
||||
}
|
||||
|
||||
return strTrunc
|
||||
}
|
576
internal/logger/logger_test.go
Normal file
576
internal/logger/logger_test.go
Normal file
@ -0,0 +1,576 @@
|
||||
// Copyright (C) MongoDB, Inc. 2023-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"gitea.psichedelico.com/go/bson"
|
||||
"gitea.psichedelico.com/go/bson/internal/assert"
|
||||
"gitea.psichedelico.com/go/bson/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
type mockLogSink struct{}
|
||||
|
||||
func (mockLogSink) Info(int, string, ...interface{}) {}
|
||||
func (mockLogSink) Error(error, string, ...interface{}) {}
|
||||
|
||||
func BenchmarkLoggerWithLargeDocuments(b *testing.B) {
|
||||
// Define the large document test cases
|
||||
testCases := []struct {
|
||||
name string
|
||||
create func() bson.D
|
||||
}{
|
||||
{
|
||||
name: "LargeStrings",
|
||||
create: func() bson.D { return createLargeStringsDocument(10) },
|
||||
},
|
||||
{
|
||||
name: "MassiveArrays",
|
||||
create: func() bson.D { return createMassiveArraysDocument(100000) },
|
||||
},
|
||||
{
|
||||
name: "VeryVoluminousDocument",
|
||||
create: func() bson.D { return createVoluminousDocument(100000) },
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
// Run benchmark with logging and truncation enabled
|
||||
b.Run("LoggingWithTruncation", func(b *testing.B) {
|
||||
logger, err := New(mockLogSink{}, 0, map[Component]Level{
|
||||
ComponentCommand: LevelDebug,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
bs, err := bson.Marshal(tc.create())
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Print(LevelInfo, ComponentCommand, FormatDocument(bs, 1024), "foo", "bar", "baz")
|
||||
|
||||
}
|
||||
})
|
||||
|
||||
// Run benchmark with logging enabled without truncation
|
||||
b.Run("LoggingWithoutTruncation", func(b *testing.B) {
|
||||
logger, err := New(mockLogSink{}, 0, map[Component]Level{
|
||||
ComponentCommand: LevelDebug,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
bs, err := bson.Marshal(tc.create())
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
msg := bsoncore.Document(bs).String()
|
||||
logger.Print(LevelInfo, ComponentCommand, msg, "foo", "bar", "baz")
|
||||
|
||||
}
|
||||
})
|
||||
|
||||
// Run benchmark without logging or truncation
|
||||
b.Run("WithoutLoggingOrTruncation", func(b *testing.B) {
|
||||
bs, err := bson.Marshal(tc.create())
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = bsoncore.Document(bs).String()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions to create large documents
|
||||
func createVoluminousDocument(numKeys int) bson.D {
|
||||
d := make(bson.D, numKeys)
|
||||
for i := 0; i < numKeys; i++ {
|
||||
d = append(d, bson.E{Key: fmt.Sprintf("key%d", i), Value: "value"})
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func createLargeStringsDocument(sizeMB int) bson.D {
|
||||
largeString := strings.Repeat("a", sizeMB*1024*1024)
|
||||
return bson.D{
|
||||
{Key: "largeString1", Value: largeString},
|
||||
{Key: "largeString2", Value: largeString},
|
||||
{Key: "largeString3", Value: largeString},
|
||||
{Key: "largeString4", Value: largeString},
|
||||
}
|
||||
}
|
||||
|
||||
func createMassiveArraysDocument(arraySize int) bson.D {
|
||||
massiveArray := make([]string, arraySize)
|
||||
for i := 0; i < arraySize; i++ {
|
||||
massiveArray[i] = "value"
|
||||
}
|
||||
return bson.D{
|
||||
{Key: "massiveArray1", Value: massiveArray},
|
||||
{Key: "massiveArray2", Value: massiveArray},
|
||||
{Key: "massiveArray3", Value: massiveArray},
|
||||
{Key: "massiveArray4", Value: massiveArray},
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLogger(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.Run("Print", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
logger, err := New(mockLogSink{}, 0, map[Component]Level{
|
||||
ComponentCommand: LevelDebug,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
logger.Print(LevelInfo, ComponentCommand, "foo", "bar", "baz")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func mockKeyValues(length int) (KeyValues, map[string]interface{}) {
|
||||
keysAndValues := KeyValues{}
|
||||
m := map[string]interface{}{}
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
keyName := fmt.Sprintf("key%d", i)
|
||||
valueName := fmt.Sprintf("value%d", i)
|
||||
|
||||
keysAndValues.Add(keyName, valueName)
|
||||
m[keyName] = valueName
|
||||
}
|
||||
|
||||
return keysAndValues, m
|
||||
}
|
||||
|
||||
func BenchmarkIOSinkInfo(b *testing.B) {
|
||||
keysAndValues, _ := mockKeyValues(10)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
sink := NewIOSink(bytes.NewBuffer(nil))
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
sink.Info(0, "foo", keysAndValues...)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIOSinkInfo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const threshold = 1000
|
||||
|
||||
mockKeyValues, kvmap := mockKeyValues(10)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
sink := NewIOSink(buf)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(threshold)
|
||||
|
||||
for i := 0; i < threshold; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
sink.Info(0, "foo", mockKeyValues...)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
dec := json.NewDecoder(buf)
|
||||
for dec.More() {
|
||||
var m map[string]interface{}
|
||||
if err := dec.Decode(&m); err != nil {
|
||||
t.Fatalf("error unmarshaling JSON: %v", err)
|
||||
}
|
||||
|
||||
delete(m, KeyTimestamp)
|
||||
delete(m, KeyMessage)
|
||||
|
||||
if !reflect.DeepEqual(m, kvmap) {
|
||||
t.Fatalf("expected %v, got %v", kvmap, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectMaxDocumentLength(t *testing.T) {
|
||||
for _, tcase := range []struct {
|
||||
name string
|
||||
arg uint
|
||||
expected uint
|
||||
env map[string]string
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
arg: 0,
|
||||
expected: DefaultMaxDocumentLength,
|
||||
},
|
||||
{
|
||||
name: "non-zero",
|
||||
arg: 100,
|
||||
expected: 100,
|
||||
},
|
||||
{
|
||||
name: "valid env",
|
||||
arg: 0,
|
||||
expected: 100,
|
||||
env: map[string]string{
|
||||
maxDocumentLengthEnvVar: "100",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid env",
|
||||
arg: 0,
|
||||
expected: DefaultMaxDocumentLength,
|
||||
env: map[string]string{
|
||||
maxDocumentLengthEnvVar: "foo",
|
||||
},
|
||||
},
|
||||
} {
|
||||
tcase := tcase
|
||||
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
for k, v := range tcase.env {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
actual := selectMaxDocumentLength(tcase.arg)
|
||||
if actual != tcase.expected {
|
||||
t.Errorf("expected %d, got %d", tcase.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectLogSink(t *testing.T) {
|
||||
for _, tcase := range []struct {
|
||||
name string
|
||||
arg LogSink
|
||||
expected LogSink
|
||||
env map[string]string
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
arg: nil,
|
||||
expected: NewIOSink(os.Stderr),
|
||||
},
|
||||
{
|
||||
name: "non-nil",
|
||||
arg: mockLogSink{},
|
||||
expected: mockLogSink{},
|
||||
},
|
||||
{
|
||||
name: "stdout",
|
||||
arg: nil,
|
||||
expected: NewIOSink(os.Stdout),
|
||||
env: map[string]string{
|
||||
logSinkPathEnvVar: logSinkPathStdout,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "stderr",
|
||||
arg: nil,
|
||||
expected: NewIOSink(os.Stderr),
|
||||
env: map[string]string{
|
||||
logSinkPathEnvVar: logSinkPathStderr,
|
||||
},
|
||||
},
|
||||
} {
|
||||
tcase := tcase
|
||||
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
for k, v := range tcase.env {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
actual, _, _ := selectLogSink(tcase.arg)
|
||||
if !reflect.DeepEqual(actual, tcase.expected) {
|
||||
t.Errorf("expected %+v, got %+v", tcase.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectedComponentLevels(t *testing.T) {
|
||||
for _, tcase := range []struct {
|
||||
name string
|
||||
arg map[Component]Level
|
||||
expected map[Component]Level
|
||||
env map[string]string
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
arg: nil,
|
||||
expected: map[Component]Level{
|
||||
ComponentCommand: LevelOff,
|
||||
ComponentTopology: LevelOff,
|
||||
ComponentServerSelection: LevelOff,
|
||||
ComponentConnection: LevelOff,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-nil",
|
||||
arg: map[Component]Level{
|
||||
ComponentCommand: LevelDebug,
|
||||
},
|
||||
expected: map[Component]Level{
|
||||
ComponentCommand: LevelDebug,
|
||||
ComponentTopology: LevelOff,
|
||||
ComponentServerSelection: LevelOff,
|
||||
ComponentConnection: LevelOff,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid env",
|
||||
arg: nil,
|
||||
expected: map[Component]Level{
|
||||
ComponentCommand: LevelDebug,
|
||||
ComponentTopology: LevelInfo,
|
||||
ComponentServerSelection: LevelOff,
|
||||
ComponentConnection: LevelOff,
|
||||
},
|
||||
env: map[string]string{
|
||||
mongoDBLogCommandEnvVar: levelLiteralDebug,
|
||||
mongoDBLogTopologyEnvVar: levelLiteralInfo,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid env",
|
||||
arg: nil,
|
||||
expected: map[Component]Level{
|
||||
ComponentCommand: LevelOff,
|
||||
ComponentTopology: LevelOff,
|
||||
ComponentServerSelection: LevelOff,
|
||||
ComponentConnection: LevelOff,
|
||||
},
|
||||
env: map[string]string{
|
||||
mongoDBLogCommandEnvVar: "foo",
|
||||
mongoDBLogTopologyEnvVar: "bar",
|
||||
},
|
||||
},
|
||||
} {
|
||||
tcase := tcase
|
||||
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
for k, v := range tcase.env {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
actual := selectComponentLevels(tcase.arg)
|
||||
for k, v := range tcase.expected {
|
||||
if actual[k] != v {
|
||||
t.Errorf("expected %d, got %d", v, actual[k])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger_LevelComponentEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logger Logger
|
||||
level Level
|
||||
component Component
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
logger: Logger{},
|
||||
level: LevelOff,
|
||||
component: ComponentCommand,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{},
|
||||
},
|
||||
level: LevelOff,
|
||||
component: ComponentCommand,
|
||||
want: false, // LevelOff should never be considered enabled.
|
||||
},
|
||||
{
|
||||
name: "one level below",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentCommand: LevelDebug,
|
||||
},
|
||||
},
|
||||
level: LevelInfo,
|
||||
component: ComponentCommand,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "equal levels",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentCommand: LevelDebug,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentCommand,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "one level above",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentCommand: LevelInfo,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentCommand,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "component mismatch",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentCommand: LevelDebug,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentTopology,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "component all enables with topology",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentAll: LevelDebug,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentTopology,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "component all enables with server selection",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentAll: LevelDebug,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentServerSelection,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "component all enables with connection",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentAll: LevelDebug,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentConnection,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "component all enables with command",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentAll: LevelDebug,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentCommand,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "component all enables with all",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentAll: LevelDebug,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentAll,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "component all does not enable with lower level",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentAll: LevelInfo,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentCommand,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "component all has a lower log level than command",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentAll: LevelInfo,
|
||||
ComponentCommand: LevelDebug,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentCommand,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "component all has a higher log level than command",
|
||||
logger: Logger{
|
||||
ComponentLevels: map[Component]Level{
|
||||
ComponentAll: LevelDebug,
|
||||
ComponentCommand: LevelInfo,
|
||||
},
|
||||
},
|
||||
level: LevelDebug,
|
||||
component: ComponentCommand,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tcase := range tests {
|
||||
tcase := tcase // Capture the range variable.
|
||||
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := tcase.logger.LevelComponentEnabled(tcase.level, tcase.component)
|
||||
assert.Equal(t, tcase.want, got, "unexpected result for LevelComponentEnabled")
|
||||
})
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user