Skip to content

Commit

Permalink
Refactor ListUpdate args.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 639046109
  • Loading branch information
rvilgalys authored and copybara-github committed Jun 4, 2024
1 parent d5412b6 commit 50ac1fe
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
14 changes: 6 additions & 8 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ const (

// The api interface specifies wrappers around the Web Risk API.
type api interface {
ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error)
ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error)
HashLookup(ctx context.Context, hashPrefix []byte,
threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error)
}
Expand Down Expand Up @@ -123,17 +122,16 @@ func (a *netAPI) parseError(httpResp *http.Response) error {
}

// ListUpdate issues a ComputeThreatListDiff API call and returns the response.
func (a *netAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) {
func (a *netAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) {
resp := new(pb.ComputeThreatListDiffResponse)
u := *a.url // Make a copy of URL
// Add fields from ComputeThreatListDiffRequest to URL request
q := u.Query()
q.Set(threatTypeString, threatType.String())
if len(versionToken) != 0 {
q.Set(versionTokenString, base64.StdEncoding.EncodeToString(versionToken))
q.Set(threatTypeString, req.GetThreatType().String())
if len(req.GetVersionToken()) != 0 {
q.Set(versionTokenString, base64.StdEncoding.EncodeToString(req.GetVersionToken()))
}
for _, compressionType := range compressionTypes {
for _, compressionType := range req.GetConstraints().GetSupportedCompressions() {
q.Add(supportedCompressionsString, compressionType.String())
}
u.RawQuery = q.Encode()
Expand Down
15 changes: 10 additions & 5 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ type mockAPI struct {
threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error)
}

func (m *mockAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) {
return m.listUpdate(ctx, threatType, versionToken, compressionTypes)
func (m *mockAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) {
return m.listUpdate(ctx, req.GetThreatType(), req.GetVersionToken(), req.GetConstraints().GetSupportedCompressions())
}

func (m *mockAPI) HashLookup(ctx context.Context, hashPrefix []byte,
Expand Down Expand Up @@ -118,8 +117,14 @@ func TestNetAPI(t *testing.T) {
RawIndices: &pb.RawIndices{Indices: []int32{1, 2, 3}},
},
}
resp1, err := api.ListUpdate(context.Background(), wantReqThreatType, []byte{},
wantReqCompressionTypes)
req := &pb.ComputeThreatListDiffRequest{
ThreatType: wantReqThreatType,
Constraints: &pb.ComputeThreatListDiffRequest_Constraints{
SupportedCompressions: wantReqCompressionTypes,
},
VersionToken: []byte{},
}
resp1, err := api.ListUpdate(context.Background(), req)
gotResp = resp1
if err != nil {
t.Errorf("unexpected ListUpdate error: %v", err)
Expand Down
2 changes: 1 addition & 1 deletion database.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) {
last := db.config.now()
for _, req := range s {
// Query the API for the threat list and update the database.
resp, err := api.ListUpdate(ctx, req.ThreatType, req.VersionToken, req.Constraints.SupportedCompressions)
resp, err := api.ListUpdate(ctx, req)
if err != nil {
db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err)
db.setError(err)
Expand Down
3 changes: 1 addition & 2 deletions webrisk_client_system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ func TestNetworkAPIUpdate(t *testing.T) {
ThreatType: pb.ThreatType_MALWARE,
}

dat, err := nm.ListUpdate(context.Background(), req.ThreatType,
req.VersionToken, []pb.CompressionType{})
dat, err := nm.ListUpdate(context.Background(), req)
if err != nil {
t.Fatal(err)
}
Expand Down

0 comments on commit 50ac1fe

Please sign in to comment.