From 7fb05167d2b58f0445da1310420395c1f821d1fd Mon Sep 17 00:00:00 2001 From: Rim Vilgalys Date: Fri, 31 May 2024 11:53:27 -0700 Subject: [PATCH] Updates to README with details about command line flags. PiperOrigin-RevId: 639105973 --- README.md | 23 +++++++++++++++++ api.go | 23 +++++++++++------ api_test.go | 48 +++++++++++++++++++++++++++++++---- cmd/wrlookup/main.go | 26 +++++++++++-------- cmd/wrserver/main.go | 26 +++++++++++-------- database.go | 7 +++-- webrisk_client.go | 35 +++++++++++++++++++++++-- webrisk_client_system_test.go | 3 +-- webrisk_client_test.go | 47 ++++++++++++++++++++++++++++++++++ 9 files changed, 197 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 9a2f91e..27d5307 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,29 @@ For 400 errors, this usually means the API key is incorrect or was not supplied For 403 errors, this could mean the Web Risk API is not enabled for your project **or** your project does not have Billing enabled. +# Configuration + +Both `wrserver` (used by `docker run`) and `wrlookup` support several command line flags. + +- `apikey` (required) -- Used to Authenticate requests with the Web Risk API. +The API itself must also be enabled on the same project & be linked to a Billing account. + +- `threatTypes` (optional) -- A comma-separated lists of different blocklists to load and check URLs against. +Available options include `MALWARE`,`UNWANTED_SOFTWARE`,`SOCIAL_ENGINEERING`, +`SOCIAL_ENGINEERING_EXTENDED_COVERAGE`. This arg will also accept `ALL` which is +the default behavior. + +- `maxDiffEntries` (optional) -- An int32 value that will set the max number of hash prefixes +returned in a single diff request. This can be used in resource-bound environments to control +bandwidth usage. The default value of 0 will result in this limit being ignored. Otherwise, this +must be set to a positive integer which must be a power of 2 between 2 ^ 10 and 2 ^ 20. + +- `maxDatabaseEntries` (optional) -- An in32 value that will set the upper boundary has prefixes +to be returned from the API and stored locally. This can be used to limit the number of hash +prefixes to be searched against. The default value of 0 will result in this limit being ignored. Otherwise, this +must be set to a positive integer which must be a power of 2 between 2 ^ 10 and 2 ^ 20. *Note*: Setting this limit +will decrease blocklist coverage. + # About the Social Engineering Extended Coverage List This is a newer blocklist that includes a greater range of risky URLs that diff --git a/api.go b/api.go index ffa663f..3a73c0e 100644 --- a/api.go +++ b/api.go @@ -20,6 +20,7 @@ import ( "io/ioutil" "net/http" "net/url" + "strconv" "strings" "google.golang.org/protobuf/encoding/protojson" @@ -34,6 +35,8 @@ const ( threatTypeString = "threat_type" versionTokenString = "version_token" supportedCompressionsString = "constraints.supported_compressions" + maxDiffEntriesString = "constraints.max_diff_entries" + maxDatabaseEntriesString = "constraints.max_database_entries" hashPrefixString = "hash_prefix" threatTypesString = "threat_types" userAgentString = "Webrisk-Client/0.2.1" @@ -41,8 +44,7 @@ const ( // The api interface specifies wrappers around the Web Risk API. type api interface { - ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte, - compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) + ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) HashLookup(ctx context.Context, hashPrefix []byte, threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error) } @@ -123,17 +125,22 @@ func (a *netAPI) parseError(httpResp *http.Response) error { } // ListUpdate issues a ComputeThreatListDiff API call and returns the response. -func (a *netAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte, - compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) { +func (a *netAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) { resp := new(pb.ComputeThreatListDiffResponse) u := *a.url // Make a copy of URL // Add fields from ComputeThreatListDiffRequest to URL request q := u.Query() - q.Set(threatTypeString, threatType.String()) - if len(versionToken) != 0 { - q.Set(versionTokenString, base64.StdEncoding.EncodeToString(versionToken)) + q.Set(threatTypeString, req.GetThreatType().String()) + if len(req.GetVersionToken()) != 0 { + q.Set(versionTokenString, base64.StdEncoding.EncodeToString(req.GetVersionToken())) } - for _, compressionType := range compressionTypes { + if req.GetConstraints().GetMaxDiffEntries() != 0 { + q.Add(maxDiffEntriesString, strconv.FormatInt(int64(req.GetConstraints().GetMaxDiffEntries()), 10)) + } + if req.GetConstraints().GetMaxDatabaseEntries() != 0 { + q.Add(maxDatabaseEntriesString, strconv.FormatInt(int64(req.GetConstraints().GetMaxDatabaseEntries()), 10)) + } + for _, compressionType := range req.GetConstraints().GetSupportedCompressions() { q.Add(supportedCompressionsString, compressionType.String()) } u.RawQuery = q.Encode() diff --git a/api_test.go b/api_test.go index 320c4f7..f95c2d1 100644 --- a/api_test.go +++ b/api_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "strconv" "testing" "github.com/google/go-cmp/cmp" @@ -39,9 +40,8 @@ type mockAPI struct { threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error) } -func (m *mockAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte, - compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) { - return m.listUpdate(ctx, threatType, versionToken, compressionTypes) +func (m *mockAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) { + return m.listUpdate(ctx, req.GetThreatType(), req.GetVersionToken(), req.GetConstraints().GetSupportedCompressions()) } func (m *mockAPI) HashLookup(ctx context.Context, hashPrefix []byte, @@ -55,6 +55,8 @@ func TestNetAPI(t *testing.T) { var gotReqHashPrefix, wantReqHashPrefix []byte var gotReqThreatTypes, wantReqThreatTypes []pb.ThreatType var gotResp, wantResp proto.Message + var gotMaxDiffEntries, wantMaxDiffEntries []int32 + var gotMaxDatabaseEntries, wantMaxDatabaseEntries []int32 responseMisformatter := "" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var p []byte @@ -89,6 +91,24 @@ func TestNetAPI(t *testing.T) { gotReqThreatTypes = append(gotReqThreatTypes, pb.ThreatType(pb.ThreatType_value[threat])) } + } else if key == maxDiffEntriesString { + if len(value) == 0 { + t.Fatalf("missing value for key %v", key) + } + i, err := strconv.ParseInt(value[0], 10, 32) + if err != nil { + t.Fatalf("error parsing %v", value[0]) + } + gotMaxDiffEntries = append(gotMaxDiffEntries, int32(i)) + } else if key == maxDatabaseEntriesString { + if len(value) == 0 { + t.Fatalf("missing value for key %v", key) + } + i, err := strconv.ParseInt(value[0], 10, 32) + if err != nil { + t.Fatalf("error parsing %v", value[0]) + } + gotMaxDatabaseEntries = append(gotMaxDatabaseEntries, int32(i)) } else if key != "key" { t.Fatalf("unexpected request param error for key: %v", key) } @@ -110,6 +130,8 @@ func TestNetAPI(t *testing.T) { // Test that ListUpdate marshal/unmarshal works. wantReqThreatType = pb.ThreatType_MALWARE wantReqCompressionTypes = []pb.CompressionType{0, 1, 2} + wantMaxDiffEntries = []int32{1024} + wantMaxDatabaseEntries = []int32{1024} wantResp = &pb.ComputeThreatListDiffResponse{ ResponseType: 1, @@ -118,8 +140,16 @@ func TestNetAPI(t *testing.T) { RawIndices: &pb.RawIndices{Indices: []int32{1, 2, 3}}, }, } - resp1, err := api.ListUpdate(context.Background(), wantReqThreatType, []byte{}, - wantReqCompressionTypes) + req := &pb.ComputeThreatListDiffRequest{ + ThreatType: wantReqThreatType, + Constraints: &pb.ComputeThreatListDiffRequest_Constraints{ + SupportedCompressions: wantReqCompressionTypes, + MaxDiffEntries: 1024, + MaxDatabaseEntries: 1024, + }, + VersionToken: []byte{}, + } + resp1, err := api.ListUpdate(context.Background(), req) gotResp = resp1 if err != nil { t.Errorf("unexpected ListUpdate error: %v", err) @@ -135,6 +165,14 @@ func TestNetAPI(t *testing.T) { if !proto.Equal(gotResp, wantResp) { t.Errorf("mismatching ListUpdate responses:\ngot %+v\nwant %+v", gotResp, wantResp) } + if !reflect.DeepEqual(gotMaxDiffEntries, wantMaxDiffEntries) { + t.Errorf("mismatching ListUpdate max diff entries:\ngot %+v\nwant %+v", + gotMaxDiffEntries, wantMaxDiffEntries) + } + if !reflect.DeepEqual(gotMaxDatabaseEntries, wantMaxDatabaseEntries) { + t.Errorf("mismatching ListUpdate max database entries:\ngot %+v\nwant %+v", + gotMaxDatabaseEntries, wantMaxDatabaseEntries) + } // Test that HashLookup marshal/unmarshal works. wantReqHashPrefix = []byte("aaaa") diff --git a/cmd/wrlookup/main.go b/cmd/wrlookup/main.go index 096c901..5ae319e 100644 --- a/cmd/wrlookup/main.go +++ b/cmd/wrlookup/main.go @@ -40,11 +40,13 @@ import ( ) var ( - apiKeyFlag = flag.String("apikey", "", "specify your Web Risk API key") - databaseFlag = flag.String("db", "", "path to the Web Risk database. By default persistent storage is disabled (not recommended).") - serverURLFlag = flag.String("server", webrisk.DefaultServerURL, "Web Risk API server address.") - proxyFlag = flag.String("proxy", "", "proxy to use to connect to the HTTP server") - threatTypesFlag = flag.String("threatTypes", "ALL", "threat types to check against") + apiKeyFlag = flag.String("apikey", "", "specify your Web Risk API key") + databaseFlag = flag.String("db", "", "path to the Web Risk database. By default persistent storage is disabled (not recommended).") + serverURLFlag = flag.String("server", webrisk.DefaultServerURL, "Web Risk API server address.") + proxyFlag = flag.String("proxy", "", "proxy to use to connect to the HTTP server") + threatTypesFlag = flag.String("threatTypes", "ALL", "threat types to check against") + maxDiffEntriesFlag = flag.Int("maxDiffEntries", 0, "maximum number of diff entries to return from a ComputeThreatListDiff request") + maxDatabaseEntriesFlag = flag.Int("maxDatabaseEntries", 0, "maximum number of database entries to be stored in the local database") ) const usage = `wrlookup: command-line tool to lookup URLs with Web Risk. @@ -81,12 +83,14 @@ func main() { os.Exit(codeInvalid) } sb, err := webrisk.NewUpdateClient(webrisk.Config{ - APIKey: *apiKeyFlag, - DBPath: *databaseFlag, - Logger: os.Stderr, - ServerURL: *serverURLFlag, - ProxyURL: *proxyFlag, - ThreatListArg: *threatTypesFlag, + APIKey: *apiKeyFlag, + DBPath: *databaseFlag, + Logger: os.Stderr, + ServerURL: *serverURLFlag, + ProxyURL: *proxyFlag, + ThreatListArg: *threatTypesFlag, + MaxDiffEntries: int32(*maxDiffEntriesFlag), + MaxDatabaseEntries: int32(*maxDatabaseEntriesFlag), }) if err != nil { fmt.Fprintln(os.Stderr, "Unable to initialize Web Risk client: ", err) diff --git a/cmd/wrserver/main.go b/cmd/wrserver/main.go index 93a2e18..e9691fb 100644 --- a/cmd/wrserver/main.go +++ b/cmd/wrserver/main.go @@ -223,11 +223,13 @@ const ( ) var ( - apiKeyFlag = flag.String("apikey", os.Getenv("APIKEY"), "specify your Web Risk API key") - srvAddrFlag = flag.String("srvaddr", "0.0.0.0:8080", "TCP network address the HTTP server should use") - proxyFlag = flag.String("proxy", "", "proxy to use to connect to the HTTP server") - databaseFlag = flag.String("db", "", "path to the Web Risk database.") - threatTypesFlag = flag.String("threatTypes", "ALL", "threat types to check against") + apiKeyFlag = flag.String("apikey", os.Getenv("APIKEY"), "specify your Web Risk API key") + srvAddrFlag = flag.String("srvaddr", "0.0.0.0:8080", "TCP network address the HTTP server should use") + proxyFlag = flag.String("proxy", "", "proxy to use to connect to the HTTP server") + databaseFlag = flag.String("db", "", "path to the Web Risk database.") + threatTypesFlag = flag.String("threatTypes", "ALL", "threat types to check against") + maxDiffEntriesFlag = flag.Int("maxDiffEntries", 0, "maximum number of diff entries to return from a ComputeThreatListDiff request") + maxDatabaseEntriesFlag = flag.Int("maxDatabaseEntries", 0, "maximum number of database entries to be stored in the local database") ) var threatTemplate = map[webrisk.ThreatType]string{ @@ -477,7 +479,7 @@ func runServer(srv *http.Server) (chan os.Signal, <-chan struct{}) { // start listening for interrupts exit := make(chan os.Signal, 1) down := make(chan struct{}) - + // runs shutdown and cleanup on an exit signal go func() { <-exit @@ -518,11 +520,13 @@ func main() { os.Exit(1) } conf := webrisk.Config{ - APIKey: *apiKeyFlag, - ProxyURL: *proxyFlag, - DBPath: *databaseFlag, - ThreatListArg: *threatTypesFlag, - Logger: os.Stderr, + APIKey: *apiKeyFlag, + ProxyURL: *proxyFlag, + DBPath: *databaseFlag, + ThreatListArg: *threatTypesFlag, + MaxDiffEntries: int32(*maxDiffEntriesFlag), + MaxDatabaseEntries: int32(*maxDatabaseEntriesFlag), + Logger: os.Stderr, } wr, err := webrisk.NewUpdateClient(conf) if err != nil { diff --git a/database.go b/database.go index 0aacf0c..b721ee0 100644 --- a/database.go +++ b/database.go @@ -200,7 +200,10 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) { s = append(s, &pb.ComputeThreatListDiffRequest{ ThreatType: pb.ThreatType(td), Constraints: &pb.ComputeThreatListDiffRequest_Constraints{ - SupportedCompressions: db.config.compressionTypes}, + SupportedCompressions: db.config.compressionTypes, + MaxDiffEntries: db.config.MaxDiffEntries, + MaxDatabaseEntries: db.config.MaxDatabaseEntries, + }, VersionToken: state, }) } @@ -212,7 +215,7 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) { last := db.config.now() for _, req := range s { // Query the API for the threat list and update the database. - resp, err := api.ListUpdate(ctx, req.ThreatType, req.VersionToken, req.Constraints.SupportedCompressions) + resp, err := api.ListUpdate(ctx, req) if err != nil { db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err) db.setError(err) diff --git a/webrisk_client.go b/webrisk_client.go index ea4f98a..46aa5fc 100644 --- a/webrisk_client.go +++ b/webrisk_client.go @@ -94,8 +94,9 @@ const ( // Errors specific to this package. var ( - errClosed = errors.New("webrisk: handler is closed") - errStale = errors.New("webrisk: threat list is stale") + errClosed = errors.New("webrisk: handler is closed") + errStale = errors.New("webrisk: threat list is stale") + errMaxEntries = errors.New("webrisk: max entries must be a power of 2 between 2 ** 10 and 2 ** 20") ) // ThreatType is an enumeration type for threats classes. Examples of threat @@ -167,6 +168,18 @@ type Config struct { // If empty, ThreatLists will be loaded instead. ThreatListArg string + // MaxDiffEntries sets the maximum entries to request in a single call to ComputeThreatListDiff. + // This can be used in resource-constrained environments to limit the number of entries fetched + // at once. The default behavior (0) is to ignore this limit. + // If set, this should be a power of 2 between 2 ** 10 and 2 ** 20. + MaxDiffEntries int32 + + // MaxDatabaseEntries sets the maximum entries that the client will accept & store in the local + // database. This can be used to limit the size of the blocklist at the trade-off of decreased + // coverage. The default behavior (0) is to ignore this limit. + // If set, this should be a power of 2 between 2 ** 10 and 2 ** 20. + MaxDatabaseEntries int32 + // ThreatLists determines which threat lists that UpdateClient should // subscribe to. The threats reported by LookupURLs will only be ones that // are specified by this list. @@ -229,6 +242,19 @@ func parseThreatTypes(args string) ([]ThreatType, error) { return r, nil } +// validateMaxEntries validates a max entries argument, which must be either 0 or a power of 2 +// between 2 ** 10 and 2 ** 20. +func validateMaxEntries(n int32) error { + if n == 0 { + return nil + } + // Bitwise check confirms a power of 2. + if n&(n-1) != 0 || n < 1024 || n > 1048576 { + return errMaxEntries + } + return nil +} + func (c Config) copy() Config { c2 := c c2.ThreatLists = append([]ThreatType(nil), c.ThreatLists...) @@ -287,6 +313,11 @@ func NewUpdateClient(conf Config) (*UpdateClient, error) { conf.ThreatLists = tl } + // Validate max entries if args are passed. + if err := validateMaxEntries(conf.MaxDiffEntries); err != nil { + return nil, err + } + // Create the SafeBrowsing object. if conf.api == nil { var err error diff --git a/webrisk_client_system_test.go b/webrisk_client_system_test.go index 2a74fbc..60355c1 100644 --- a/webrisk_client_system_test.go +++ b/webrisk_client_system_test.go @@ -42,8 +42,7 @@ func TestNetworkAPIUpdate(t *testing.T) { ThreatType: pb.ThreatType_MALWARE, } - dat, err := nm.ListUpdate(context.Background(), req.ThreatType, - req.VersionToken, []pb.CompressionType{}) + dat, err := nm.ListUpdate(context.Background(), req) if err != nil { t.Fatal(err) } diff --git a/webrisk_client_test.go b/webrisk_client_test.go index d42a392..5c56b66 100644 --- a/webrisk_client_test.go +++ b/webrisk_client_test.go @@ -65,3 +65,50 @@ func TestParseThreatTypes(t *testing.T) { } } } + +func TestValidateMaxEntries(t *testing.T) { + tests := []struct { + n int32 + wantErr error + }{ + { + n: 0, + wantErr: nil, + }, + { + n: 1024, + wantErr: nil, + }, + { + n: 4096, + wantErr: nil, + }, + { + n: 1048576, + wantErr: nil, + }, + { + n: -1024, + wantErr: errMaxEntries, + }, + { + n: 100, + wantErr: errMaxEntries, + }, + { + n: 1026, + wantErr: errMaxEntries, + }, + { + n: 2097152, + wantErr: errMaxEntries, + }, + } + + for _, tc := range tests { + gotErr := validateMaxEntries(tc.n) + if gotErr != tc.wantErr { + t.Errorf("validateMaxEntries(%d) = %v, want %v", tc.n, gotErr, tc.wantErr) + } + } +}