From 459978310bc15020290c3cd32523d4960e3686e8 Mon Sep 17 00:00:00 2001 From: Rim Vilgalys Date: Fri, 31 May 2024 11:09:44 -0700 Subject: [PATCH] Adds optional max_diff_entries and max_database_entries in ComputeThreatListDiffRequest call from api.go PiperOrigin-RevId: 639091343 --- api.go | 23 +++++++++------ api_test.go | 54 +++++++++++++++++++++++++++++------ database.go | 7 +++-- webrisk_client.go | 35 +++++++++++++++++++++-- webrisk_client_system_test.go | 3 +- webrisk_client_test.go | 47 ++++++++++++++++++++++++++++++ 6 files changed, 147 insertions(+), 22 deletions(-) diff --git a/api.go b/api.go index ffa663f..410334b 100644 --- a/api.go +++ b/api.go @@ -20,6 +20,7 @@ import ( "io/ioutil" "net/http" "net/url" + "strconv" "strings" "google.golang.org/protobuf/encoding/protojson" @@ -34,6 +35,8 @@ const ( threatTypeString = "threat_type" versionTokenString = "version_token" supportedCompressionsString = "constraints.supported_compressions" + maxDiffEntriesKey = "constraints.max_diff_entries" + maxDatabaseEntriesKey = "constraints.max_database_entries" hashPrefixString = "hash_prefix" threatTypesString = "threat_types" userAgentString = "Webrisk-Client/0.2.1" @@ -41,8 +44,7 @@ const ( // The api interface specifies wrappers around the Web Risk API. type api interface { - ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte, - compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) + ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) HashLookup(ctx context.Context, hashPrefix []byte, threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error) } @@ -123,17 +125,22 @@ func (a *netAPI) parseError(httpResp *http.Response) error { } // ListUpdate issues a ComputeThreatListDiff API call and returns the response. -func (a *netAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte, - compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) { +func (a *netAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) { resp := new(pb.ComputeThreatListDiffResponse) u := *a.url // Make a copy of URL // Add fields from ComputeThreatListDiffRequest to URL request q := u.Query() - q.Set(threatTypeString, threatType.String()) - if len(versionToken) != 0 { - q.Set(versionTokenString, base64.StdEncoding.EncodeToString(versionToken)) + q.Set(threatTypeString, req.GetThreatType().String()) + if len(req.GetVersionToken()) != 0 { + q.Set(versionTokenString, base64.StdEncoding.EncodeToString(req.GetVersionToken())) } - for _, compressionType := range compressionTypes { + if req.GetConstraints().GetMaxDiffEntries() != 0 { + q.Add(maxDiffEntriesKey, strconv.FormatInt(int64(req.GetConstraints().GetMaxDiffEntries()), 10)) + } + if req.GetConstraints().GetMaxDatabaseEntries() != 0 { + q.Add(maxDatabaseEntriesKey, strconv.FormatInt(int64(req.GetConstraints().GetMaxDatabaseEntries()), 10)) + } + for _, compressionType := range req.GetConstraints().GetSupportedCompressions() { q.Add(supportedCompressionsString, compressionType.String()) } u.RawQuery = q.Encode() diff --git a/api_test.go b/api_test.go index 320c4f7..7546595 100644 --- a/api_test.go +++ b/api_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strconv" "testing" "github.com/google/go-cmp/cmp" @@ -39,9 +40,8 @@ type mockAPI struct { threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error) } -func (m *mockAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte, - compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) { - return m.listUpdate(ctx, threatType, versionToken, compressionTypes) +func (m *mockAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) { + return m.listUpdate(ctx, req.GetThreatType(), req.GetVersionToken(), req.GetConstraints().GetSupportedCompressions()) } func (m *mockAPI) HashLookup(ctx context.Context, hashPrefix []byte, @@ -55,6 +55,8 @@ func TestNetAPI(t *testing.T) { var gotReqHashPrefix, wantReqHashPrefix []byte var gotReqThreatTypes, wantReqThreatTypes []pb.ThreatType var gotResp, wantResp proto.Message + var gotMaxDiffEntries, wantMaxDiffEntries []int32 + var gotMaxDatabaseEntries, wantMaxDatabaseEntries []int32 responseMisformatter := "" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var p []byte @@ -89,15 +91,33 @@ func TestNetAPI(t *testing.T) { gotReqThreatTypes = append(gotReqThreatTypes, pb.ThreatType(pb.ThreatType_value[threat])) } + } else if key == maxDiffEntriesKey { + if len(value) == 0 { + t.Fatalf("Missing value for key %v", key) + } + i, err := strconv.ParseInt(value[0], 10, 32) + if err != nil { + t.Fatalf("Error parsing %q: %v", value[0], err) + } + gotMaxDiffEntries = append(gotMaxDiffEntries, int32(i)) + } else if key == maxDatabaseEntriesKey { + if len(value) == 0 { + t.Fatalf("Missing value for key %v", key) + } + i, err := strconv.ParseInt(value[0], 10, 32) + if err != nil { + t.Fatalf("Error parsing %q: %v", value[0], err) + } + gotMaxDatabaseEntries = append(gotMaxDatabaseEntries, int32(i)) } else if key != "key" { - t.Fatalf("unexpected request param error for key: %v", key) + t.Fatalf("Unexpected request param error for key: %v", key) } } if p, err = protojson.Marshal(wantResp); err != nil { - t.Fatalf("unexpected json MarshalToString error: %v", err) + t.Fatalf("Unexpected json MarshalToString error: %v", err) } if _, err := w.Write([]byte(responseMisformatter + string(p))); err != nil { - t.Fatalf("unexpected ResponseWriter.Write error: %v", err) + t.Fatalf("Unexpected ResponseWriter.Write error: %v", err) } })) defer ts.Close() @@ -110,6 +130,8 @@ func TestNetAPI(t *testing.T) { // Test that ListUpdate marshal/unmarshal works. wantReqThreatType = pb.ThreatType_MALWARE wantReqCompressionTypes = []pb.CompressionType{0, 1, 2} + wantMaxDiffEntries = []int32{1024} + wantMaxDatabaseEntries = []int32{1024} wantResp = &pb.ComputeThreatListDiffResponse{ ResponseType: 1, @@ -118,8 +140,16 @@ func TestNetAPI(t *testing.T) { RawIndices: &pb.RawIndices{Indices: []int32{1, 2, 3}}, }, } - resp1, err := api.ListUpdate(context.Background(), wantReqThreatType, []byte{}, - wantReqCompressionTypes) + req := &pb.ComputeThreatListDiffRequest{ + ThreatType: wantReqThreatType, + Constraints: &pb.ComputeThreatListDiffRequest_Constraints{ + SupportedCompressions: wantReqCompressionTypes, + MaxDiffEntries: 1024, + MaxDatabaseEntries: 1024, + }, + VersionToken: []byte{}, + } + resp1, err := api.ListUpdate(context.Background(), req) gotResp = resp1 if err != nil { t.Errorf("unexpected ListUpdate error: %v", err) @@ -135,6 +165,14 @@ func TestNetAPI(t *testing.T) { if !proto.Equal(gotResp, wantResp) { t.Errorf("mismatching ListUpdate responses:\ngot %+v\nwant %+v", gotResp, wantResp) } + if !reflect.DeepEqual(gotMaxDiffEntries, wantMaxDiffEntries) { + t.Errorf("mismatching ListUpdate max diff entries:\ngot %+v\nwant %+v", + gotMaxDiffEntries, wantMaxDiffEntries) + } + if !reflect.DeepEqual(gotMaxDatabaseEntries, wantMaxDatabaseEntries) { + t.Errorf("mismatching ListUpdate max database entries:\ngot %+v\nwant %+v", + gotMaxDatabaseEntries, wantMaxDatabaseEntries) + } // Test that HashLookup marshal/unmarshal works. wantReqHashPrefix = []byte("aaaa") diff --git a/database.go b/database.go index 0aacf0c..b721ee0 100644 --- a/database.go +++ b/database.go @@ -200,7 +200,10 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) { s = append(s, &pb.ComputeThreatListDiffRequest{ ThreatType: pb.ThreatType(td), Constraints: &pb.ComputeThreatListDiffRequest_Constraints{ - SupportedCompressions: db.config.compressionTypes}, + SupportedCompressions: db.config.compressionTypes, + MaxDiffEntries: db.config.MaxDiffEntries, + MaxDatabaseEntries: db.config.MaxDatabaseEntries, + }, VersionToken: state, }) } @@ -212,7 +215,7 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) { last := db.config.now() for _, req := range s { // Query the API for the threat list and update the database. - resp, err := api.ListUpdate(ctx, req.ThreatType, req.VersionToken, req.Constraints.SupportedCompressions) + resp, err := api.ListUpdate(ctx, req) if err != nil { db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err) db.setError(err) diff --git a/webrisk_client.go b/webrisk_client.go index ea4f98a..46aa5fc 100644 --- a/webrisk_client.go +++ b/webrisk_client.go @@ -94,8 +94,9 @@ const ( // Errors specific to this package. var ( - errClosed = errors.New("webrisk: handler is closed") - errStale = errors.New("webrisk: threat list is stale") + errClosed = errors.New("webrisk: handler is closed") + errStale = errors.New("webrisk: threat list is stale") + errMaxEntries = errors.New("webrisk: max entries must be a power of 2 between 2 ** 10 and 2 ** 20") ) // ThreatType is an enumeration type for threats classes. Examples of threat @@ -167,6 +168,18 @@ type Config struct { // If empty, ThreatLists will be loaded instead. ThreatListArg string + // MaxDiffEntries sets the maximum entries to request in a single call to ComputeThreatListDiff. + // This can be used in resource-constrained environments to limit the number of entries fetched + // at once. The default behavior (0) is to ignore this limit. + // If set, this should be a power of 2 between 2 ** 10 and 2 ** 20. + MaxDiffEntries int32 + + // MaxDatabaseEntries sets the maximum entries that the client will accept & store in the local + // database. This can be used to limit the size of the blocklist at the trade-off of decreased + // coverage. The default behavior (0) is to ignore this limit. + // If set, this should be a power of 2 between 2 ** 10 and 2 ** 20. + MaxDatabaseEntries int32 + // ThreatLists determines which threat lists that UpdateClient should // subscribe to. The threats reported by LookupURLs will only be ones that // are specified by this list. @@ -229,6 +242,19 @@ func parseThreatTypes(args string) ([]ThreatType, error) { return r, nil } +// validateMaxEntries validates a max entries argument, which must be either 0 or a power of 2 +// between 2 ** 10 and 2 ** 20. +func validateMaxEntries(n int32) error { + if n == 0 { + return nil + } + // Bitwise check confirms a power of 2. + if n&(n-1) != 0 || n < 1024 || n > 1048576 { + return errMaxEntries + } + return nil +} + func (c Config) copy() Config { c2 := c c2.ThreatLists = append([]ThreatType(nil), c.ThreatLists...) @@ -287,6 +313,11 @@ func NewUpdateClient(conf Config) (*UpdateClient, error) { conf.ThreatLists = tl } + // Validate max entries if args are passed. + if err := validateMaxEntries(conf.MaxDiffEntries); err != nil { + return nil, err + } + // Create the SafeBrowsing object. if conf.api == nil { var err error diff --git a/webrisk_client_system_test.go b/webrisk_client_system_test.go index 2a74fbc..60355c1 100644 --- a/webrisk_client_system_test.go +++ b/webrisk_client_system_test.go @@ -42,8 +42,7 @@ func TestNetworkAPIUpdate(t *testing.T) { ThreatType: pb.ThreatType_MALWARE, } - dat, err := nm.ListUpdate(context.Background(), req.ThreatType, - req.VersionToken, []pb.CompressionType{}) + dat, err := nm.ListUpdate(context.Background(), req) if err != nil { t.Fatal(err) } diff --git a/webrisk_client_test.go b/webrisk_client_test.go index d42a392..5c56b66 100644 --- a/webrisk_client_test.go +++ b/webrisk_client_test.go @@ -65,3 +65,50 @@ func TestParseThreatTypes(t *testing.T) { } } } + +func TestValidateMaxEntries(t *testing.T) { + tests := []struct { + n int32 + wantErr error + }{ + { + n: 0, + wantErr: nil, + }, + { + n: 1024, + wantErr: nil, + }, + { + n: 4096, + wantErr: nil, + }, + { + n: 1048576, + wantErr: nil, + }, + { + n: -1024, + wantErr: errMaxEntries, + }, + { + n: 100, + wantErr: errMaxEntries, + }, + { + n: 1026, + wantErr: errMaxEntries, + }, + { + n: 2097152, + wantErr: errMaxEntries, + }, + } + + for _, tc := range tests { + gotErr := validateMaxEntries(tc.n) + if gotErr != tc.wantErr { + t.Errorf("validateMaxEntries(%d) = %v, want %v", tc.n, gotErr, tc.wantErr) + } + } +}