Skip to content

Commit

Permalink
Adds optional max_diff_entries and max_database_entries in ComputeThr…
Browse files Browse the repository at this point in the history
…eatListDiffRequest call from api.go

PiperOrigin-RevId: 639091343
  • Loading branch information
rvilgalys authored and copybara-github committed Jun 3, 2024
1 parent d5412b6 commit 4599783
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 22 deletions.
23 changes: 15 additions & 8 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"

"google.golang.org/protobuf/encoding/protojson"
Expand All @@ -34,15 +35,16 @@ const (
threatTypeString = "threat_type"
versionTokenString = "version_token"
supportedCompressionsString = "constraints.supported_compressions"
maxDiffEntriesKey = "constraints.max_diff_entries"
maxDatabaseEntriesKey = "constraints.max_database_entries"
hashPrefixString = "hash_prefix"
threatTypesString = "threat_types"
userAgentString = "Webrisk-Client/0.2.1"
)

// The api interface specifies wrappers around the Web Risk API.
type api interface {
ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error)
ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error)
HashLookup(ctx context.Context, hashPrefix []byte,
threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error)
}
Expand Down Expand Up @@ -123,17 +125,22 @@ func (a *netAPI) parseError(httpResp *http.Response) error {
}

// ListUpdate issues a ComputeThreatListDiff API call and returns the response.
func (a *netAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) {
func (a *netAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) {
resp := new(pb.ComputeThreatListDiffResponse)
u := *a.url // Make a copy of URL
// Add fields from ComputeThreatListDiffRequest to URL request
q := u.Query()
q.Set(threatTypeString, threatType.String())
if len(versionToken) != 0 {
q.Set(versionTokenString, base64.StdEncoding.EncodeToString(versionToken))
q.Set(threatTypeString, req.GetThreatType().String())
if len(req.GetVersionToken()) != 0 {
q.Set(versionTokenString, base64.StdEncoding.EncodeToString(req.GetVersionToken()))
}
for _, compressionType := range compressionTypes {
if req.GetConstraints().GetMaxDiffEntries() != 0 {
q.Add(maxDiffEntriesKey, strconv.FormatInt(int64(req.GetConstraints().GetMaxDiffEntries()), 10))
}
if req.GetConstraints().GetMaxDatabaseEntries() != 0 {
q.Add(maxDatabaseEntriesKey, strconv.FormatInt(int64(req.GetConstraints().GetMaxDatabaseEntries()), 10))
}
for _, compressionType := range req.GetConstraints().GetSupportedCompressions() {
q.Add(supportedCompressionsString, compressionType.String())
}
u.RawQuery = q.Encode()
Expand Down
54 changes: 46 additions & 8 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -39,9 +40,8 @@ type mockAPI struct {
threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error)
}

func (m *mockAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) {
return m.listUpdate(ctx, threatType, versionToken, compressionTypes)
func (m *mockAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) {
return m.listUpdate(ctx, req.GetThreatType(), req.GetVersionToken(), req.GetConstraints().GetSupportedCompressions())
}

func (m *mockAPI) HashLookup(ctx context.Context, hashPrefix []byte,
Expand All @@ -55,6 +55,8 @@ func TestNetAPI(t *testing.T) {
var gotReqHashPrefix, wantReqHashPrefix []byte
var gotReqThreatTypes, wantReqThreatTypes []pb.ThreatType
var gotResp, wantResp proto.Message
var gotMaxDiffEntries, wantMaxDiffEntries []int32
var gotMaxDatabaseEntries, wantMaxDatabaseEntries []int32
responseMisformatter := ""
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var p []byte
Expand Down Expand Up @@ -89,15 +91,33 @@ func TestNetAPI(t *testing.T) {
gotReqThreatTypes = append(gotReqThreatTypes,
pb.ThreatType(pb.ThreatType_value[threat]))
}
} else if key == maxDiffEntriesKey {
if len(value) == 0 {
t.Fatalf("Missing value for key %v", key)
}
i, err := strconv.ParseInt(value[0], 10, 32)
if err != nil {
t.Fatalf("Error parsing %q: %v", value[0], err)
}
gotMaxDiffEntries = append(gotMaxDiffEntries, int32(i))
} else if key == maxDatabaseEntriesKey {
if len(value) == 0 {
t.Fatalf("Missing value for key %v", key)
}
i, err := strconv.ParseInt(value[0], 10, 32)
if err != nil {
t.Fatalf("Error parsing %q: %v", value[0], err)
}
gotMaxDatabaseEntries = append(gotMaxDatabaseEntries, int32(i))
} else if key != "key" {
t.Fatalf("unexpected request param error for key: %v", key)
t.Fatalf("Unexpected request param error for key: %v", key)
}
}
if p, err = protojson.Marshal(wantResp); err != nil {
t.Fatalf("unexpected json MarshalToString error: %v", err)
t.Fatalf("Unexpected json MarshalToString error: %v", err)
}
if _, err := w.Write([]byte(responseMisformatter + string(p))); err != nil {
t.Fatalf("unexpected ResponseWriter.Write error: %v", err)
t.Fatalf("Unexpected ResponseWriter.Write error: %v", err)
}
}))
defer ts.Close()
Expand All @@ -110,6 +130,8 @@ func TestNetAPI(t *testing.T) {
// Test that ListUpdate marshal/unmarshal works.
wantReqThreatType = pb.ThreatType_MALWARE
wantReqCompressionTypes = []pb.CompressionType{0, 1, 2}
wantMaxDiffEntries = []int32{1024}
wantMaxDatabaseEntries = []int32{1024}

wantResp = &pb.ComputeThreatListDiffResponse{
ResponseType: 1,
Expand All @@ -118,8 +140,16 @@ func TestNetAPI(t *testing.T) {
RawIndices: &pb.RawIndices{Indices: []int32{1, 2, 3}},
},
}
resp1, err := api.ListUpdate(context.Background(), wantReqThreatType, []byte{},
wantReqCompressionTypes)
req := &pb.ComputeThreatListDiffRequest{
ThreatType: wantReqThreatType,
Constraints: &pb.ComputeThreatListDiffRequest_Constraints{
SupportedCompressions: wantReqCompressionTypes,
MaxDiffEntries: 1024,
MaxDatabaseEntries: 1024,
},
VersionToken: []byte{},
}
resp1, err := api.ListUpdate(context.Background(), req)
gotResp = resp1
if err != nil {
t.Errorf("unexpected ListUpdate error: %v", err)
Expand All @@ -135,6 +165,14 @@ func TestNetAPI(t *testing.T) {
if !proto.Equal(gotResp, wantResp) {
t.Errorf("mismatching ListUpdate responses:\ngot %+v\nwant %+v", gotResp, wantResp)
}
if !reflect.DeepEqual(gotMaxDiffEntries, wantMaxDiffEntries) {
t.Errorf("mismatching ListUpdate max diff entries:\ngot %+v\nwant %+v",
gotMaxDiffEntries, wantMaxDiffEntries)
}
if !reflect.DeepEqual(gotMaxDatabaseEntries, wantMaxDatabaseEntries) {
t.Errorf("mismatching ListUpdate max database entries:\ngot %+v\nwant %+v",
gotMaxDatabaseEntries, wantMaxDatabaseEntries)
}

// Test that HashLookup marshal/unmarshal works.
wantReqHashPrefix = []byte("aaaa")
Expand Down
7 changes: 5 additions & 2 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) {
s = append(s, &pb.ComputeThreatListDiffRequest{
ThreatType: pb.ThreatType(td),
Constraints: &pb.ComputeThreatListDiffRequest_Constraints{
SupportedCompressions: db.config.compressionTypes},
SupportedCompressions: db.config.compressionTypes,
MaxDiffEntries: db.config.MaxDiffEntries,
MaxDatabaseEntries: db.config.MaxDatabaseEntries,
},
VersionToken: state,
})
}
Expand All @@ -212,7 +215,7 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) {
last := db.config.now()
for _, req := range s {
// Query the API for the threat list and update the database.
resp, err := api.ListUpdate(ctx, req.ThreatType, req.VersionToken, req.Constraints.SupportedCompressions)
resp, err := api.ListUpdate(ctx, req)
if err != nil {
db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err)
db.setError(err)
Expand Down
35 changes: 33 additions & 2 deletions webrisk_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ const (

// Errors specific to this package.
var (
errClosed = errors.New("webrisk: handler is closed")
errStale = errors.New("webrisk: threat list is stale")
errClosed = errors.New("webrisk: handler is closed")
errStale = errors.New("webrisk: threat list is stale")
errMaxEntries = errors.New("webrisk: max entries must be a power of 2 between 2 ** 10 and 2 ** 20")
)

// ThreatType is an enumeration type for threats classes. Examples of threat
Expand Down Expand Up @@ -167,6 +168,18 @@ type Config struct {
// If empty, ThreatLists will be loaded instead.
ThreatListArg string

// MaxDiffEntries sets the maximum entries to request in a single call to ComputeThreatListDiff.
// This can be used in resource-constrained environments to limit the number of entries fetched
// at once. The default behavior (0) is to ignore this limit.
// If set, this should be a power of 2 between 2 ** 10 and 2 ** 20.
MaxDiffEntries int32

// MaxDatabaseEntries sets the maximum entries that the client will accept & store in the local
// database. This can be used to limit the size of the blocklist at the trade-off of decreased
// coverage. The default behavior (0) is to ignore this limit.
// If set, this should be a power of 2 between 2 ** 10 and 2 ** 20.
MaxDatabaseEntries int32

// ThreatLists determines which threat lists that UpdateClient should
// subscribe to. The threats reported by LookupURLs will only be ones that
// are specified by this list.
Expand Down Expand Up @@ -229,6 +242,19 @@ func parseThreatTypes(args string) ([]ThreatType, error) {
return r, nil
}

// validateMaxEntries validates a max entries argument, which must be either 0 or a power of 2
// between 2 ** 10 and 2 ** 20.
func validateMaxEntries(n int32) error {
if n == 0 {
return nil
}
// Bitwise check confirms a power of 2.
if n&(n-1) != 0 || n < 1024 || n > 1048576 {
return errMaxEntries
}
return nil
}

func (c Config) copy() Config {
c2 := c
c2.ThreatLists = append([]ThreatType(nil), c.ThreatLists...)
Expand Down Expand Up @@ -287,6 +313,11 @@ func NewUpdateClient(conf Config) (*UpdateClient, error) {
conf.ThreatLists = tl
}

// Validate max entries if args are passed.
if err := validateMaxEntries(conf.MaxDiffEntries); err != nil {
return nil, err
}

// Create the SafeBrowsing object.
if conf.api == nil {
var err error
Expand Down
3 changes: 1 addition & 2 deletions webrisk_client_system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ func TestNetworkAPIUpdate(t *testing.T) {
ThreatType: pb.ThreatType_MALWARE,
}

dat, err := nm.ListUpdate(context.Background(), req.ThreatType,
req.VersionToken, []pb.CompressionType{})
dat, err := nm.ListUpdate(context.Background(), req)
if err != nil {
t.Fatal(err)
}
Expand Down
47 changes: 47 additions & 0 deletions webrisk_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,50 @@ func TestParseThreatTypes(t *testing.T) {
}
}
}

func TestValidateMaxEntries(t *testing.T) {
tests := []struct {
n int32
wantErr error
}{
{
n: 0,
wantErr: nil,
},
{
n: 1024,
wantErr: nil,
},
{
n: 4096,
wantErr: nil,
},
{
n: 1048576,
wantErr: nil,
},
{
n: -1024,
wantErr: errMaxEntries,
},
{
n: 100,
wantErr: errMaxEntries,
},
{
n: 1026,
wantErr: errMaxEntries,
},
{
n: 2097152,
wantErr: errMaxEntries,
},
}

for _, tc := range tests {
gotErr := validateMaxEntries(tc.n)
if gotErr != tc.wantErr {
t.Errorf("validateMaxEntries(%d) = %v, want %v", tc.n, gotErr, tc.wantErr)
}
}
}

0 comments on commit 4599783

Please sign in to comment.