diff --git a/propagation/baggage.go b/propagation/baggage.go index 552263ba734..5b341071229 100644 --- a/propagation/baggage.go +++ b/propagation/baggage.go @@ -29,6 +29,19 @@ func (b Baggage) Inject(ctx context.Context, carrier TextMapCarrier) { // Extract returns a copy of parent with the baggage from the carrier added. func (b Baggage) Extract(parent context.Context, carrier TextMapCarrier) context.Context { + multiCarrier, isMultiCarrier := carrier.(MultiTextMapCarrier) + if isMultiCarrier { + return extractMultiBaggage(parent, multiCarrier) + } + return extractSingleBaggage(parent, carrier) +} + +// Fields returns the keys who's values are set with Inject. +func (b Baggage) Fields() []string { + return []string{baggageHeader} +} + +func extractSingleBaggage(parent context.Context, carrier TextMapCarrier) context.Context { bStr := carrier.Get(baggageHeader) if bStr == "" { return parent @@ -41,7 +54,20 @@ func (b Baggage) Extract(parent context.Context, carrier TextMapCarrier) context return baggage.ContextWithBaggage(parent, bag) } -// Fields returns the keys who's values are set with Inject. -func (b Baggage) Fields() []string { - return []string{baggageHeader} +func extractMultiBaggage(parent context.Context, carrier MultiTextMapCarrier) context.Context { + bVals := carrier.GetAll(baggageHeader) + members := make([]baggage.Member, 0) + for _, bStr := range bVals { + currBag, err := baggage.Parse(bStr) + if err != nil { + continue + } + members = append(members, currBag.Members()...) + } + + b, err := baggage.New(members...) + if err != nil || b.Len() == 0 { + return parent + } + return baggage.ContextWithBaggage(parent, b) } diff --git a/propagation/baggage_test.go b/propagation/baggage_test.go index e6a71540fa2..0d7d81b5fe3 100644 --- a/propagation/baggage_test.go +++ b/propagation/baggage_test.go @@ -128,6 +128,55 @@ func TestExtractValidBaggageFromHTTPReq(t *testing.T) { } } +func TestExtractValidMultipleBaggageHeaders(t *testing.T) { + prop := propagation.TextMapPropagator(propagation.Baggage{}) + tests := []struct { + name string + headers []string + want members + }{ + { + name: "non conflicting headers", + headers: []string{"key1=val1", "key2=val2"}, + want: members{ + {Key: "key1", Value: "val1"}, + {Key: "key2", Value: "val2"}, + }, + }, + { + name: "conflicting keys, uses last val", + headers: []string{"key1=val1", "key1=val2"}, + want: members{ + {Key: "key1", Value: "val2"}, + }, + }, + { + name: "single empty", + headers: []string{"", "key1=val1"}, + want: members{ + {Key: "key1", Value: "val1"}, + }, + }, + { + name: "all empty", + headers: []string{"", ""}, + want: members{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + req.Header["Baggage"] = tt.headers + + ctx := context.Background() + ctx = prop.Extract(ctx, propagation.HeaderCarrier(req.Header)) + expected := tt.want.Baggage(t) + assert.Equal(t, expected, baggage.FromContext(ctx)) + }) + } +} + func TestExtractInvalidDistributedContextFromHTTPReq(t *testing.T) { prop := propagation.TextMapPropagator(propagation.Baggage{}) tests := []struct { diff --git a/propagation/propagation.go b/propagation/propagation.go index 8c8286aab4d..9d4a729533c 100644 --- a/propagation/propagation.go +++ b/propagation/propagation.go @@ -29,6 +29,15 @@ type TextMapCarrier interface { // must never be done outside of a new major release. } +// MultiTextMapCarrier is a TextMapCarrier that can return multiple values for a single key. +type MultiTextMapCarrier interface { + TextMapCarrier + // GetAll returns all values associated with the passed key. + GetAll(key string) []string + // DO NOT CHANGE: any modification will not be backwards compatible and + // must never be done outside of a new major release. +} + // MapCarrier is a TextMapCarrier that uses a map held in memory as a storage // medium for propagated key-value pairs. type MapCarrier map[string]string @@ -58,11 +67,16 @@ func (c MapCarrier) Keys() []string { // HeaderCarrier adapts http.Header to satisfy the TextMapCarrier interface. type HeaderCarrier http.Header -// Get returns the value associated with the passed key. +// Get returns the first value associated with the passed key. func (hc HeaderCarrier) Get(key string) string { return http.Header(hc).Get(key) } +// GetAll returns all values associated with the passed key. +func (hc HeaderCarrier) GetAll(key string) []string { + return http.Header(hc).Values(key) +} + // Set stores the key-value pair. func (hc HeaderCarrier) Set(key string, value string) { http.Header(hc).Set(key, value)