diff --git a/database.go b/database.go index 7b86409..b721ee0 100644 --- a/database.go +++ b/database.go @@ -200,7 +200,10 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) { s = append(s, &pb.ComputeThreatListDiffRequest{ ThreatType: pb.ThreatType(td), Constraints: &pb.ComputeThreatListDiffRequest_Constraints{ - SupportedCompressions: db.config.compressionTypes}, + SupportedCompressions: db.config.compressionTypes, + MaxDiffEntries: db.config.MaxDiffEntries, + MaxDatabaseEntries: db.config.MaxDatabaseEntries, + }, VersionToken: state, }) } diff --git a/webrisk_client.go b/webrisk_client.go index ea4f98a..46aa5fc 100644 --- a/webrisk_client.go +++ b/webrisk_client.go @@ -94,8 +94,9 @@ const ( // Errors specific to this package. var ( - errClosed = errors.New("webrisk: handler is closed") - errStale = errors.New("webrisk: threat list is stale") + errClosed = errors.New("webrisk: handler is closed") + errStale = errors.New("webrisk: threat list is stale") + errMaxEntries = errors.New("webrisk: max entries must be a power of 2 between 2 ** 10 and 2 ** 20") ) // ThreatType is an enumeration type for threats classes. Examples of threat @@ -167,6 +168,18 @@ type Config struct { // If empty, ThreatLists will be loaded instead. ThreatListArg string + // MaxDiffEntries sets the maximum entries to request in a single call to ComputeThreatListDiff. + // This can be used in resource-constrained environments to limit the number of entries fetched + // at once. The default behavior (0) is to ignore this limit. + // If set, this should be a power of 2 between 2 ** 10 and 2 ** 20. + MaxDiffEntries int32 + + // MaxDatabaseEntries sets the maximum entries that the client will accept & store in the local + // database. This can be used to limit the size of the blocklist at the trade-off of decreased + // coverage. The default behavior (0) is to ignore this limit. + // If set, this should be a power of 2 between 2 ** 10 and 2 ** 20. + MaxDatabaseEntries int32 + // ThreatLists determines which threat lists that UpdateClient should // subscribe to. The threats reported by LookupURLs will only be ones that // are specified by this list. @@ -229,6 +242,19 @@ func parseThreatTypes(args string) ([]ThreatType, error) { return r, nil } +// validateMaxEntries validates a max entries argument, which must be either 0 or a power of 2 +// between 2 ** 10 and 2 ** 20. +func validateMaxEntries(n int32) error { + if n == 0 { + return nil + } + // Bitwise check confirms a power of 2. + if n&(n-1) != 0 || n < 1024 || n > 1048576 { + return errMaxEntries + } + return nil +} + func (c Config) copy() Config { c2 := c c2.ThreatLists = append([]ThreatType(nil), c.ThreatLists...) @@ -287,6 +313,11 @@ func NewUpdateClient(conf Config) (*UpdateClient, error) { conf.ThreatLists = tl } + // Validate max entries if args are passed. + if err := validateMaxEntries(conf.MaxDiffEntries); err != nil { + return nil, err + } + // Create the SafeBrowsing object. if conf.api == nil { var err error diff --git a/webrisk_client_test.go b/webrisk_client_test.go index d42a392..5c56b66 100644 --- a/webrisk_client_test.go +++ b/webrisk_client_test.go @@ -65,3 +65,50 @@ func TestParseThreatTypes(t *testing.T) { } } } + +func TestValidateMaxEntries(t *testing.T) { + tests := []struct { + n int32 + wantErr error + }{ + { + n: 0, + wantErr: nil, + }, + { + n: 1024, + wantErr: nil, + }, + { + n: 4096, + wantErr: nil, + }, + { + n: 1048576, + wantErr: nil, + }, + { + n: -1024, + wantErr: errMaxEntries, + }, + { + n: 100, + wantErr: errMaxEntries, + }, + { + n: 1026, + wantErr: errMaxEntries, + }, + { + n: 2097152, + wantErr: errMaxEntries, + }, + } + + for _, tc := range tests { + gotErr := validateMaxEntries(tc.n) + if gotErr != tc.wantErr { + t.Errorf("validateMaxEntries(%d) = %v, want %v", tc.n, gotErr, tc.wantErr) + } + } +}