From b36ed2648cbd0bd42cae322966390fff00a0855b Mon Sep 17 00:00:00 2001 From: Rim Vilgalys Date: Fri, 31 May 2024 08:40:06 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 639046109 --- api.go | 14 ++++++-------- api_test.go | 15 ++++++++++----- database.go | 2 +- webrisk_client_system_test.go | 3 +-- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/api.go b/api.go index ffa663f..17c6238 100644 --- a/api.go +++ b/api.go @@ -41,8 +41,7 @@ const ( // The api interface specifies wrappers around the Web Risk API. type api interface { - ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte, - compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) + ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) HashLookup(ctx context.Context, hashPrefix []byte, threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error) } @@ -123,17 +122,16 @@ func (a *netAPI) parseError(httpResp *http.Response) error { } // ListUpdate issues a ComputeThreatListDiff API call and returns the response. -func (a *netAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte, - compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) { +func (a *netAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) { resp := new(pb.ComputeThreatListDiffResponse) u := *a.url // Make a copy of URL // Add fields from ComputeThreatListDiffRequest to URL request q := u.Query() - q.Set(threatTypeString, threatType.String()) - if len(versionToken) != 0 { - q.Set(versionTokenString, base64.StdEncoding.EncodeToString(versionToken)) + q.Set(threatTypeString, req.GetThreatType().String()) + if len(req.GetVersionToken()) != 0 { + q.Set(versionTokenString, base64.StdEncoding.EncodeToString(req.GetVersionToken())) } - for _, compressionType := range compressionTypes { + for _, compressionType := range req.GetConstraints().GetSupportedCompressions() { q.Add(supportedCompressionsString, compressionType.String()) } u.RawQuery = q.Encode() diff --git a/api_test.go b/api_test.go index 320c4f7..a58d12f 100644 --- a/api_test.go +++ b/api_test.go @@ -39,9 +39,8 @@ type mockAPI struct { threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error) } -func (m *mockAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte, - compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) { - return m.listUpdate(ctx, threatType, versionToken, compressionTypes) +func (m *mockAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) { + return m.listUpdate(ctx, req.GetThreatType(), req.GetVersionToken(), req.GetConstraints().GetSupportedCompressions()) } func (m *mockAPI) HashLookup(ctx context.Context, hashPrefix []byte, @@ -118,8 +117,14 @@ func TestNetAPI(t *testing.T) { RawIndices: &pb.RawIndices{Indices: []int32{1, 2, 3}}, }, } - resp1, err := api.ListUpdate(context.Background(), wantReqThreatType, []byte{}, - wantReqCompressionTypes) + req := &pb.ComputeThreatListDiffRequest{ + ThreatType: wantReqThreatType, + Constraints: &pb.ComputeThreatListDiffRequest_Constraints{ + SupportedCompressions: wantReqCompressionTypes, + }, + VersionToken: []byte{}, + } + resp1, err := api.ListUpdate(context.Background(), req) gotResp = resp1 if err != nil { t.Errorf("unexpected ListUpdate error: %v", err) diff --git a/database.go b/database.go index 0aacf0c..7b86409 100644 --- a/database.go +++ b/database.go @@ -212,7 +212,7 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) { last := db.config.now() for _, req := range s { // Query the API for the threat list and update the database. - resp, err := api.ListUpdate(ctx, req.ThreatType, req.VersionToken, req.Constraints.SupportedCompressions) + resp, err := api.ListUpdate(ctx, req) if err != nil { db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err) db.setError(err) diff --git a/webrisk_client_system_test.go b/webrisk_client_system_test.go index 2a74fbc..60355c1 100644 --- a/webrisk_client_system_test.go +++ b/webrisk_client_system_test.go @@ -42,8 +42,7 @@ func TestNetworkAPIUpdate(t *testing.T) { ThreatType: pb.ThreatType_MALWARE, } - dat, err := nm.ListUpdate(context.Background(), req.ThreatType, - req.VersionToken, []pb.CompressionType{}) + dat, err := nm.ListUpdate(context.Background(), req) if err != nil { t.Fatal(err) }