diff --git a/reflect.go b/reflect.go index 39acb320..7f7a7275 100644 --- a/reflect.go +++ b/reflect.go @@ -270,7 +270,7 @@ func _createEncoderOfType(ctx *ctx, typ reflect2.Type) ValEncoder { kind := typ.Kind() switch kind { case reflect.Interface: - return &dynamicEncoder{typ} + return &dynamicEncoder{valType: typ, seen: make(map[unsafe.Pointer]bool, 1)} case reflect.Struct: return encoderOfStruct(ctx, typ) case reflect.Array: diff --git a/reflect_dynamic.go b/reflect_dynamic.go index 8b6bc8b4..fce985d1 100644 --- a/reflect_dynamic.go +++ b/reflect_dynamic.go @@ -1,16 +1,25 @@ package jsoniter import ( - "github.com/modern-go/reflect2" "reflect" "unsafe" + + "github.com/modern-go/reflect2" ) type dynamicEncoder struct { valType reflect2.Type + seen map[unsafe.Pointer]bool } func (encoder *dynamicEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + if encoder.seen[ptr] { + stream.Error = ErrEncounterCycle + return + } + encoder.seen[ptr] = true + defer delete(encoder.seen, ptr) + obj := encoder.valType.UnsafeIndirect(ptr) stream.WriteVal(obj) } diff --git a/reflect_extension.go b/reflect_extension.go index 74a97bfe..b79951bf 100644 --- a/reflect_extension.go +++ b/reflect_extension.go @@ -2,12 +2,13 @@ package jsoniter import ( "fmt" - "github.com/modern-go/reflect2" "reflect" "sort" "strings" "unicode" "unsafe" + + "github.com/modern-go/reflect2" ) var typeDecoders = map[string]ValDecoder{} @@ -325,7 +326,7 @@ func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder { typePtr := typ.(*reflect2.UnsafePtrType) encoder := typeEncoders[typePtr.Elem().String()] if encoder != nil { - return &OptionalEncoder{encoder} + return &OptionalEncoder{ValueEncoder: encoder, seen: make(map[unsafe.Pointer]bool, 1)} } } return nil diff --git a/reflect_optional.go b/reflect_optional.go index fa71f474..c6422fca 100644 --- a/reflect_optional.go +++ b/reflect_optional.go @@ -1,8 +1,9 @@ package jsoniter import ( - "github.com/modern-go/reflect2" "unsafe" + + "github.com/modern-go/reflect2" ) func decoderOfOptional(ctx *ctx, typ reflect2.Type) ValDecoder { @@ -16,7 +17,7 @@ func encoderOfOptional(ctx *ctx, typ reflect2.Type) ValEncoder { ptrType := typ.(*reflect2.UnsafePtrType) elemType := ptrType.Elem() elemEncoder := encoderOfType(ctx, elemType) - encoder := &OptionalEncoder{elemEncoder} + encoder := &OptionalEncoder{ValueEncoder: elemEncoder, seen: make(map[unsafe.Pointer]bool, 1)} return encoder } @@ -61,13 +62,22 @@ func (decoder *dereferenceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { type OptionalEncoder struct { ValueEncoder ValEncoder + seen map[unsafe.Pointer]bool } func (encoder *OptionalEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { - if *((*unsafe.Pointer)(ptr)) == nil { + ptr = *((*unsafe.Pointer)(ptr)) + if encoder.seen[ptr] { + stream.Error = ErrEncounterCycle + return + } + encoder.seen[ptr] = true + defer delete(encoder.seen, ptr) + + if ptr == nil { stream.WriteNil() } else { - encoder.ValueEncoder.Encode(*((*unsafe.Pointer)(ptr)), stream) + encoder.ValueEncoder.Encode(ptr, stream) } } diff --git a/reflect_struct_encoder.go b/reflect_struct_encoder.go index 152e3ef5..6347907a 100644 --- a/reflect_struct_encoder.go +++ b/reflect_struct_encoder.go @@ -2,10 +2,11 @@ package jsoniter import ( "fmt" - "github.com/modern-go/reflect2" "io" "reflect" "unsafe" + + "github.com/modern-go/reflect2" ) func encoderOfStruct(ctx *ctx, typ reflect2.Type) ValEncoder { @@ -54,7 +55,7 @@ func createCheckIsEmpty(ctx *ctx, typ reflect2.Type) checkIsEmpty { kind := typ.Kind() switch kind { case reflect.Interface: - return &dynamicEncoder{typ} + return &dynamicEncoder{valType: typ, seen: make(map[unsafe.Pointer]bool, 1)} case reflect.Struct: return &structEncoder{typ: typ} case reflect.Array: diff --git a/stream.go b/stream.go index 23d8a3ad..7c8c41ed 100644 --- a/stream.go +++ b/stream.go @@ -1,9 +1,12 @@ package jsoniter import ( + "errors" "io" ) +var ErrEncounterCycle = errors.New("encountered a cycle") + // stream is a io.Writer like object, with JSON specific write functions. // Error is not returned as return value, but stored as Error member on this stream instance. type Stream struct { diff --git a/value_tests/map_test.go b/value_tests/map_test.go index 02a1895a..c36451a6 100644 --- a/value_tests/map_test.go +++ b/value_tests/map_test.go @@ -57,9 +57,13 @@ func init() { "2018-12-14": true }`, }, unmarshalCase{ - ptr: (*map[customKey]string)(nil), + ptr: (*map[customKey]string)(nil), input: `{"foo": "bar"}`, }) + + selfRecursive := map[string]interface{}{} + selfRecursive["me"] = selfRecursive + marshalSelfRecursiveCases = append(marshalSelfRecursiveCases, selfRecursive) } type MyInterface interface { diff --git a/value_tests/slice_test.go b/value_tests/slice_test.go index 3731cbe1..d60494c2 100644 --- a/value_tests/slice_test.go +++ b/value_tests/slice_test.go @@ -24,4 +24,8 @@ func init() { ptr: (*[]byte)(nil), input: `"c3ViamVjdHM\/X2Q9MQ=="`, }) + + selfRecursive := []interface{}{nil} + selfRecursive[0] = selfRecursive + marshalSelfRecursiveCases = append(marshalSelfRecursiveCases, selfRecursive) } diff --git a/value_tests/struct_test.go b/value_tests/struct_test.go index 181e12fc..24bf2eb0 100644 --- a/value_tests/struct_test.go +++ b/value_tests/struct_test.go @@ -205,6 +205,10 @@ func init() { "should not marshal", }, ) + + selfRecursive := &structRecursive{} + selfRecursive.Me = selfRecursive + marshalSelfRecursiveCases = append(marshalSelfRecursiveCases, selfRecursive) } type StructVarious struct { diff --git a/value_tests/value_test.go b/value_tests/value_test.go index 95cfdd56..79a93b43 100644 --- a/value_tests/value_test.go +++ b/value_tests/value_test.go @@ -3,10 +3,11 @@ package test import ( "encoding/json" "fmt" - "github.com/json-iterator/go" + "testing" + + jsoniter "github.com/json-iterator/go" "github.com/modern-go/reflect2" "github.com/stretchr/testify/require" - "testing" ) type unmarshalCase struct { @@ -22,6 +23,8 @@ var marshalCases = []interface{}{ nil, } +var marshalSelfRecursiveCases = []interface{}{} + type selectedMarshalCase struct { marshalCase interface{} } @@ -78,3 +81,15 @@ func Test_marshal(t *testing.T) { }) } } + +func Test_marshal_self_recursive(t *testing.T) { + for i, testCase := range marshalSelfRecursiveCases { + t.Run(fmt.Sprintf("[%v]%s", i, reflect2.TypeOf(testCase).String()), func(t *testing.T) { + should := require.New(t) + _, err1 := json.Marshal(testCase) + should.ErrorContains(err1, "encountered a cycle") + _, err2 := jsoniter.ConfigCompatibleWithStandardLibrary.Marshal(testCase) + should.ErrorContains(err2, "encountered a cycle") + }) + } +}