Skip to content

Commit

Permalink
badjson: Add context marshaler/unmarshaler
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 1, 2024
1 parent 94f0582 commit 8452992
Show file tree
Hide file tree
Showing 13 changed files with 285 additions and 60 deletions.
5 changes: 3 additions & 2 deletions common/json/badjson/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package badjson

import (
"bytes"
"context"

E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)

func Decode(content []byte) (any, error) {
decoder := json.NewDecoder(bytes.NewReader(content))
func Decode(ctx context.Context, content []byte) (any, error) {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
return decodeJSON(decoder)
}

Expand Down
45 changes: 23 additions & 22 deletions common/json/badjson/merge.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package badjson

import (
"context"
"os"
"reflect"

Expand All @@ -9,100 +10,100 @@ import (
"github.com/sagernet/sing/common/json"
)

func Omitempty[T any](value T) (T, error) {
func Omitempty[T any](ctx context.Context, value T) (T, error) {
objectContent, err := json.Marshal(value)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal object")
}
rawNewObject, err := Decode(objectContent)
rawNewObject, err := Decode(ctx, objectContent)
if err != nil {
return common.DefaultValue[T](), err
}
newObjectContent, err := json.Marshal(rawNewObject)
newObjectContent, err := json.MarshalContext(ctx, rawNewObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal new object")
}
var newObject T
err = json.Unmarshal(newObjectContent, &newObject)
err = json.UnmarshalContext(ctx, newObjectContent, &newObject)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal new object")
}
return newObject, nil
}

func Merge[T any](source T, destination T, disableAppend bool) (T, error) {
rawSource, err := json.Marshal(source)
func Merge[T any](ctx context.Context, source T, destination T, disableAppend bool) (T, error) {
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
rawDestination, err := json.Marshal(destination)
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination, disableAppend)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}

func MergeFromSource[T any](rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
func MergeFromSource[T any](ctx context.Context, rawSource json.RawMessage, destination T, disableAppend bool) (T, error) {
if rawSource == nil {
return destination, nil
}
rawDestination, err := json.Marshal(destination)
rawDestination, err := json.MarshalContext(ctx, destination)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal destination")
}
return MergeFrom[T](rawSource, rawDestination, disableAppend)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}

func MergeFromDestination[T any](source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
func MergeFromDestination[T any](ctx context.Context, source T, rawDestination json.RawMessage, disableAppend bool) (T, error) {
if rawDestination == nil {
return source, nil
}
rawSource, err := json.Marshal(source)
rawSource, err := json.MarshalContext(ctx, source)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "marshal source")
}
return MergeFrom[T](rawSource, rawDestination, disableAppend)
return MergeFrom[T](ctx, rawSource, rawDestination, disableAppend)
}

func MergeFrom[T any](rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
rawMerged, err := MergeJSON(rawSource, rawDestination, disableAppend)
func MergeFrom[T any](ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (T, error) {
rawMerged, err := MergeJSON(ctx, rawSource, rawDestination, disableAppend)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "merge options")
}
var merged T
err = json.Unmarshal(rawMerged, &merged)
err = json.UnmarshalContext(ctx, rawMerged, &merged)
if err != nil {
return common.DefaultValue[T](), E.Cause(err, "unmarshal merged options")
}
return merged, nil
}

func MergeJSON(rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
func MergeJSON(ctx context.Context, rawSource json.RawMessage, rawDestination json.RawMessage, disableAppend bool) (json.RawMessage, error) {
if rawSource == nil && rawDestination == nil {
return nil, os.ErrInvalid
} else if rawSource == nil {
return rawDestination, nil
} else if rawDestination == nil {
return rawSource, nil
}
source, err := Decode(rawSource)
source, err := Decode(ctx, rawSource)
if err != nil {
return nil, E.Cause(err, "decode source")
}
destination, err := Decode(rawDestination)
destination, err := Decode(ctx, rawDestination)
if err != nil {
return nil, E.Cause(err, "decode destination")
}
if source == nil {
return json.Marshal(destination)
return json.MarshalContext(ctx, destination)
} else if destination == nil {
return json.Marshal(source)
}
merged, err := mergeJSON(source, destination, disableAppend)
if err != nil {
return nil, err
}
return json.Marshal(merged)
return json.MarshalContext(ctx, merged)
}

func mergeJSON(anySource any, anyDestination any, disableAppend bool) (any, error) {
Expand Down
28 changes: 19 additions & 9 deletions common/json/badjson/merge_objects.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,42 @@
package badjson

import (
"context"

E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)

func MarshallObjects(objects ...any) ([]byte, error) {
return MarshallObjectsContext(context.Background(), objects...)
}

func MarshallObjectsContext(ctx context.Context, objects ...any) ([]byte, error) {
if len(objects) == 1 {
return json.Marshal(objects[0])
}
var content JSONObject
for _, object := range objects {
objectMap, err := newJSONObject(object)
objectMap, err := newJSONObject(ctx, object)
if err != nil {
return nil, err
}
content.PutAll(objectMap)
}
return content.MarshalJSON()
return content.MarshalJSONContext(ctx)
}

func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error {
parentContent, err := newJSONObject(parentObject)
return UnmarshallExcludedContext(context.Background(), inputContent, parentObject, object)
}

func UnmarshallExcludedContext(ctx context.Context, inputContent []byte, parentObject any, object any) error {
parentContent, err := newJSONObject(ctx, parentObject)
if err != nil {
return err
}
var content JSONObject
err = content.UnmarshalJSON(inputContent)
err = content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return err
}
Expand All @@ -39,20 +49,20 @@ func UnmarshallExcluded(inputContent []byte, parentObject any, object any) error
}
return E.New("unexpected key: ", content.Keys()[0])
}
inputContent, err = content.MarshalJSON()
inputContent, err = content.MarshalJSONContext(ctx)
if err != nil {
return err
}
return json.UnmarshalDisallowUnknownFields(inputContent, object)
return json.UnmarshalContextDisallowUnknownFields(ctx, inputContent, object)
}

func newJSONObject(object any) (*JSONObject, error) {
inputContent, err := json.Marshal(object)
func newJSONObject(ctx context.Context, object any) (*JSONObject, error) {
inputContent, err := json.MarshalContext(ctx, object)
if err != nil {
return nil, err
}
var content JSONObject
err = content.UnmarshalJSON(inputContent)
err = content.UnmarshalJSONContext(ctx, inputContent)
if err != nil {
return nil, err
}
Expand Down
15 changes: 12 additions & 3 deletions common/json/badjson/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package badjson

import (
"bytes"
"context"
"strings"

"github.com/sagernet/sing/common"
Expand All @@ -28,6 +29,10 @@ func (m *JSONObject) IsEmpty() bool {
}

func (m *JSONObject) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}

func (m *JSONObject) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := common.Filter(m.Entries(), func(it collections.MapEntry[string, any]) bool {
Expand All @@ -38,13 +43,13 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
})
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
Expand All @@ -58,7 +63,11 @@ func (m *JSONObject) MarshalJSON() ([]byte, error) {
}

func (m *JSONObject) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
return m.UnmarshalJSONContext(context.Background(), content)
}

func (m *JSONObject) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {
Expand Down
23 changes: 16 additions & 7 deletions common/json/badjson/typed.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package badjson

import (
"bytes"
"context"
"strings"

E "github.com/sagernet/sing/common/exceptions"
Expand All @@ -14,18 +15,22 @@ type TypedMap[K comparable, V any] struct {
}

func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
return m.MarshalJSONContext(context.Background())
}

func (m TypedMap[K, V]) MarshalJSONContext(ctx context.Context) ([]byte, error) {
buffer := new(bytes.Buffer)
buffer.WriteString("{")
items := m.Entries()
iLen := len(items)
for i, entry := range items {
keyContent, err := json.Marshal(entry.Key)
keyContent, err := json.MarshalContext(ctx, entry.Key)
if err != nil {
return nil, err
}
buffer.WriteString(strings.TrimSpace(string(keyContent)))
buffer.WriteString(": ")
valueContent, err := json.Marshal(entry.Value)
valueContent, err := json.MarshalContext(ctx, entry.Value)
if err != nil {
return nil, err
}
Expand All @@ -39,15 +44,19 @@ func (m TypedMap[K, V]) MarshalJSON() ([]byte, error) {
}

func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
decoder := json.NewDecoder(bytes.NewReader(content))
return m.UnmarshalJSONContext(context.Background(), content)
}

func (m *TypedMap[K, V]) UnmarshalJSONContext(ctx context.Context, content []byte) error {
decoder := json.NewDecoderContext(ctx, bytes.NewReader(content))
m.Clear()
objectStart, err := decoder.Token()
if err != nil {
return err
} else if objectStart != json.Delim('{') {
return E.New("expected json object start, but starts with ", objectStart)
}
err = m.decodeJSON(decoder)
err = m.decodeJSON(ctx, decoder)
if err != nil {
return E.Cause(err, "decode json object content")
}
Expand All @@ -60,18 +69,18 @@ func (m *TypedMap[K, V]) UnmarshalJSON(content []byte) error {
return nil
}

func (m *TypedMap[K, V]) decodeJSON(decoder *json.Decoder) error {
func (m *TypedMap[K, V]) decodeJSON(ctx context.Context, decoder *json.Decoder) error {
for decoder.More() {
keyToken, err := decoder.Token()
if err != nil {
return err
}
keyContent, err := json.Marshal(keyToken)
keyContent, err := json.MarshalContext(ctx, keyToken)
if err != nil {
return err
}
var entryKey K
err = json.Unmarshal(keyContent, &entryKey)
err = json.UnmarshalContext(ctx, keyContent, &entryKey)
if err != nil {
return err
}
Expand Down
23 changes: 23 additions & 0 deletions common/json/context_ext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package json

import (
"context"

"github.com/sagernet/sing/common/json/internal/contextjson"
)

var (
MarshalContext = json.MarshalContext
UnmarshalContext = json.UnmarshalContext
NewEncoderContext = json.NewEncoderContext
NewDecoderContext = json.NewDecoderContext
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
)

type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}

type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}
11 changes: 11 additions & 0 deletions common/json/internal/contextjson/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package json

import "context"

type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}

type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}
Loading

0 comments on commit 8452992

Please sign in to comment.