Skip to content

Commit 7dfe2bb

Browse files
committed
Add APIGatewayProxyStreamingResponse
1 parent 45c22d5 commit 7dfe2bb

File tree

5 files changed

+183
-12
lines changed

5 files changed

+183
-12
lines changed

events/apigw.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
package events
44

5+
import (
6+
"bytes"
7+
"encoding/json"
8+
"errors"
9+
"io"
10+
)
11+
512
// APIGatewayProxyRequest contains data coming from the API Gateway proxy
613
type APIGatewayProxyRequest struct {
714
Resource string `json:"resource"` // The resource path defined in API Gateway
@@ -27,6 +34,64 @@ type APIGatewayProxyResponse struct {
2734
IsBase64Encoded bool `json:"isBase64Encoded,omitempty"`
2835
}
2936

37+
// APIGatewayProxyStreamingResponse configures the response to be returned by API Gateway for the request.
38+
// - integration type must be AWS_PROXY
39+
// - integration uri must be arn:<partition>:apigateway:<region>:lambda:path/2021-11-15/functions/<function-arn>/response-streaming-invocations
40+
// - integration response transfer mode must be STREAM
41+
//
42+
// If not using the above streaming integration, use APIGatewayProxyResponse instead
43+
type APIGatewayProxyStreamingResponse struct {
44+
prelude *bytes.Buffer
45+
46+
StatusCode int
47+
Headers map[string]string
48+
MultiValueHeaders map[string][]string
49+
Body io.Reader
50+
Cookies []string
51+
}
52+
53+
func (r *APIGatewayProxyStreamingResponse) Read(p []byte) (n int, err error) {
54+
if r.prelude == nil {
55+
b, err := json.Marshal(struct {
56+
StatusCode int `json:"statusCode,omitempty"`
57+
Headers map[string]string `json:"headers,omitempty"`
58+
MultiValueHeaders map[string][]string `json:"multiValueHeaders,omitempty"`
59+
Cookies []string `json:"cookies,omitempty"`
60+
}{
61+
StatusCode: r.StatusCode,
62+
Headers: r.Headers,
63+
MultiValueHeaders: r.MultiValueHeaders,
64+
Cookies: r.Cookies,
65+
})
66+
if err != nil {
67+
return 0, err
68+
}
69+
r.prelude = bytes.NewBuffer(append(b, 0, 0, 0, 0, 0, 0, 0, 0))
70+
}
71+
if r.prelude.Len() > 0 {
72+
return r.prelude.Read(p)
73+
}
74+
if r.Body == nil {
75+
return 0, io.EOF
76+
}
77+
return r.Body.Read(p)
78+
}
79+
80+
func (r *APIGatewayProxyStreamingResponse) Close() error {
81+
if closer, ok := r.Body.(io.ReadCloser); ok {
82+
return closer.Close()
83+
}
84+
return nil
85+
}
86+
87+
func (r *APIGatewayProxyStreamingResponse) MarshalJSON() ([]byte, error) {
88+
return nil, errors.New("not json")
89+
}
90+
91+
func (r *APIGatewayProxyStreamingResponse) ContentType() string {
92+
return "application/vnd.awslambda.http-integration-response"
93+
}
94+
3095
// APIGatewayProxyRequestContext contains the information to identify the AWS account and resources invoking the
3196
// Lambda function. It also includes Cognito identity information for the caller.
3297
type APIGatewayProxyRequestContext struct {

events/apigw_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ package events
44

55
import (
66
"encoding/json"
7+
"errors"
78
"io/ioutil" //nolint: staticcheck
9+
"net/http"
10+
"strings"
811
"testing"
912

1013
"github.com/aws/aws-lambda-go/events/test"
1114
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
1216
)
1317

1418
func TestApiGatewayRequestMarshaling(t *testing.T) {
@@ -83,6 +87,80 @@ func TestApiGatewayResponseMalformedJson(t *testing.T) {
8387
test.TestMalformedJson(t, APIGatewayProxyResponse{})
8488
}
8589

90+
func TestAPIGatewayProxyStreamingResponseMarshaling(t *testing.T) {
91+
for _, test := range []struct {
92+
name string
93+
response *APIGatewayProxyStreamingResponse
94+
expectedHead string
95+
expectedBody string
96+
}{
97+
{
98+
"empty",
99+
&APIGatewayProxyStreamingResponse{},
100+
`{}`,
101+
"",
102+
},
103+
{
104+
"just the status code",
105+
&APIGatewayProxyStreamingResponse{
106+
StatusCode: http.StatusTeapot,
107+
},
108+
`{"statusCode":418}`,
109+
"",
110+
},
111+
{
112+
"status and headers and cookies and body",
113+
&APIGatewayProxyStreamingResponse{
114+
StatusCode: http.StatusTeapot,
115+
Headers: map[string]string{"hello": "world"},
116+
MultiValueHeaders: map[string][]string{"hi": {"1", "2"}},
117+
Cookies: []string{"cookies", "are", "yummy"},
118+
Body: strings.NewReader(`<html>Hello Hello</html>`),
119+
},
120+
`{"statusCode":418, "headers":{"hello":"world"}, "multiValueHeaders":{"hi":["1","2"]}, "cookies":["cookies","are","yummy"]}`,
121+
`<html>Hello Hello</html>`,
122+
},
123+
} {
124+
t.Run(test.name, func(t *testing.T) {
125+
response, err := ioutil.ReadAll(test.response)
126+
require.NoError(t, err)
127+
sep := "\x00\x00\x00\x00\x00\x00\x00\x00"
128+
responseParts := strings.Split(string(response), sep)
129+
require.Len(t, responseParts, 2)
130+
head := string(responseParts[0])
131+
body := string(responseParts[1])
132+
assert.JSONEq(t, test.expectedHead, head)
133+
assert.Equal(t, test.expectedBody, body)
134+
assert.NoError(t, test.response.Close())
135+
})
136+
}
137+
}
138+
139+
func TestAPIGatewayProxyStreamingResponsePropogatesInnerClose(t *testing.T) {
140+
for _, test := range []struct {
141+
name string
142+
closer *readCloser
143+
err error
144+
}{
145+
{
146+
"closer no err",
147+
&readCloser{},
148+
nil,
149+
},
150+
{
151+
"closer with err",
152+
&readCloser{err: errors.New("yolo")},
153+
errors.New("yolo"),
154+
},
155+
} {
156+
t.Run(test.name, func(t *testing.T) {
157+
response := &APIGatewayProxyStreamingResponse{Body: test.closer}
158+
assert.Equal(t, test.err, response.Close())
159+
assert.True(t, test.closer.closed)
160+
})
161+
}
162+
}
163+
86164
func TestApiGatewayCustomAuthorizerRequestMarshaling(t *testing.T) {
87165

88166
// read json from file

events/example_apigw_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package events_test
2+
3+
import (
4+
"strings"
5+
6+
"github.com/aws/aws-lambda-go/events"
7+
"github.com/aws/aws-lambda-go/lambda"
8+
)
9+
10+
func ExampleAPIGatewayProxyStreamingResponse() {
11+
lambda.Start(func() (*events.APIGatewayProxyStreamingResponse, error) {
12+
return &events.APIGatewayProxyStreamingResponse{
13+
StatusCode: 200,
14+
Headers: map[string]string{
15+
"Content-Type": "text/html",
16+
},
17+
Body: strings.NewReader("<html><body>Hello World!</body></html>"),
18+
}, nil
19+
})
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package events_test
2+
3+
import (
4+
"strings"
5+
6+
"github.com/aws/aws-lambda-go/events"
7+
"github.com/aws/aws-lambda-go/lambda"
8+
)
9+
10+
func ExampleLambdaFunctionURLStreamingResponse() {
11+
lambda.Start(func() (*events.LambdaFunctionURLStreamingResponse, error) {
12+
return &events.LambdaFunctionURLStreamingResponse{
13+
StatusCode: 200,
14+
Headers: map[string]string{
15+
"Content-Type": "text/html",
16+
},
17+
Body: strings.NewReader("<html><body>Hello World!</body></html>"),
18+
}, nil
19+
})
20+
}

events/lambda_function_urls.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,6 @@ type LambdaFunctionURLResponse struct {
7171
// LambdaFunctionURLStreamingResponse models the response to a Lambda Function URL when InvokeMode is RESPONSE_STREAM.
7272
// If the InvokeMode of the Function URL is BUFFERED (default), use LambdaFunctionURLResponse instead.
7373
//
74-
// Example:
75-
//
76-
// lambda.Start(func() (*events.LambdaFunctionURLStreamingResponse, error) {
77-
// return &events.LambdaFunctionURLStreamingResponse{
78-
// StatusCode: 200,
79-
// Headers: map[string]string{
80-
// "Content-Type": "text/html",
81-
// },
82-
// Body: strings.NewReader("<html><body>Hello World!</body></html>"),
83-
// }, nil
84-
// })
85-
//
8674
// Note: This response type requires compiling with `-tags lambda.norpc`, or choosing the `provided` or `provided.al2` runtime.
8775
type LambdaFunctionURLStreamingResponse struct {
8876
prelude *bytes.Buffer

0 commit comments

Comments
 (0)