Skip to content

Commit ef82168

Browse files
committed
encoding/xml: make use of reflect.TypeAssert
To simplify the code and make Unmarshal use less memory: goos: darwin goarch: arm64 pkg: encoding/xml cpu: Apple M4 │ old │ new │ │ sec/op │ sec/op vs base │ Unmarshal-10 3.818µ ± 1% 3.808µ ± 2% ~ (p=0.869 n=10) │ old │ new │ │ B/op │ B/op vs base │ Unmarshal-10 7.586Ki ± 0% 7.555Ki ± 0% -0.41% (p=0.000 n=10) │ old │ new │ │ allocs/op │ allocs/op vs base │ Unmarshal-10 185.0 ± 0% 184.0 ± 0% -0.54% (p=0.000 n=10) Updates #62121
1 parent ca0e035 commit ef82168

File tree

2 files changed

+45
-24
lines changed

2 files changed

+45
-24
lines changed

src/encoding/xml/marshal.go

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -451,23 +451,27 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
451451

452452
// Check for marshaler.
453453
if val.CanInterface() && typ.Implements(marshalerType) {
454-
return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
454+
marshaler, _ := reflect.TypeAssert[Marshaler](val)
455+
return p.marshalInterface(marshaler, defaultStart(typ, finfo, startTemplate))
455456
}
456457
if val.CanAddr() {
457458
pv := val.Addr()
458459
if pv.CanInterface() && pv.Type().Implements(marshalerType) {
459-
return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
460+
marshaler, _ := reflect.TypeAssert[Marshaler](pv)
461+
return p.marshalInterface(marshaler, defaultStart(pv.Type(), finfo, startTemplate))
460462
}
461463
}
462464

463465
// Check for text marshaler.
464466
if val.CanInterface() && typ.Implements(textMarshalerType) {
465-
return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
467+
textMarshaler, _ := reflect.TypeAssert[encoding.TextMarshaler](val)
468+
return p.marshalTextInterface(textMarshaler, defaultStart(typ, finfo, startTemplate))
466469
}
467470
if val.CanAddr() {
468471
pv := val.Addr()
469472
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
470-
return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
473+
textMarshaler, _ := reflect.TypeAssert[encoding.TextMarshaler](pv)
474+
return p.marshalTextInterface(textMarshaler, defaultStart(pv.Type(), finfo, startTemplate))
471475
}
472476
}
473477

@@ -503,7 +507,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
503507
start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
504508
} else {
505509
fv := xmlname.value(val, dontInitNilPointers)
506-
if v, ok := fv.Interface().(Name); ok && v.Local != "" {
510+
if v, ok := reflect.TypeAssert[Name](fv); ok && v.Local != "" {
507511
start.Name = v
508512
}
509513
}
@@ -581,7 +585,8 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
581585
// marshalAttr marshals an attribute with the given name and value, adding to start.Attr.
582586
func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) error {
583587
if val.CanInterface() && val.Type().Implements(marshalerAttrType) {
584-
attr, err := val.Interface().(MarshalerAttr).MarshalXMLAttr(name)
588+
marshaler, _ := reflect.TypeAssert[MarshalerAttr](val)
589+
attr, err := marshaler.MarshalXMLAttr(name)
585590
if err != nil {
586591
return err
587592
}
@@ -594,7 +599,8 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
594599
if val.CanAddr() {
595600
pv := val.Addr()
596601
if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
597-
attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
602+
marshaler, _ := reflect.TypeAssert[MarshalerAttr](pv)
603+
attr, err := marshaler.MarshalXMLAttr(name)
598604
if err != nil {
599605
return err
600606
}
@@ -606,7 +612,8 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
606612
}
607613

608614
if val.CanInterface() && val.Type().Implements(textMarshalerType) {
609-
text, err := val.Interface().(encoding.TextMarshaler).MarshalText()
615+
textMarshaler, _ := reflect.TypeAssert[encoding.TextMarshaler](val)
616+
text, err := textMarshaler.MarshalText()
610617
if err != nil {
611618
return err
612619
}
@@ -617,7 +624,8 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
617624
if val.CanAddr() {
618625
pv := val.Addr()
619626
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
620-
text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
627+
textMarshaler, _ := reflect.TypeAssert[encoding.TextMarshaler](pv)
628+
text, err := textMarshaler.MarshalText()
621629
if err != nil {
622630
return err
623631
}
@@ -647,7 +655,8 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
647655
}
648656

649657
if val.Type() == attrType {
650-
start.Attr = append(start.Attr, val.Interface().(Attr))
658+
attr, _ := reflect.TypeAssert[Attr](val)
659+
start.Attr = append(start.Attr, attr)
651660
return nil
652661
}
653662

@@ -855,7 +864,8 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
855864
return err
856865
}
857866
if vf.CanInterface() && vf.Type().Implements(textMarshalerType) {
858-
data, err := vf.Interface().(encoding.TextMarshaler).MarshalText()
867+
textMarshaler, _ := reflect.TypeAssert[encoding.TextMarshaler](vf)
868+
data, err := textMarshaler.MarshalText()
859869
if err != nil {
860870
return err
861871
}
@@ -867,7 +877,8 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
867877
if vf.CanAddr() {
868878
pv := vf.Addr()
869879
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
870-
data, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
880+
textMarshaler, _ := reflect.TypeAssert[encoding.TextMarshaler](pv)
881+
data, err := textMarshaler.MarshalText()
871882
if err != nil {
872883
return err
873884
}
@@ -902,7 +913,7 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
902913
return err
903914
}
904915
case reflect.Slice:
905-
if elem, ok := vf.Interface().([]byte); ok {
916+
if elem, ok := reflect.TypeAssert[[]byte](vf); ok {
906917
if err := emit(p, elem); err != nil {
907918
return err
908919
}

src/encoding/xml/read.go

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,25 +258,29 @@ func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
258258
if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) {
259259
// This is an unmarshaler with a non-pointer receiver,
260260
// so it's likely to be incorrect, but we do what we're told.
261-
return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
261+
unmarshaler, _ := reflect.TypeAssert[UnmarshalerAttr](val)
262+
return unmarshaler.UnmarshalXMLAttr(attr)
262263
}
263264
if val.CanAddr() {
264265
pv := val.Addr()
265266
if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) {
266-
return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
267+
unmarshaler, _ := reflect.TypeAssert[UnmarshalerAttr](pv)
268+
return unmarshaler.UnmarshalXMLAttr(attr)
267269
}
268270
}
269271

270272
// Not an UnmarshalerAttr; try encoding.TextUnmarshaler.
271273
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
272274
// This is an unmarshaler with a non-pointer receiver,
273275
// so it's likely to be incorrect, but we do what we're told.
274-
return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
276+
textUnmarshaler, _ := reflect.TypeAssert[encoding.TextUnmarshaler](val)
277+
return textUnmarshaler.UnmarshalText([]byte(attr.Value))
275278
}
276279
if val.CanAddr() {
277280
pv := val.Addr()
278281
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
279-
return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
282+
textUnmarshaler, _ := reflect.TypeAssert[encoding.TextUnmarshaler](pv)
283+
return textUnmarshaler.UnmarshalText([]byte(attr.Value))
280284
}
281285
}
282286

@@ -355,24 +359,28 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) e
355359
if val.CanInterface() && val.Type().Implements(unmarshalerType) {
356360
// This is an unmarshaler with a non-pointer receiver,
357361
// so it's likely to be incorrect, but we do what we're told.
358-
return d.unmarshalInterface(val.Interface().(Unmarshaler), start)
362+
unmarshaler, _ := reflect.TypeAssert[Unmarshaler](val)
363+
return d.unmarshalInterface(unmarshaler, start)
359364
}
360365

361366
if val.CanAddr() {
362367
pv := val.Addr()
363368
if pv.CanInterface() && pv.Type().Implements(unmarshalerType) {
364-
return d.unmarshalInterface(pv.Interface().(Unmarshaler), start)
369+
unmarshaler, _ := reflect.TypeAssert[Unmarshaler](pv)
370+
return d.unmarshalInterface(unmarshaler, start)
365371
}
366372
}
367373

368374
if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
369-
return d.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler))
375+
textUnmarshaler, _ := reflect.TypeAssert[encoding.TextUnmarshaler](val)
376+
return d.unmarshalTextInterface(textUnmarshaler)
370377
}
371378

372379
if val.CanAddr() {
373380
pv := val.Addr()
374381
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
375-
return d.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler))
382+
textUnmarshaler, _ := reflect.TypeAssert[encoding.TextUnmarshaler](pv)
383+
return d.unmarshalTextInterface(textUnmarshaler)
376384
}
377385
}
378386

@@ -453,7 +461,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) e
453461
return UnmarshalError(e)
454462
}
455463
fv := finfo.value(sv, initNilPointers)
456-
if _, ok := fv.Interface().(Name); ok {
464+
if _, ok := reflect.TypeAssert[Name](fv); ok {
457465
fv.Set(reflect.ValueOf(start.Name))
458466
}
459467
}
@@ -579,7 +587,8 @@ Loop:
579587
}
580588

581589
if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) {
582-
if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
590+
textUnmarshaler, _ := reflect.TypeAssert[encoding.TextUnmarshaler](saveData)
591+
if err := textUnmarshaler.UnmarshalText(data); err != nil {
583592
return err
584593
}
585594
saveData = reflect.Value{}
@@ -588,7 +597,8 @@ Loop:
588597
if saveData.IsValid() && saveData.CanAddr() {
589598
pv := saveData.Addr()
590599
if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
591-
if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
600+
textUnmarshaler, _ := reflect.TypeAssert[encoding.TextUnmarshaler](pv)
601+
if err := textUnmarshaler.UnmarshalText(data); err != nil {
592602
return err
593603
}
594604
saveData = reflect.Value{}

0 commit comments

Comments
 (0)