Skip to content

Commit

Permalink
Adds optional max_diff_entries and max_database_entries in ComputeThr…
Browse files Browse the repository at this point in the history
…eatListDiffRequest call from api.go. Version bump from User Agent.

PiperOrigin-RevId: 639091343
  • Loading branch information
rvilgalys authored and copybara-github committed Jun 4, 2024
1 parent 9c1c0dc commit 98ed702
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
11 changes: 10 additions & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"

"google.golang.org/protobuf/encoding/protojson"
Expand All @@ -34,9 +35,11 @@ const (
threatTypeString = "threat_type"
versionTokenString = "version_token"
supportedCompressionsString = "constraints.supported_compressions"
maxDiffEntriesKey = "constraints.max_diff_entries"
maxDatabaseEntriesKey = "constraints.max_database_entries"
hashPrefixString = "hash_prefix"
threatTypesString = "threat_types"
userAgentString = "Webrisk-Client/0.2.1"
userAgentString = "Webrisk-Client/0.2.2"
)

// The api interface specifies wrappers around the Web Risk API.
Expand Down Expand Up @@ -131,6 +134,12 @@ func (a *netAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRe
if len(req.GetVersionToken()) != 0 {
q.Set(versionTokenString, base64.StdEncoding.EncodeToString(req.GetVersionToken()))
}
if req.GetConstraints().GetMaxDiffEntries() != 0 {
q.Add(maxDiffEntriesKey, strconv.FormatInt(int64(req.GetConstraints().GetMaxDiffEntries()), 10))
}
if req.GetConstraints().GetMaxDatabaseEntries() != 0 {
q.Add(maxDatabaseEntriesKey, strconv.FormatInt(int64(req.GetConstraints().GetMaxDatabaseEntries()), 10))
}
for _, compressionType := range req.GetConstraints().GetSupportedCompressions() {
q.Add(supportedCompressionsString, compressionType.String())
}
Expand Down
39 changes: 36 additions & 3 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -54,6 +55,8 @@ func TestNetAPI(t *testing.T) {
var gotReqHashPrefix, wantReqHashPrefix []byte
var gotReqThreatTypes, wantReqThreatTypes []pb.ThreatType
var gotResp, wantResp proto.Message
var gotMaxDiffEntries, wantMaxDiffEntries []int32
var gotMaxDatabaseEntries, wantMaxDatabaseEntries []int32
responseMisformatter := ""
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var p []byte
Expand Down Expand Up @@ -88,15 +91,33 @@ func TestNetAPI(t *testing.T) {
gotReqThreatTypes = append(gotReqThreatTypes,
pb.ThreatType(pb.ThreatType_value[threat]))
}
} else if key == maxDiffEntriesKey {
if len(value) == 0 {
t.Fatalf("Missing value for key %v", key)
}
i, err := strconv.ParseInt(value[0], 10, 32)
if err != nil {
t.Fatalf("Error parsing %q: %v", value[0], err)
}
gotMaxDiffEntries = append(gotMaxDiffEntries, int32(i))
} else if key == maxDatabaseEntriesKey {
if len(value) == 0 {
t.Fatalf("Missing value for key %v", key)
}
i, err := strconv.ParseInt(value[0], 10, 32)
if err != nil {
t.Fatalf("Error parsing %q: %v", value[0], err)
}
gotMaxDatabaseEntries = append(gotMaxDatabaseEntries, int32(i))
} else if key != "key" {
t.Fatalf("unexpected request param error for key: %v", key)
t.Fatalf("Unexpected request param error for key: %v", key)
}
}
if p, err = protojson.Marshal(wantResp); err != nil {
t.Fatalf("unexpected json MarshalToString error: %v", err)
t.Fatalf("Unexpected json MarshalToString error: %v", err)
}
if _, err := w.Write([]byte(responseMisformatter + string(p))); err != nil {
t.Fatalf("unexpected ResponseWriter.Write error: %v", err)
t.Fatalf("Unexpected ResponseWriter.Write error: %v", err)
}
}))
defer ts.Close()
Expand All @@ -109,6 +130,8 @@ func TestNetAPI(t *testing.T) {
// Test that ListUpdate marshal/unmarshal works.
wantReqThreatType = pb.ThreatType_MALWARE
wantReqCompressionTypes = []pb.CompressionType{0, 1, 2}
wantMaxDiffEntries = []int32{1024}
wantMaxDatabaseEntries = []int32{1024}

wantResp = &pb.ComputeThreatListDiffResponse{
ResponseType: 1,
Expand All @@ -121,6 +144,8 @@ func TestNetAPI(t *testing.T) {
ThreatType: wantReqThreatType,
Constraints: &pb.ComputeThreatListDiffRequest_Constraints{
SupportedCompressions: wantReqCompressionTypes,
MaxDiffEntries: 1024,
MaxDatabaseEntries: 1024,
},
VersionToken: []byte{},
}
Expand All @@ -140,6 +165,14 @@ func TestNetAPI(t *testing.T) {
if !proto.Equal(gotResp, wantResp) {
t.Errorf("mismatching ListUpdate responses:\ngot %+v\nwant %+v", gotResp, wantResp)
}
if !reflect.DeepEqual(gotMaxDiffEntries, wantMaxDiffEntries) {
t.Errorf("mismatching ListUpdate max diff entries:\ngot %+v\nwant %+v",
gotMaxDiffEntries, wantMaxDiffEntries)
}
if !reflect.DeepEqual(gotMaxDatabaseEntries, wantMaxDatabaseEntries) {
t.Errorf("mismatching ListUpdate max database entries:\ngot %+v\nwant %+v",
gotMaxDatabaseEntries, wantMaxDatabaseEntries)
}

// Test that HashLookup marshal/unmarshal works.
wantReqHashPrefix = []byte("aaaa")
Expand Down

0 comments on commit 98ed702

Please sign in to comment.