diff --git a/README.md b/README.md index e069c3a..60fabb5 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,78 @@ implemented in pure Go, i.e. does not wrap libnftnl. This is not an official Google product. -## Breaking changes +## Alpha status + +This package is in early stages, and only implements a subset of nftables +features. While the developers intend to keep interfaces & function signatures +backwards-compatible, no guarantees are made; bugs or any unexpected +structuring of nftables features may result in breaking changes. + +## Usage + +Issue commands to mutate or read nftables state. Commands that mutate state +(eg: `AddTable`, `AddChain`, `AddSet`, `AddRule`) are queued until `Flush` +is called. + +### Expressions + +The following expressions are implemented for use in rule logic: + +TODO + +### Examples + +#### Drop outgoing packets to 1.2.3.4 + +```go +c := &nftables.Conn{} + +myTable := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "myFilter", +}) + +myChain := c.AddChain(&nftables.Chain{ + Name: "myChain", + Table: myTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityFilter, +}) + +c.AddRule(&nftables.Rule{ + Table: myTable, + Chain: myChain, + Exprs: []expr.Any{ + // payload load 4b @ network header + 16 => reg 1 + // (Load the destination IP into register 1) + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, + // cmp eq reg 1 0x01020304 + // (bail if register 1 != 1.2.3.4) + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: net.ParseIP("1.2.3.4").To4(), + }, + // [ immediate reg 0 drop ] + // (drop the packet) + &expr.Verdict{ + Kind: expr.VerdictDrop, + }, + }, +}) + +if err := c.Flush(); err != nil { + // handle error +} + +``` -This package is in very early stages, and only contains enough data types and -functions to install very basic nftables rules. It is likely that mistakes with -the data types/API will be identified as more functionality is added. ## Contributions diff --git a/expr/immediate.go b/expr/immediate.go index f050ce5..a503845 100644 --- a/expr/immediate.go +++ b/expr/immediate.go @@ -15,6 +15,7 @@ package expr import ( + "encoding/binary" "fmt" "github.com/google/nftables/binaryutil" @@ -49,5 +50,30 @@ func (e *Immediate) marshal() ([]byte, error) { } func (e *Immediate) unmarshal(data []byte) error { - return fmt.Errorf("not yet implemented") + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_IMMEDIATE_DREG: + e.Register = ad.Uint32() + case unix.NFTA_IMMEDIATE_DATA: + nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) + if err != nil { + return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err) + } + for nestedAD.Next() { + switch nestedAD.Type() { + case unix.NFTA_DATA_VALUE: + e.Data = nestedAD.Bytes() + } + } + if nestedAD.Err() != nil { + return fmt.Errorf("decoding immediate: %v", nestedAD.Err()) + } + } + } + return ad.Err() } diff --git a/expr/lookup.go b/expr/lookup.go index 67ab165..6bdf986 100644 --- a/expr/lookup.go +++ b/expr/lookup.go @@ -15,7 +15,7 @@ package expr import ( - "fmt" + "encoding/binary" "github.com/google/nftables/binaryutil" "github.com/mdlayher/netlink" @@ -60,5 +60,24 @@ func (e *Lookup) marshal() ([]byte, error) { } func (e *Lookup) unmarshal(data []byte) error { - return fmt.Errorf("not yet implemented") + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_LOOKUP_SET: + e.SetName = ad.String() + case unix.NFTA_LOOKUP_SET_ID: + e.SetID = ad.Uint32() + case unix.NFTA_LOOKUP_SREG: + e.SourceRegister = ad.Uint32() + case unix.NFTA_LOOKUP_DREG: + e.DestRegister = ad.Uint32() + case unix.NFTA_LOOKUP_FLAGS: + e.Invert = (ad.Uint32() & unix.NFT_LOOKUP_F_INV) != 0 + } + } + return ad.Err() } diff --git a/expr/verdict.go b/expr/verdict.go index 92cc98b..37884ad 100644 --- a/expr/verdict.go +++ b/expr/verdict.go @@ -15,6 +15,7 @@ package expr import ( + "encoding/binary" "fmt" "github.com/google/nftables/binaryutil" @@ -85,5 +86,28 @@ func (e *Verdict) marshal() ([]byte, error) { } func (e *Verdict) unmarshal(data []byte) error { - return fmt.Errorf("not yet implemented") + ad, err := netlink.NewAttributeDecoder(data) + if err != nil { + return err + } + ad.ByteOrder = binary.BigEndian + for ad.Next() { + switch ad.Type() { + case unix.NFTA_IMMEDIATE_DATA: + nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes()) + if err != nil { + return fmt.Errorf("nested NewAttributeDecoder() failed: %v", err) + } + for nestedAD.Next() { + switch nestedAD.Type() { + case unix.NFTA_DATA_VERDICT: + e.Kind = VerdictKind(binaryutil.BigEndian.Uint32(nestedAD.Bytes()[4:])) + } + } + if nestedAD.Err() != nil { + return fmt.Errorf("decoding immediate: %v", nestedAD.Err()) + } + } + } + return ad.Err() } diff --git a/nftables.go b/nftables.go index 6fd9649..d59769b 100644 --- a/nftables.go +++ b/nftables.go @@ -213,6 +213,10 @@ func exprsFromMsg(b []byte) ([]expr.Any, error) { e = &expr.Counter{} case "payload": e = &expr.Payload{} + case "lookup": + e = &expr.Lookup{} + case "immediate": + e = &expr.Immediate{} } if e == nil { // TODO: introduce an opaque expression type so that users know @@ -224,6 +228,15 @@ func exprsFromMsg(b []byte) ([]expr.Any, error) { if err := expr.Unmarshal(b, e); err != nil { return err } + // Verdict expressions are a special-case of immediate expressions, so + // if the expression is an immediate writing nothing into the verdict + // register (invalid), re-parse it as a verdict expression. + if imm, isImmediate := e.(*expr.Immediate); isImmediate && imm.Register == unix.NFT_REG_VERDICT && len(imm.Data) == 0 { + e = &expr.Verdict{} + if err := expr.Unmarshal(b, e); err != nil { + return err + } + } exprs = append(exprs, e) return nil }) diff --git a/nftables_test.go b/nftables_test.go index 04cdafd..2cde4ef 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -19,6 +19,7 @@ import ( "flag" "fmt" "net" + "reflect" "runtime" "strings" "testing" @@ -1112,3 +1113,130 @@ func TestDeleteElementNamedSet(t *testing.T) { t.Errorf("elems[0].Key = %v, want 22", elems[0].Key) } } + +func TestGetRuleLookupVerdictImmediate(t *testing.T) { + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + forward := c.AddChain(&nftables.Chain{ + Name: "forward", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + + set := &nftables.Set{ + Table: filter, + Name: "kek", + KeyType: nftables.TypeInetService, + } + if err := c.AddSet(set, nil); err != nil { + t.Errorf("c.AddSet(portSet) failed: %v", err) + } + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + c.AddRule(&nftables.Rule{ + Table: filter, + Chain: forward, + Exprs: []expr.Any{ + // [ meta load l4proto => reg 1 ] + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + // [ cmp eq reg 1 0x00000006 ] + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + // [ payload load 2b @ transport header + 2 => reg 1 ] + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + // [ lookup reg 1 set __set%d ] + &expr.Lookup{ + SourceRegister: 1, + SetName: set.Name, + SetID: set.ID, + }, + // [ immediate reg 0 drop ] + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + // [ immediate reg 2 kek ] + &expr.Immediate{ + Register: 2, + Data: []byte("kek"), + }, + }, + }) + + if err := c.Flush(); err != nil { + t.Errorf("c.Flush() failed: %v", err) + } + + rules, err := c.GetRule( + &nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }, + &nftables.Chain{ + Name: "forward", + }, + ) + if err != nil { + t.Fatal(err) + } + + if got, want := len(rules), 1; got != want { + t.Fatalf("unexpected number of rules: got %d, want %d", got, want) + } + if got, want := len(rules[0].Exprs), 6; got != want { + t.Fatalf("unexpected number of exprs: got %d, want %d", got, want) + } + + lookup, lookupOk := rules[0].Exprs[3].(*expr.Lookup) + if !lookupOk { + t.Fatalf("Exprs[3] is type %T, want *expr.Lookup", rules[0].Exprs[3]) + } + if want := (&expr.Lookup{ + SourceRegister: 1, + SetName: set.Name, + }); !reflect.DeepEqual(lookup, want) { + t.Errorf("lookup expr = %+v, wanted %+v", lookup, want) + } + + verdict, verdictOk := rules[0].Exprs[4].(*expr.Verdict) + if !verdictOk { + t.Fatalf("Exprs[4] is type %T, want *expr.Verdict", rules[0].Exprs[4]) + } + if want := (&expr.Verdict{ + Kind: expr.VerdictAccept, + }); !reflect.DeepEqual(verdict, want) { + t.Errorf("verdict expr = %+v, wanted %+v", verdict, want) + } + + imm, immOk := rules[0].Exprs[5].(*expr.Immediate) + if !immOk { + t.Fatalf("Exprs[4] is type %T, want *expr.Immediate", rules[0].Exprs[5]) + } + if want := (&expr.Immediate{ + Register: 2, + Data: []byte("kek"), + }); !reflect.DeepEqual(imm, want) { + t.Errorf("verdict expr = %+v, wanted %+v", imm, want) + } +}