Skip to content

Commit

Permalink
Updates to README with details about command line flags.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 639105973
  • Loading branch information
rvilgalys authored and copybara-github committed May 31, 2024
1 parent d5412b6 commit 7fb0516
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 41 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,29 @@ For 400 errors, this usually means the API key is incorrect or was not supplied

For 403 errors, this could mean the Web Risk API is not enabled for your project **or** your project does not have Billing enabled.

# Configuration

Both `wrserver` (used by `docker run`) and `wrlookup` support several command line flags.

- `apikey` (required) -- Used to Authenticate requests with the Web Risk API.
The API itself must also be enabled on the same project & be linked to a Billing account.

- `threatTypes` (optional) -- A comma-separated lists of different blocklists to load and check URLs against.
Available options include `MALWARE`,`UNWANTED_SOFTWARE`,`SOCIAL_ENGINEERING`,
`SOCIAL_ENGINEERING_EXTENDED_COVERAGE`. This arg will also accept `ALL` which is
the default behavior.

- `maxDiffEntries` (optional) -- An int32 value that will set the max number of hash prefixes
returned in a single diff request. This can be used in resource-bound environments to control
bandwidth usage. The default value of 0 will result in this limit being ignored. Otherwise, this
must be set to a positive integer which must be a power of 2 between 2 ^ 10 and 2 ^ 20.

- `maxDatabaseEntries` (optional) -- An in32 value that will set the upper boundary has prefixes
to be returned from the API and stored locally. This can be used to limit the number of hash
prefixes to be searched against. The default value of 0 will result in this limit being ignored. Otherwise, this
must be set to a positive integer which must be a power of 2 between 2 ^ 10 and 2 ^ 20. *Note*: Setting this limit
will decrease blocklist coverage.

# About the Social Engineering Extended Coverage List

This is a newer blocklist that includes a greater range of risky URLs that
Expand Down
23 changes: 15 additions & 8 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"io/ioutil"
"net/http"
"net/url"
"strconv"
"strings"

"google.golang.org/protobuf/encoding/protojson"
Expand All @@ -34,15 +35,16 @@ const (
threatTypeString = "threat_type"
versionTokenString = "version_token"
supportedCompressionsString = "constraints.supported_compressions"
maxDiffEntriesString = "constraints.max_diff_entries"
maxDatabaseEntriesString = "constraints.max_database_entries"
hashPrefixString = "hash_prefix"
threatTypesString = "threat_types"
userAgentString = "Webrisk-Client/0.2.1"
)

// The api interface specifies wrappers around the Web Risk API.
type api interface {
ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error)
ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error)
HashLookup(ctx context.Context, hashPrefix []byte,
threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error)
}
Expand Down Expand Up @@ -123,17 +125,22 @@ func (a *netAPI) parseError(httpResp *http.Response) error {
}

// ListUpdate issues a ComputeThreatListDiff API call and returns the response.
func (a *netAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) {
func (a *netAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) {
resp := new(pb.ComputeThreatListDiffResponse)
u := *a.url // Make a copy of URL
// Add fields from ComputeThreatListDiffRequest to URL request
q := u.Query()
q.Set(threatTypeString, threatType.String())
if len(versionToken) != 0 {
q.Set(versionTokenString, base64.StdEncoding.EncodeToString(versionToken))
q.Set(threatTypeString, req.GetThreatType().String())
if len(req.GetVersionToken()) != 0 {
q.Set(versionTokenString, base64.StdEncoding.EncodeToString(req.GetVersionToken()))
}
for _, compressionType := range compressionTypes {
if req.GetConstraints().GetMaxDiffEntries() != 0 {
q.Add(maxDiffEntriesString, strconv.FormatInt(int64(req.GetConstraints().GetMaxDiffEntries()), 10))
}
if req.GetConstraints().GetMaxDatabaseEntries() != 0 {
q.Add(maxDatabaseEntriesString, strconv.FormatInt(int64(req.GetConstraints().GetMaxDatabaseEntries()), 10))
}
for _, compressionType := range req.GetConstraints().GetSupportedCompressions() {
q.Add(supportedCompressionsString, compressionType.String())
}
u.RawQuery = q.Encode()
Expand Down
48 changes: 43 additions & 5 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -39,9 +40,8 @@ type mockAPI struct {
threatTypes []pb.ThreatType) (*pb.SearchHashesResponse, error)
}

func (m *mockAPI) ListUpdate(ctx context.Context, threatType pb.ThreatType, versionToken []byte,
compressionTypes []pb.CompressionType) (*pb.ComputeThreatListDiffResponse, error) {
return m.listUpdate(ctx, threatType, versionToken, compressionTypes)
func (m *mockAPI) ListUpdate(ctx context.Context, req *pb.ComputeThreatListDiffRequest) (*pb.ComputeThreatListDiffResponse, error) {
return m.listUpdate(ctx, req.GetThreatType(), req.GetVersionToken(), req.GetConstraints().GetSupportedCompressions())
}

func (m *mockAPI) HashLookup(ctx context.Context, hashPrefix []byte,
Expand All @@ -55,6 +55,8 @@ func TestNetAPI(t *testing.T) {
var gotReqHashPrefix, wantReqHashPrefix []byte
var gotReqThreatTypes, wantReqThreatTypes []pb.ThreatType
var gotResp, wantResp proto.Message
var gotMaxDiffEntries, wantMaxDiffEntries []int32
var gotMaxDatabaseEntries, wantMaxDatabaseEntries []int32
responseMisformatter := ""
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var p []byte
Expand Down Expand Up @@ -89,6 +91,24 @@ func TestNetAPI(t *testing.T) {
gotReqThreatTypes = append(gotReqThreatTypes,
pb.ThreatType(pb.ThreatType_value[threat]))
}
} else if key == maxDiffEntriesString {
if len(value) == 0 {
t.Fatalf("missing value for key %v", key)
}
i, err := strconv.ParseInt(value[0], 10, 32)
if err != nil {
t.Fatalf("error parsing %v", value[0])
}
gotMaxDiffEntries = append(gotMaxDiffEntries, int32(i))
} else if key == maxDatabaseEntriesString {
if len(value) == 0 {
t.Fatalf("missing value for key %v", key)
}
i, err := strconv.ParseInt(value[0], 10, 32)
if err != nil {
t.Fatalf("error parsing %v", value[0])
}
gotMaxDatabaseEntries = append(gotMaxDatabaseEntries, int32(i))
} else if key != "key" {
t.Fatalf("unexpected request param error for key: %v", key)
}
Expand All @@ -110,6 +130,8 @@ func TestNetAPI(t *testing.T) {
// Test that ListUpdate marshal/unmarshal works.
wantReqThreatType = pb.ThreatType_MALWARE
wantReqCompressionTypes = []pb.CompressionType{0, 1, 2}
wantMaxDiffEntries = []int32{1024}
wantMaxDatabaseEntries = []int32{1024}

wantResp = &pb.ComputeThreatListDiffResponse{
ResponseType: 1,
Expand All @@ -118,8 +140,16 @@ func TestNetAPI(t *testing.T) {
RawIndices: &pb.RawIndices{Indices: []int32{1, 2, 3}},
},
}
resp1, err := api.ListUpdate(context.Background(), wantReqThreatType, []byte{},
wantReqCompressionTypes)
req := &pb.ComputeThreatListDiffRequest{
ThreatType: wantReqThreatType,
Constraints: &pb.ComputeThreatListDiffRequest_Constraints{
SupportedCompressions: wantReqCompressionTypes,
MaxDiffEntries: 1024,
MaxDatabaseEntries: 1024,
},
VersionToken: []byte{},
}
resp1, err := api.ListUpdate(context.Background(), req)
gotResp = resp1
if err != nil {
t.Errorf("unexpected ListUpdate error: %v", err)
Expand All @@ -135,6 +165,14 @@ func TestNetAPI(t *testing.T) {
if !proto.Equal(gotResp, wantResp) {
t.Errorf("mismatching ListUpdate responses:\ngot %+v\nwant %+v", gotResp, wantResp)
}
if !reflect.DeepEqual(gotMaxDiffEntries, wantMaxDiffEntries) {
t.Errorf("mismatching ListUpdate max diff entries:\ngot %+v\nwant %+v",
gotMaxDiffEntries, wantMaxDiffEntries)
}
if !reflect.DeepEqual(gotMaxDatabaseEntries, wantMaxDatabaseEntries) {
t.Errorf("mismatching ListUpdate max database entries:\ngot %+v\nwant %+v",
gotMaxDatabaseEntries, wantMaxDatabaseEntries)
}

// Test that HashLookup marshal/unmarshal works.
wantReqHashPrefix = []byte("aaaa")
Expand Down
26 changes: 15 additions & 11 deletions cmd/wrlookup/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ import (
)

var (
apiKeyFlag = flag.String("apikey", "", "specify your Web Risk API key")
databaseFlag = flag.String("db", "", "path to the Web Risk database. By default persistent storage is disabled (not recommended).")
serverURLFlag = flag.String("server", webrisk.DefaultServerURL, "Web Risk API server address.")
proxyFlag = flag.String("proxy", "", "proxy to use to connect to the HTTP server")
threatTypesFlag = flag.String("threatTypes", "ALL", "threat types to check against")
apiKeyFlag = flag.String("apikey", "", "specify your Web Risk API key")
databaseFlag = flag.String("db", "", "path to the Web Risk database. By default persistent storage is disabled (not recommended).")
serverURLFlag = flag.String("server", webrisk.DefaultServerURL, "Web Risk API server address.")
proxyFlag = flag.String("proxy", "", "proxy to use to connect to the HTTP server")
threatTypesFlag = flag.String("threatTypes", "ALL", "threat types to check against")
maxDiffEntriesFlag = flag.Int("maxDiffEntries", 0, "maximum number of diff entries to return from a ComputeThreatListDiff request")
maxDatabaseEntriesFlag = flag.Int("maxDatabaseEntries", 0, "maximum number of database entries to be stored in the local database")
)

const usage = `wrlookup: command-line tool to lookup URLs with Web Risk.
Expand Down Expand Up @@ -81,12 +83,14 @@ func main() {
os.Exit(codeInvalid)
}
sb, err := webrisk.NewUpdateClient(webrisk.Config{
APIKey: *apiKeyFlag,
DBPath: *databaseFlag,
Logger: os.Stderr,
ServerURL: *serverURLFlag,
ProxyURL: *proxyFlag,
ThreatListArg: *threatTypesFlag,
APIKey: *apiKeyFlag,
DBPath: *databaseFlag,
Logger: os.Stderr,
ServerURL: *serverURLFlag,
ProxyURL: *proxyFlag,
ThreatListArg: *threatTypesFlag,
MaxDiffEntries: int32(*maxDiffEntriesFlag),
MaxDatabaseEntries: int32(*maxDatabaseEntriesFlag),
})
if err != nil {
fmt.Fprintln(os.Stderr, "Unable to initialize Web Risk client: ", err)
Expand Down
26 changes: 15 additions & 11 deletions cmd/wrserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,13 @@ const (
)

var (
apiKeyFlag = flag.String("apikey", os.Getenv("APIKEY"), "specify your Web Risk API key")
srvAddrFlag = flag.String("srvaddr", "0.0.0.0:8080", "TCP network address the HTTP server should use")
proxyFlag = flag.String("proxy", "", "proxy to use to connect to the HTTP server")
databaseFlag = flag.String("db", "", "path to the Web Risk database.")
threatTypesFlag = flag.String("threatTypes", "ALL", "threat types to check against")
apiKeyFlag = flag.String("apikey", os.Getenv("APIKEY"), "specify your Web Risk API key")
srvAddrFlag = flag.String("srvaddr", "0.0.0.0:8080", "TCP network address the HTTP server should use")
proxyFlag = flag.String("proxy", "", "proxy to use to connect to the HTTP server")
databaseFlag = flag.String("db", "", "path to the Web Risk database.")
threatTypesFlag = flag.String("threatTypes", "ALL", "threat types to check against")
maxDiffEntriesFlag = flag.Int("maxDiffEntries", 0, "maximum number of diff entries to return from a ComputeThreatListDiff request")
maxDatabaseEntriesFlag = flag.Int("maxDatabaseEntries", 0, "maximum number of database entries to be stored in the local database")
)

var threatTemplate = map[webrisk.ThreatType]string{
Expand Down Expand Up @@ -477,7 +479,7 @@ func runServer(srv *http.Server) (chan os.Signal, <-chan struct{}) {
// start listening for interrupts
exit := make(chan os.Signal, 1)
down := make(chan struct{})

// runs shutdown and cleanup on an exit signal
go func() {
<-exit
Expand Down Expand Up @@ -518,11 +520,13 @@ func main() {
os.Exit(1)
}
conf := webrisk.Config{
APIKey: *apiKeyFlag,
ProxyURL: *proxyFlag,
DBPath: *databaseFlag,
ThreatListArg: *threatTypesFlag,
Logger: os.Stderr,
APIKey: *apiKeyFlag,
ProxyURL: *proxyFlag,
DBPath: *databaseFlag,
ThreatListArg: *threatTypesFlag,
MaxDiffEntries: int32(*maxDiffEntriesFlag),
MaxDatabaseEntries: int32(*maxDatabaseEntriesFlag),
Logger: os.Stderr,
}
wr, err := webrisk.NewUpdateClient(conf)
if err != nil {
Expand Down
7 changes: 5 additions & 2 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) {
s = append(s, &pb.ComputeThreatListDiffRequest{
ThreatType: pb.ThreatType(td),
Constraints: &pb.ComputeThreatListDiffRequest_Constraints{
SupportedCompressions: db.config.compressionTypes},
SupportedCompressions: db.config.compressionTypes,
MaxDiffEntries: db.config.MaxDiffEntries,
MaxDatabaseEntries: db.config.MaxDatabaseEntries,
},
VersionToken: state,
})
}
Expand All @@ -212,7 +215,7 @@ func (db *database) Update(ctx context.Context, api api) (time.Duration, bool) {
last := db.config.now()
for _, req := range s {
// Query the API for the threat list and update the database.
resp, err := api.ListUpdate(ctx, req.ThreatType, req.VersionToken, req.Constraints.SupportedCompressions)
resp, err := api.ListUpdate(ctx, req)
if err != nil {
db.log.Printf("ListUpdate failure (%d): %v", db.updateAPIErrors+1, err)
db.setError(err)
Expand Down
35 changes: 33 additions & 2 deletions webrisk_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ const (

// Errors specific to this package.
var (
errClosed = errors.New("webrisk: handler is closed")
errStale = errors.New("webrisk: threat list is stale")
errClosed = errors.New("webrisk: handler is closed")
errStale = errors.New("webrisk: threat list is stale")
errMaxEntries = errors.New("webrisk: max entries must be a power of 2 between 2 ** 10 and 2 ** 20")
)

// ThreatType is an enumeration type for threats classes. Examples of threat
Expand Down Expand Up @@ -167,6 +168,18 @@ type Config struct {
// If empty, ThreatLists will be loaded instead.
ThreatListArg string

// MaxDiffEntries sets the maximum entries to request in a single call to ComputeThreatListDiff.
// This can be used in resource-constrained environments to limit the number of entries fetched
// at once. The default behavior (0) is to ignore this limit.
// If set, this should be a power of 2 between 2 ** 10 and 2 ** 20.
MaxDiffEntries int32

// MaxDatabaseEntries sets the maximum entries that the client will accept & store in the local
// database. This can be used to limit the size of the blocklist at the trade-off of decreased
// coverage. The default behavior (0) is to ignore this limit.
// If set, this should be a power of 2 between 2 ** 10 and 2 ** 20.
MaxDatabaseEntries int32

// ThreatLists determines which threat lists that UpdateClient should
// subscribe to. The threats reported by LookupURLs will only be ones that
// are specified by this list.
Expand Down Expand Up @@ -229,6 +242,19 @@ func parseThreatTypes(args string) ([]ThreatType, error) {
return r, nil
}

// validateMaxEntries validates a max entries argument, which must be either 0 or a power of 2
// between 2 ** 10 and 2 ** 20.
func validateMaxEntries(n int32) error {
if n == 0 {
return nil
}
// Bitwise check confirms a power of 2.
if n&(n-1) != 0 || n < 1024 || n > 1048576 {
return errMaxEntries
}
return nil
}

func (c Config) copy() Config {
c2 := c
c2.ThreatLists = append([]ThreatType(nil), c.ThreatLists...)
Expand Down Expand Up @@ -287,6 +313,11 @@ func NewUpdateClient(conf Config) (*UpdateClient, error) {
conf.ThreatLists = tl
}

// Validate max entries if args are passed.
if err := validateMaxEntries(conf.MaxDiffEntries); err != nil {
return nil, err
}

// Create the SafeBrowsing object.
if conf.api == nil {
var err error
Expand Down
3 changes: 1 addition & 2 deletions webrisk_client_system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ func TestNetworkAPIUpdate(t *testing.T) {
ThreatType: pb.ThreatType_MALWARE,
}

dat, err := nm.ListUpdate(context.Background(), req.ThreatType,
req.VersionToken, []pb.CompressionType{})
dat, err := nm.ListUpdate(context.Background(), req)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading

0 comments on commit 7fb0516

Please sign in to comment.