diff --git a/codegen/source.go.tmpl b/codegen/source.go.tmpl index 3a08ef5..f5026a3 100644 --- a/codegen/source.go.tmpl +++ b/codegen/source.go.tmpl @@ -301,7 +301,7 @@ var _ types.ENUM = {{capitalise .Name}}("") // {{capitalise .Name}} is a {{.RawType}} type type {{capitalise .Name}} struct { {{- range $field := .Fields}} - {{capitalise $field.Name}} {{$field.Type.GoType}} `json:"{{$field.Name}}"{{if $field.IsBytesHex}} hex:"bytes16"{{else if $field.IsUint32}} hex:"uint32"{{else if $field.IsUint32List}} hex:"[]uint32"{{end}}` + {{capitalise $field.Name}} {{$field.Type.GoType}} `json:"{{$field.Name}}"{{if $field.IsOptional}} hex:"optional"{{else if $field.IsBytesHex}} hex:"bytes16"{{else if $field.IsUint32}} hex:"uint32"{{else if $field.IsUint32List}} hex:"[]uint32"{{end}}` {{- end}} } @@ -359,7 +359,7 @@ func (t *{{capitalise .Name}}) UnmarshalHex(data string) error { type {{capitalise .Name}}MCMSParams struct { {{- range .Fields}} {{- if not (isCallerField .Name)}} - {{capitalise .Name}} {{.Type.GoType}} `json:"{{.Name}}"{{if .IsBytesHex}} hex:"bytes16"{{else if .IsUint32}} hex:"uint32"{{else if .IsUint32List}} hex:"[]uint32"{{end}}` + {{capitalise .Name}} {{.Type.GoType}} `json:"{{.Name}}"{{if .IsOptional}} hex:"optional"{{else if .IsBytesHex}} hex:"bytes16"{{else if .IsUint32}} hex:"uint32"{{else if .IsUint32List}} hex:"[]uint32"{{end}}` {{- end}} {{- end}} } diff --git a/examples/codegen/all_kinds_of_1_0_0.go b/examples/codegen/all_kinds_of_1_0_0.go index e3107d4..c22439b 100644 --- a/examples/codegen/all_kinds_of_1_0_0.go +++ b/examples/codegen/all_kinds_of_1_0_0.go @@ -243,8 +243,8 @@ type OneOfEverything struct { SomeBoolean types.BOOL `json:"someBoolean"` SomeInteger types.INT64 `json:"someInteger"` SomeDecimal types.NUMERIC `json:"someDecimal"` - SomeMaybe *types.INT64 `json:"someMaybe"` - SomeMaybeNot *types.INT64 `json:"someMaybeNot"` + SomeMaybe *types.INT64 `json:"someMaybe" hex:"optional"` + SomeMaybeNot *types.INT64 `json:"someMaybeNot" hex:"optional"` SomeText types.TEXT `json:"someText"` SomeDate types.DATE `json:"someDate"` SomeDatetime types.TIMESTAMP `json:"someDatetime"` diff --git a/pkg/codec/hex_codec.go b/pkg/codec/hex_codec.go index 62d3d49..2a04890 100644 --- a/pkg/codec/hex_codec.go +++ b/pkg/codec/hex_codec.go @@ -387,6 +387,20 @@ func (c *HexCodec) encodeStruct(rv reflect.Value) ([]byte, error) { if err != nil { return nil, fmt.Errorf("failed to encode bytes16 field %s: %w", fieldType.Name, err) } + case "optional": + // hex:"optional" - encode Daml Optional type: 0x00 for None, 0x01 + value for Some + if field.Kind() != reflect.Ptr { + return nil, fmt.Errorf("hex:\"optional\" tag only valid on pointer fields, got %v", field.Kind()) + } + if field.IsNil() { + encoded = []byte{0x00} + } else { + valueEncoded, encErr := c.encode(field.Elem().Interface()) + if encErr != nil { + return nil, fmt.Errorf("failed to encode optional field %s: %w", fieldType.Name, encErr) + } + encoded = append([]byte{0x01}, valueEncoded...) + } default: // No tag or unknown tag - use default encoding encoded, err = c.encode(field.Interface()) @@ -874,6 +888,27 @@ func (c *HexCodec) decodeStruct(data []byte, offset int, target reflect.Value) ( rawBytes := data[offset : offset+byteCount] field.SetString(hex.EncodeToString(rawBytes)) offset += byteCount + case "optional": + // hex:"optional" - decode Daml Optional type: 0x00 for None, 0x01 + value for Some + if field.Kind() != reflect.Ptr { + return offset, fmt.Errorf("hex:\"optional\" tag only valid on pointer fields, got %v", field.Kind()) + } + if offset >= len(data) { + return offset, fmt.Errorf("not enough data for optional flag at offset %d", offset) + } + flag := data[offset] + offset++ + if flag == 0x01 { + newVal := reflect.New(field.Type().Elem()) + offset, err = c.decodeValue(data, offset, newVal.Elem()) + if err != nil { + return offset, fmt.Errorf("failed to decode optional field %s: %w", fieldType.Name, err) + } + field.Set(newVal) + } else if flag != 0x00 { + return offset, fmt.Errorf("invalid optional flag 0x%02x for field %s", flag, fieldType.Name) + } + // flag == 0x00: leave field as nil (zero value) default: // No tag or unknown tag - use default decoding offset, err = c.decodeValue(data, offset, field) diff --git a/pkg/codec/hex_codec_test.go b/pkg/codec/hex_codec_test.go index 70ad924..855620f 100644 --- a/pkg/codec/hex_codec_test.go +++ b/pkg/codec/hex_codec_test.go @@ -1072,3 +1072,447 @@ func TestHexCodec_RoundTrip_Bytes16(t *testing.T) { }) } } + +// Tests for hex:"optional" tag (Daml Optional encoding) + +type TestOptionalStruct struct { + Name string + Value *int `hex:"optional"` +} + +func TestHexCodec_EncodeOptional_Nil(t *testing.T) { + c := NewHexCodec() + s := TestOptionalStruct{ + Name: "test", + Value: nil, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Name: len=4 (04) + "test" (74657374) = 0474657374 + // Value: None = 00 + assert.Equal(t, "0474657374"+"00", result) +} + +func TestHexCodec_EncodeOptional_Some(t *testing.T) { + c := NewHexCodec() + val := 42 + s := TestOptionalStruct{ + Name: "test", + Value: &val, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Name: len=4 (04) + "test" (74657374) = 0474657374 + // Value: Some = 01 + int32(42) (0000002a) = 010000002a + assert.Equal(t, "0474657374"+"01"+"0000002a", result) +} + +func TestHexCodec_DecodeOptional_Nil(t *testing.T) { + c := NewHexCodec() + // Name: len=4 (04) + "test" (74657374) = 0474657374 + // Value: None = 00 + hexStr := "047465737400" + var s TestOptionalStruct + err := c.Unmarshal(hexStr, &s) + require.NoError(t, err) + assert.Equal(t, "test", s.Name) + assert.Nil(t, s.Value) +} + +func TestHexCodec_DecodeOptional_Some(t *testing.T) { + c := NewHexCodec() + // Name: len=4 (04) + "test" (74657374) = 0474657374 + // Value: Some = 01 + int32(42) = 010000002a + hexStr := "0474657374010000002a" + var s TestOptionalStruct + err := c.Unmarshal(hexStr, &s) + require.NoError(t, err) + assert.Equal(t, "test", s.Name) + require.NotNil(t, s.Value) + assert.Equal(t, 42, *s.Value) +} + +func TestHexCodec_RoundTrip_Optional_Nil(t *testing.T) { + c := NewHexCodec() + original := TestOptionalStruct{ + Name: "niltest", + Value: nil, + } + + encoded, err := c.Marshal(original) + require.NoError(t, err) + + var decoded TestOptionalStruct + err = c.Unmarshal(encoded, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Name, decoded.Name) + assert.Nil(t, decoded.Value) +} + +func TestHexCodec_RoundTrip_Optional_Some(t *testing.T) { + c := NewHexCodec() + val := 12345 + original := TestOptionalStruct{ + Name: "sometest", + Value: &val, + } + + encoded, err := c.Marshal(original) + require.NoError(t, err) + + var decoded TestOptionalStruct + err = c.Unmarshal(encoded, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Name, decoded.Name) + require.NotNil(t, decoded.Value) + assert.Equal(t, *original.Value, *decoded.Value) +} + +// Test optional with PARTY type +type TestOptionalPartyStruct struct { + Owner types.PARTY + Admin *types.PARTY `hex:"optional"` +} + +func TestHexCodec_EncodeOptional_Party_Nil(t *testing.T) { + c := NewHexCodec() + s := TestOptionalPartyStruct{ + Owner: types.PARTY("alice"), + Admin: nil, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Owner: len=5 (05) + "alice" (616c696365) = 05616c696365 + // Admin: None = 00 + assert.Equal(t, "05616c696365"+"00", result) +} + +func TestHexCodec_EncodeOptional_Party_Some(t *testing.T) { + c := NewHexCodec() + admin := types.PARTY("bob") + s := TestOptionalPartyStruct{ + Owner: types.PARTY("alice"), + Admin: &admin, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Owner: len=5 (05) + "alice" (616c696365) = 05616c696365 + // Admin: Some = 01 + len=3 (03) + "bob" (626f62) = 0103626f62 + assert.Equal(t, "05616c696365"+"01"+"03626f62", result) +} + +func TestHexCodec_RoundTrip_Optional_Party(t *testing.T) { + c := NewHexCodec() + admin := types.PARTY("bob") + original := TestOptionalPartyStruct{ + Owner: types.PARTY("alice"), + Admin: &admin, + } + + encoded, err := c.Marshal(original) + require.NoError(t, err) + + var decoded TestOptionalPartyStruct + err = c.Unmarshal(encoded, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Owner, decoded.Owner) + require.NotNil(t, decoded.Admin) + assert.Equal(t, *original.Admin, *decoded.Admin) +} + +// Test optional with INT64 type +type TestOptionalInt64Struct struct { + Count types.INT64 + Limit *types.INT64 `hex:"optional"` +} + +func TestHexCodec_EncodeOptional_INT64_Nil(t *testing.T) { + c := NewHexCodec() + s := TestOptionalInt64Struct{ + Count: types.INT64(100), + Limit: nil, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Count: int64(100) = 0000000000000064 + // Limit: None = 00 + assert.Equal(t, "0000000000000064"+"00", result) +} + +func TestHexCodec_EncodeOptional_INT64_Some(t *testing.T) { + c := NewHexCodec() + limit := types.INT64(500) + s := TestOptionalInt64Struct{ + Count: types.INT64(100), + Limit: &limit, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Count: int64(100) = 0000000000000064 + // Limit: Some = 01 + int64(500) = 0100000000000001f4 + assert.Equal(t, "0000000000000064"+"01"+"00000000000001f4", result) +} + +func TestHexCodec_RoundTrip_Optional_INT64(t *testing.T) { + c := NewHexCodec() + limit := types.INT64(999) + original := TestOptionalInt64Struct{ + Count: types.INT64(123), + Limit: &limit, + } + + encoded, err := c.Marshal(original) + require.NoError(t, err) + + var decoded TestOptionalInt64Struct + err = c.Unmarshal(encoded, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Count, decoded.Count) + require.NotNil(t, decoded.Limit) + assert.Equal(t, *original.Limit, *decoded.Limit) +} + +// Test optional with TEXT type +type TestOptionalTextStruct struct { + Title types.TEXT + Description *types.TEXT `hex:"optional"` +} + +func TestHexCodec_RoundTrip_Optional_TEXT(t *testing.T) { + c := NewHexCodec() + desc := types.TEXT("hello world") + original := TestOptionalTextStruct{ + Title: types.TEXT("test"), + Description: &desc, + } + + encoded, err := c.Marshal(original) + require.NoError(t, err) + + var decoded TestOptionalTextStruct + err = c.Unmarshal(encoded, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Title, decoded.Title) + require.NotNil(t, decoded.Description) + assert.Equal(t, *original.Description, *decoded.Description) +} + +// Test optional with nested struct +type InnerStruct struct { + X int + Y int +} + +type TestOptionalNestedStruct struct { + Name string + Inner *InnerStruct `hex:"optional"` +} + +func TestHexCodec_EncodeOptional_NestedStruct_Nil(t *testing.T) { + c := NewHexCodec() + s := TestOptionalNestedStruct{ + Name: "a", + Inner: nil, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Name: len=1 (01) + "a" (61) = 0161 + // Inner: None = 00 + assert.Equal(t, "0161"+"00", result) +} + +func TestHexCodec_EncodeOptional_NestedStruct_Some(t *testing.T) { + c := NewHexCodec() + s := TestOptionalNestedStruct{ + Name: "a", + Inner: &InnerStruct{ + X: 10, + Y: 20, + }, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Name: len=1 (01) + "a" (61) = 0161 + // Inner: Some = 01 + X(0000000a) + Y(00000014) = 010000000a00000014 + assert.Equal(t, "0161"+"01"+"0000000a"+"00000014", result) +} + +func TestHexCodec_RoundTrip_Optional_NestedStruct(t *testing.T) { + c := NewHexCodec() + original := TestOptionalNestedStruct{ + Name: "nested", + Inner: &InnerStruct{ + X: 100, + Y: 200, + }, + } + + encoded, err := c.Marshal(original) + require.NoError(t, err) + + var decoded TestOptionalNestedStruct + err = c.Unmarshal(encoded, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Name, decoded.Name) + require.NotNil(t, decoded.Inner) + assert.Equal(t, original.Inner.X, decoded.Inner.X) + assert.Equal(t, original.Inner.Y, decoded.Inner.Y) +} + +// Test optional with slice pointer +type TestOptionalSliceStruct struct { + Name string + Items *[]int `hex:"optional"` +} + +func TestHexCodec_EncodeOptional_Slice_Nil(t *testing.T) { + c := NewHexCodec() + s := TestOptionalSliceStruct{ + Name: "s", + Items: nil, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Name: len=1 (01) + "s" (73) = 0173 + // Items: None = 00 + assert.Equal(t, "0173"+"00", result) +} + +func TestHexCodec_EncodeOptional_Slice_Some(t *testing.T) { + c := NewHexCodec() + items := []int{1, 2, 3} + s := TestOptionalSliceStruct{ + Name: "s", + Items: &items, + } + result, err := c.Marshal(s) + require.NoError(t, err) + // Name: len=1 (01) + "s" (73) = 0173 + // Items: Some = 01 + len=3 (03) + 1 + 2 + 3 = 0103000000010000000200000003 + assert.Equal(t, "0173"+"01"+"03"+"00000001"+"00000002"+"00000003", result) +} + +func TestHexCodec_RoundTrip_Optional_Slice(t *testing.T) { + c := NewHexCodec() + items := []int{10, 20, 30} + original := TestOptionalSliceStruct{ + Name: "slice", + Items: &items, + } + + encoded, err := c.Marshal(original) + require.NoError(t, err) + + var decoded TestOptionalSliceStruct + err = c.Unmarshal(encoded, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Name, decoded.Name) + require.NotNil(t, decoded.Items) + assert.Equal(t, *original.Items, *decoded.Items) +} + +// Test mixed optional and non-optional fields +type TestMixedOptionalStruct struct { + Required1 string + Optional1 *int `hex:"optional"` + Required2 int + Optional2 *string `hex:"optional"` +} + +func TestHexCodec_RoundTrip_MixedOptional(t *testing.T) { + c := NewHexCodec() + + // Both optional fields set + val1 := 42 + val2 := "optional" + original := TestMixedOptionalStruct{ + Required1: "req1", + Optional1: &val1, + Required2: 100, + Optional2: &val2, + } + + encoded, err := c.Marshal(original) + require.NoError(t, err) + + var decoded TestMixedOptionalStruct + err = c.Unmarshal(encoded, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Required1, decoded.Required1) + require.NotNil(t, decoded.Optional1) + assert.Equal(t, *original.Optional1, *decoded.Optional1) + assert.Equal(t, original.Required2, decoded.Required2) + require.NotNil(t, decoded.Optional2) + assert.Equal(t, *original.Optional2, *decoded.Optional2) +} + +func TestHexCodec_RoundTrip_MixedOptional_SomeNil(t *testing.T) { + c := NewHexCodec() + + // One optional nil, one set + val2 := "only this" + original := TestMixedOptionalStruct{ + Required1: "req1", + Optional1: nil, + Required2: 200, + Optional2: &val2, + } + + encoded, err := c.Marshal(original) + require.NoError(t, err) + + var decoded TestMixedOptionalStruct + err = c.Unmarshal(encoded, &decoded) + require.NoError(t, err) + + assert.Equal(t, original.Required1, decoded.Required1) + assert.Nil(t, decoded.Optional1) + assert.Equal(t, original.Required2, decoded.Required2) + require.NotNil(t, decoded.Optional2) + assert.Equal(t, *original.Optional2, *decoded.Optional2) +} + +// Test error cases +func TestHexCodec_DecodeOptional_InvalidFlag(t *testing.T) { + c := NewHexCodec() + // Name: len=4 (04) + "test" (74657374) = 0474657374 + // Value: Invalid flag 0x02 + hexStr := "04746573740200000001" + var s TestOptionalStruct + err := c.Unmarshal(hexStr, &s) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid optional flag 0x02") +} + +func TestHexCodec_EncodeOptional_NonPointerField(t *testing.T) { + type BadOptionalStruct struct { + Value string `hex:"optional"` + } + + c := NewHexCodec() + s := BadOptionalStruct{Value: "test"} + _, err := c.Marshal(s) + require.Error(t, err) + assert.Contains(t, err.Error(), "hex:\"optional\" tag only valid on pointer fields") +} + +func TestHexCodec_DecodeOptional_NonPointerField(t *testing.T) { + type BadOptionalStruct struct { + Value string `hex:"optional"` + } + + c := NewHexCodec() + hexStr := "00" + var s BadOptionalStruct + err := c.Unmarshal(hexStr, &s) + require.Error(t, err) + assert.Contains(t, err.Error(), "hex:\"optional\" tag only valid on pointer fields") +} diff --git a/test-data/all_kinds_of_1_0_0.go_gen b/test-data/all_kinds_of_1_0_0.go_gen index e3107d4..c22439b 100644 --- a/test-data/all_kinds_of_1_0_0.go_gen +++ b/test-data/all_kinds_of_1_0_0.go_gen @@ -243,8 +243,8 @@ type OneOfEverything struct { SomeBoolean types.BOOL `json:"someBoolean"` SomeInteger types.INT64 `json:"someInteger"` SomeDecimal types.NUMERIC `json:"someDecimal"` - SomeMaybe *types.INT64 `json:"someMaybe"` - SomeMaybeNot *types.INT64 `json:"someMaybeNot"` + SomeMaybe *types.INT64 `json:"someMaybe" hex:"optional"` + SomeMaybeNot *types.INT64 `json:"someMaybeNot" hex:"optional"` SomeText types.TEXT `json:"someText"` SomeDate types.DATE `json:"someDate"` SomeDatetime types.TIMESTAMP `json:"someDatetime"`