Skip to content

Commit 78e13dd

Browse files
committed
checked that v is a pointer or nil before doing the request
1 parent 4a802af commit 78e13dd

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

client.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"net/http"
9+
"reflect"
910
"strings"
1011

1112
"github.com/vtopc/go-rest/defaults"
@@ -35,11 +36,18 @@ func NewClient(client *http.Client) *Client {
3536

3637
// Do executes HTTP request.
3738
//
38-
// Stores the result in the value pointed to by v. If v is nil or not a pointer,
39-
// Do returns an InvalidUnmarshalError.
39+
// Stores the result in the value pointed to by v. If v is not a nil and not a pointer,
40+
// Do returns a json.InvalidUnmarshalError.
4041
// Use func `http.NewRequestWithContext` to create `req`.
4142
func (c *Client) Do(req *http.Request, v interface{}, expectedStatusCodes ...int) error {
42-
// TODO: check that `v` is a pointer or nil
43+
// check that `v` is a pointer or nil before doing the request.
44+
if v != nil {
45+
rv := reflect.ValueOf(v)
46+
if rv.Kind() != reflect.Ptr {
47+
return &json.InvalidUnmarshalError{Type: reflect.TypeOf(v)}
48+
}
49+
}
50+
4351
if req == nil {
4452
return errors.New("empty request")
4553
}

client_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func TestClientDo(t *testing.T) {
2424
respBody []byte
2525
v interface{}
2626
want interface{}
27+
wantErr error
2728
wantWrappedErr error
2829
}{
2930
"positive_get": {
@@ -65,6 +66,16 @@ func TestClientDo(t *testing.T) {
6566
wantWrappedErr: errors.New("wrong status code (500 not in [200]): {\"error\":\"some error\"}"),
6667
},
6768

69+
"negative_not_a_pointer": {
70+
method: http.MethodGet,
71+
urlPostfix: "/health",
72+
statusCode: http.StatusOK,
73+
expectedStatusCode: http.StatusOK,
74+
respBody: []byte(`{"status":"ok"}`),
75+
v: Struct{},
76+
wantErr: errors.New("json: Unmarshal(non-pointer rest.Struct)"),
77+
},
78+
6879
// TODO: add more test cases
6980
}
7081

@@ -94,6 +105,10 @@ func TestClientDo(t *testing.T) {
94105

95106
// test:
96107
err = c.Do(req, tt.v, tt.expectedStatusCode)
108+
if tt.wantErr != nil {
109+
require.EqualError(t, err, tt.wantErr.Error())
110+
return
111+
}
97112
if tt.wantWrappedErr != nil {
98113
require.EqualError(t, errors.Unwrap(err), tt.wantWrappedErr.Error())
99114
return

error_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func TestStatusCodeFromAPIError(t *testing.T) {
4848

4949
v interface{} // for .Do(...)
5050

51-
wantErr error
51+
wantErr bool
5252
wantStatusCode int
5353
}{
5454
"no_errors": {
@@ -62,21 +62,21 @@ func TestStatusCodeFromAPIError(t *testing.T) {
6262
statusCode: 400,
6363
expectedStatusCode: 201,
6464
body: []byte(`{"errors":[{"message":"test error"}]}`),
65-
wantErr: errors.New("wrong status code (400 not in [201]): {\"errors\":[{\"message\":\"test error\"}]}"),
65+
wantErr: true,
6666
wantStatusCode: 400,
6767
},
6868
"not_found": {
6969
statusCode: 404,
7070
expectedStatusCode: 200,
7171
body: []byte(`{"errors":[{"message":"the entity not found"}]}`),
72-
wantErr: errors.New("wrong status code (404 not in [200]): {\"errors\":[{\"message\":\"the entity not found\"}]}"),
72+
wantErr: true,
7373
wantStatusCode: 404,
7474
},
7575
"not_APIError": {
7676
statusCode: 200,
7777
expectedStatusCode: 200,
7878
body: []byte(`{"foo":"bar"}`),
79-
wantErr: errors.New("failed to unmarshal the response body: json: Unmarshal(non-pointer chan struct {})"),
79+
wantErr: true,
8080
v: make(chan struct{}), // a channel just to fail Unmarshal
8181
wantStatusCode: 500,
8282
},
@@ -85,7 +85,7 @@ func TestStatusCodeFromAPIError(t *testing.T) {
8585
expectedStatusCode: 200,
8686
body: []byte(`{"foo":"bar"}`),
8787
v: new(S),
88-
wantErr: errors.New(`wrong status code (201 not in [200]): {"foo":"bar"}`),
88+
wantErr: true,
8989
wantStatusCode: 500,
9090
},
9191
}
@@ -109,10 +109,10 @@ func TestStatusCodeFromAPIError(t *testing.T) {
109109

110110
err = client.Do(req, tt.v, tt.expectedStatusCode)
111111
t.Logf("got error: %v", err)
112-
if tt.wantErr == nil {
113-
assert.NoError(t, err)
112+
if tt.wantErr {
113+
assert.Error(t, err)
114114
} else {
115-
assert.EqualError(t, errors.Unwrap(err), tt.wantErr.Error())
115+
assert.NoError(t, err)
116116
}
117117

118118
// got := StatusCodeFromAPIError(err)

0 commit comments

Comments
 (0)