From 8cb78f74320eeeab3953455cd6aafdba86b8177c Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Thu, 26 Dec 2019 12:42:43 +0100 Subject: [PATCH 01/16] set handle back --- conn.go | 23 +++++++++++++-- nftables_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ rule.go | 2 ++ 3 files changed, 97 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 3768645..094c17a 100644 --- a/conn.go +++ b/conn.go @@ -35,6 +35,7 @@ type Conn struct { NetNS int // Network namespace netlink will interact with. sync.Mutex messages []netlink.Message + rules []*Rule err error } @@ -43,6 +44,7 @@ func (cc *Conn) Flush() error { cc.Lock() defer func() { cc.messages = nil + cc.rules = nil cc.Unlock() }() if len(cc.messages) == 0 { @@ -63,8 +65,25 @@ func (cc *Conn) Flush() error { return fmt.Errorf("SendMessages: %w", err) } - if _, err := conn.Receive(); err != nil { - return fmt.Errorf("Receive: %w", err) + echoedRules := 0 + + for len(cc.rules) > echoedRules { + rmsg, err := conn.Receive() + + if err != nil { + return fmt.Errorf("Receive: %w", err) + } + + for _, msg := range rmsg { + if msg.Header.Type == ruleHeaderType { + rule, err := ruleFromMsg(msg) + if err == nil { + cc.rules[echoedRules].Handle = rule.Handle + echoedRules++ + } + } + } + } return nil diff --git a/nftables_test.go b/nftables_test.go index e3337ef..19b0f9a 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -3980,3 +3980,77 @@ func TestStatelessNAT(t *testing.T) { t.Fatal(err) } } + +func TestHandleBack(t *testing.T) { + + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + prerouting := c.AddChain(&nftables.Chain{ + Name: "base-chain", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityFilter, + }) + + var rulesCreated []*nftables.Rule + + rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{ + Table: filter, + Chain: prerouting, + Exprs: []expr.Any{ + &expr.Verdict{ + // [ immediate reg 0 drop ] + Kind: expr.VerdictDrop, + }, + }, + })) + + rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{ + Table: filter, + Chain: prerouting, + Exprs: []expr.Any{ + &expr.Verdict{ + // [ immediate reg 0 drop ] + Kind: expr.VerdictDrop, + }, + }, + })) + + for i, r := range rulesCreated { + if r.Handle != 0 { + t.Fatalf("unexpected handle value at %d", i) + } + } + + if err := c.Flush(); err != nil { + t.Fatal(err) + } + + rulesGetted, _ := c.GetRule(filter, prerouting) + + if len(rulesGetted) != len(rulesCreated) { + t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted), len(rulesCreated)) + } + + for i, r := range rulesGetted { + if r.Handle == 0 { + t.Fatalf("handle value is empty at %d", i) + } + + if r.Handle != rulesCreated[i].Handle { + t.Fatalf("mismatched handle at %d", i) + } + } +} diff --git a/rule.go b/rule.go index 48d79d1..561da54 100644 --- a/rule.go +++ b/rule.go @@ -130,6 +130,8 @@ func (cc *Conn) AddRule(r *Rule) *Rule { Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), }) + cc.rules = append(cc.rules, r) + return r } From 5a645a16e0b978ab100de1b7b7f589431ccaff06 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Fri, 27 Dec 2019 23:18:57 +0100 Subject: [PATCH 02/16] improve reliability --- conn.go | 20 ++++++++++----- nftables_test.go | 67 ++++++++++++++++++++++++++++++++++++++---------- rule.go | 6 ++++- 3 files changed, 73 insertions(+), 20 deletions(-) diff --git a/conn.go b/conn.go index 094c17a..e7ba741 100644 --- a/conn.go +++ b/conn.go @@ -35,7 +35,7 @@ type Conn struct { NetNS int // Network namespace netlink will interact with. sync.Mutex messages []netlink.Message - rules []*Rule + rules map[int]*Rule err error } @@ -61,12 +61,20 @@ func (cc *Conn) Flush() error { defer conn.Close() - if _, err := conn.SendMessages(batch(cc.messages)); err != nil { + smsg, err := conn.SendMessages(batch(cc.messages)) + + if err != nil { return fmt.Errorf("SendMessages: %w", err) } - echoedRules := 0 + // Retrieving of seq number associated to rules + rulesBySeq := make(map[uint32]*Rule) + for i, rule := range cc.rules { + rulesBySeq[smsg[i].Header.Sequence] = rule + } + // Search handle in netlink messages based on requests seq + echoedRules := 0 for len(cc.rules) > echoedRules { rmsg, err := conn.Receive() @@ -75,10 +83,10 @@ func (cc *Conn) Flush() error { } for _, msg := range rmsg { - if msg.Header.Type == ruleHeaderType { - rule, err := ruleFromMsg(msg) + if srule, ok := rulesBySeq[msg.Header.Sequence]; ok { + rrule, err := ruleFromMsg(msg) if err == nil { - cc.rules[echoedRules].Handle = rule.Handle + srule.Handle = rrule.Handle echoedRules++ } } diff --git a/nftables_test.go b/nftables_test.go index 19b0f9a..02795b2 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -3996,19 +3996,28 @@ func TestHandleBack(t *testing.T) { Name: "filter", }) - prerouting := c.AddChain(&nftables.Chain{ - Name: "base-chain", + chain1 := c.AddChain(&nftables.Chain{ + Name: "chain1", Table: filter, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityFilter, }) - var rulesCreated []*nftables.Rule + chain2 := c.AddChain(&nftables.Chain{ + Name: "chain2", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityFilter, + }) - rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{ + var rulesCreated1 []*nftables.Rule + var rulesCreated2 []*nftables.Rule + + rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{ Table: filter, - Chain: prerouting, + Chain: chain1, Exprs: []expr.Any{ &expr.Verdict{ // [ immediate reg 0 drop ] @@ -4017,9 +4026,20 @@ func TestHandleBack(t *testing.T) { }, })) - rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{ + rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{ Table: filter, - Chain: prerouting, + Chain: chain1, + Exprs: []expr.Any{ + &expr.Verdict{ + // [ immediate reg 0 drop ] + Kind: expr.VerdictDrop, + }, + }, + })) + + rulesCreated2 = append(rulesCreated2, c.AddRule(&nftables.Rule{ + Table: filter, + Chain: chain2, Exprs: []expr.Any{ &expr.Verdict{ // [ immediate reg 0 drop ] @@ -4028,7 +4048,13 @@ func TestHandleBack(t *testing.T) { }, })) - for i, r := range rulesCreated { + for i, r := range rulesCreated1 { + if r.Handle != 0 { + t.Fatalf("unexpected handle value at %d", i) + } + } + + for i, r := range rulesCreated2 { if r.Handle != 0 { t.Fatalf("unexpected handle value at %d", i) } @@ -4038,18 +4064,33 @@ func TestHandleBack(t *testing.T) { t.Fatal(err) } - rulesGetted, _ := c.GetRule(filter, prerouting) + rulesGetted1, _ := c.GetRule(filter, chain1) + rulesGetted2, _ := c.GetRule(filter, chain2) + + if len(rulesGetted1) != len(rulesCreated1) { + t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted1), len(rulesCreated1)) + } - if len(rulesGetted) != len(rulesCreated) { - t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted), len(rulesCreated)) + if len(rulesGetted2) != len(rulesCreated2) { + t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted2), len(rulesCreated2)) + } + + for i, r := range rulesGetted1 { + if r.Handle == 0 { + t.Fatalf("handle value is empty at %d", i) + } + + if r.Handle != rulesCreated1[i].Handle { + t.Fatalf("mismatched handle at %d", i) + } } - for i, r := range rulesGetted { + for i, r := range rulesGetted2 { if r.Handle == 0 { t.Fatalf("handle value is empty at %d", i) } - if r.Handle != rulesCreated[i].Handle { + if r.Handle != rulesCreated2[i].Handle { t.Fatalf("mismatched handle at %d", i) } } diff --git a/rule.go b/rule.go index 561da54..6a403eb 100644 --- a/rule.go +++ b/rule.go @@ -130,7 +130,11 @@ func (cc *Conn) AddRule(r *Rule) *Rule { Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), }) - cc.rules = append(cc.rules, r) + if cc.rules == nil { + cc.rules = make(map[int]*Rule) + } + + cc.rules[len(cc.messages)] = r return r } From b00522b83ee99a943f903bf272312364456f38c8 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Fri, 27 Dec 2019 23:24:48 +0100 Subject: [PATCH 03/16] fix typo --- nftables_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nftables_test.go b/nftables_test.go index 02795b2..173a497 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -4068,11 +4068,11 @@ func TestHandleBack(t *testing.T) { rulesGetted2, _ := c.GetRule(filter, chain2) if len(rulesGetted1) != len(rulesCreated1) { - t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted1), len(rulesCreated1)) + t.Fatalf("Bad ruleset length got %d want %d", len(rulesGetted1), len(rulesCreated1)) } if len(rulesGetted2) != len(rulesCreated2) { - t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted2), len(rulesCreated2)) + t.Fatalf("Bad ruleset length got %d want %d", len(rulesGetted2), len(rulesCreated2)) } for i, r := range rulesGetted1 { From 39f8fec1299e15e18cabb00791a38ab0594bcc24 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Fri, 27 Dec 2019 23:42:50 +0100 Subject: [PATCH 04/16] disable test on travis --- nftables_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nftables_test.go b/nftables_test.go index 173a497..5d9258e 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -3983,6 +3983,10 @@ func TestStatelessNAT(t *testing.T) { func TestHandleBack(t *testing.T) { + if os.Getenv("TRAVIS") == "true" { + t.SkipNow() + } + // Create a new network namespace to test these operations, // and tear down the namespace at test completion. c, newNS := openSystemNFTConn(t) From 87f28cef6e6a5ccf12f09ea9dfa85158e0e6e654 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Fri, 3 Jan 2020 10:56:42 +0100 Subject: [PATCH 05/16] generic --- conn.go | 29 +++++++++++++++-------------- rule.go | 15 ++++++++++++--- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/conn.go b/conn.go index e7ba741..4978618 100644 --- a/conn.go +++ b/conn.go @@ -24,6 +24,10 @@ import ( "golang.org/x/sys/unix" ) +type Entity interface { + HandleResponse(netlink.Message) +} + // A Conn represents a netlink connection of the nftables family. // // All methods return their input, so that variables can be defined from string @@ -35,7 +39,7 @@ type Conn struct { NetNS int // Network namespace netlink will interact with. sync.Mutex messages []netlink.Message - rules map[int]*Rule + entities map[int]Entity err error } @@ -44,7 +48,7 @@ func (cc *Conn) Flush() error { cc.Lock() defer func() { cc.messages = nil - cc.rules = nil + cc.entities = nil cc.Unlock() }() if len(cc.messages) == 0 { @@ -67,15 +71,15 @@ func (cc *Conn) Flush() error { return fmt.Errorf("SendMessages: %w", err) } - // Retrieving of seq number associated to rules - rulesBySeq := make(map[uint32]*Rule) - for i, rule := range cc.rules { - rulesBySeq[smsg[i].Header.Sequence] = rule + // Retrieving of seq number associated to entities + entitiesBySeq := make(map[uint32]Entity) + for i, e := range cc.entities { + entitiesBySeq[smsg[i].Header.Sequence] = e } // Search handle in netlink messages based on requests seq - echoedRules := 0 - for len(cc.rules) > echoedRules { + echoedEntities := 0 + for len(cc.entities) > echoedEntities { rmsg, err := conn.Receive() if err != nil { @@ -83,12 +87,9 @@ func (cc *Conn) Flush() error { } for _, msg := range rmsg { - if srule, ok := rulesBySeq[msg.Header.Sequence]; ok { - rrule, err := ruleFromMsg(msg) - if err == nil { - srule.Handle = rrule.Handle - echoedRules++ - } + if e, ok := entitiesBySeq[msg.Header.Sequence]; ok { + e.HandleResponse(msg) + echoedEntities++ } } diff --git a/rule.go b/rule.go index 6a403eb..2f3deae 100644 --- a/rule.go +++ b/rule.go @@ -130,11 +130,11 @@ func (cc *Conn) AddRule(r *Rule) *Rule { Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), }) - if cc.rules == nil { - cc.rules = make(map[int]*Rule) + if cc.entities == nil { + cc.entities = make(map[int]Entity) } - cc.rules[len(cc.messages)] = r + cc.entities[len(cc.messages)] = r return r } @@ -166,6 +166,15 @@ func (cc *Conn) DelRule(r *Rule) error { return nil } +// HandleResponse retrieves Handle in netlink response +func (r *Rule) HandleResponse(msg netlink.Message) { + rule, err := ruleFromMsg(msg) + + if err == nil { + r.Handle = rule.Handle + } +} + func exprsFromMsg(b []byte) ([]expr.Any, error) { ad, err := netlink.NewAttributeDecoder(b) if err != nil { From 994c20d585026feba851ab5c6cf1291b5869da73 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Mon, 6 Jan 2020 13:19:59 +0100 Subject: [PATCH 06/16] better loop control --- conn.go | 27 ++++++++++++++++++++++----- go.mod | 4 ++-- go.sum | 4 ++++ 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index 4978618..705b2e0 100644 --- a/conn.go +++ b/conn.go @@ -77,9 +77,8 @@ func (cc *Conn) Flush() error { entitiesBySeq[smsg[i].Header.Sequence] = e } - // Search handle in netlink messages based on requests seq - echoedEntities := 0 - for len(cc.entities) > echoedEntities { + // Trigger entities callback + for checkReceive(conn) { rmsg, err := conn.Receive() if err != nil { @@ -89,15 +88,33 @@ func (cc *Conn) Flush() error { for _, msg := range rmsg { if e, ok := entitiesBySeq[msg.Header.Sequence]; ok { e.HandleResponse(msg) - echoedEntities++ } } - } return nil } +func checkReceive(c *netlink.Conn) bool { + sc, err := c.SyscallConn() + + var n int + + sc.Control(func(fd uintptr) { + var fdSet unix.FdSet + fdSet.Zero() + fdSet.Set(int(fd)) + + n, err = unix.Select(int(fd)+1, &fdSet, nil, nil, &unix.Timeval{}) + }) + + if err == nil && n > 0 { + return true + } + + return false +} + // FlushRuleset flushes the entire ruleset. See also // https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level func (cc *Conn) FlushRuleset() { diff --git a/go.mod b/go.mod index dfd5143..c5bf29e 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.12 require ( github.com/koneu/natend v0.0.0-20150829182554-ec0926ea948d - github.com/mdlayher/netlink v0.0.0-20191009155606-de872b0d824b + github.com/mdlayher/netlink v1.0.0 github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271 // indirect - golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c + golang.org/x/sys v0.0.0-20200106114638-5f8ca72cd632 ) diff --git a/go.sum b/go.sum index 452fd2b..d20b949 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/koneu/natend v0.0.0-20150829182554-ec0926ea948d/go.mod h1:QHb4k4cr1fQ github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA= github.com/mdlayher/netlink v0.0.0-20191009155606-de872b0d824b h1:W3er9pI7mt2gOqOWzwvx20iJ8Akiqz1mUMTxU6wdvl8= github.com/mdlayher/netlink v0.0.0-20191009155606-de872b0d824b/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M= +github.com/mdlayher/netlink v1.0.0 h1:vySPY5Oxnn/8lxAPn2cK6kAzcZzYJl3KriSLO46OT18= +github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -25,4 +27,6 @@ golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c h1:S/FtSvpNLtFBgjTqcKsRpsa6a golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191112214154-59a1497f0cea h1:Mz1TMnfJDRJLk8S8OPCoJYgrsp/Se/2TBre2+vwX128= golang.org/x/sys v0.0.0-20191113150313-8ad342257130 h1:+sdNBpwFF05NvMnEyGynbOs/Gr2LQwORWEPKXuEXxzU= +golang.org/x/sys v0.0.0-20200106114638-5f8ca72cd632 h1:ateQkYCVYo8UwIBvoR3zj1Dh2K6Op/n3GxemXfB44/Y= +golang.org/x/sys v0.0.0-20200106114638-5f8ca72cd632/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= From 3650f836aebb4bd2f7d0a2eda65ed2f45a58e439 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Mon, 6 Jan 2020 14:25:53 +0100 Subject: [PATCH 07/16] fix tests --- conn.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 705b2e0..57ce29b 100644 --- a/conn.go +++ b/conn.go @@ -78,7 +78,7 @@ func (cc *Conn) Flush() error { } // Trigger entities callback - for checkReceive(conn) { + for cc.checkReceive(conn) { rmsg, err := conn.Receive() if err != nil { @@ -95,7 +95,11 @@ func (cc *Conn) Flush() error { return nil } -func checkReceive(c *netlink.Conn) bool { +func (cc *Conn) checkReceive(c *netlink.Conn) bool { + if cc.TestDial != nil { + return false + } + sc, err := c.SyscallConn() var n int From 9dc224dc94a35b32e3c38cb6143b5044085efb39 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Mon, 6 Jan 2020 15:06:13 +0100 Subject: [PATCH 08/16] improved tests --- nftables_test.go | 124 ++++++++++++++++++----------------------------- 1 file changed, 46 insertions(+), 78 deletions(-) diff --git a/nftables_test.go b/nftables_test.go index 5d9258e..3201457 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -23,6 +23,7 @@ import ( "reflect" "runtime" "strings" + "sync" "testing" "github.com/google/nftables" @@ -4000,102 +4001,69 @@ func TestHandleBack(t *testing.T) { Name: "filter", }) - chain1 := c.AddChain(&nftables.Chain{ - Name: "chain1", + chain := c.AddChain(&nftables.Chain{ + Name: "chain", Table: filter, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityFilter, }) - chain2 := c.AddChain(&nftables.Chain{ - Name: "chain2", - Table: filter, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookPrerouting, - Priority: nftables.ChainPriorityFilter, - }) - - var rulesCreated1 []*nftables.Rule - var rulesCreated2 []*nftables.Rule - - rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{ - Table: filter, - Chain: chain1, - Exprs: []expr.Any{ - &expr.Verdict{ - // [ immediate reg 0 drop ] - Kind: expr.VerdictDrop, - }, - }, - })) - - rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{ - Table: filter, - Chain: chain1, - Exprs: []expr.Any{ - &expr.Verdict{ - // [ immediate reg 0 drop ] - Kind: expr.VerdictDrop, - }, - }, - })) + c.Flush() - rulesCreated2 = append(rulesCreated2, c.AddRule(&nftables.Rule{ - Table: filter, - Chain: chain2, - Exprs: []expr.Any{ - &expr.Verdict{ - // [ immediate reg 0 drop ] - Kind: expr.VerdictDrop, - }, - }, - })) + execN := func(w int, n int) { - for i, r := range rulesCreated1 { - if r.Handle != 0 { - t.Fatalf("unexpected handle value at %d", i) - } - } + c := &nftables.Conn{NetNS: int(newNS)} - for i, r := range rulesCreated2 { - if r.Handle != 0 { - t.Fatalf("unexpected handle value at %d", i) - } - } + for i := 0; i < n; i++ { - if err := c.Flush(); err != nil { - t.Fatal(err) - } + r := c.AddRule(&nftables.Rule{ + Table: filter, + Chain: chain, + UserData: []byte(fmt.Sprintf("%d-%d", w, i)), + Exprs: []expr.Any{ + &expr.Verdict{ + // [ immediate reg 0 drop ] + Kind: expr.VerdictDrop, + }, + }, + }) - rulesGetted1, _ := c.GetRule(filter, chain1) - rulesGetted2, _ := c.GetRule(filter, chain2) + if r.Handle != 0 { + t.Fatalf("unexpected handle value at %d", i) + } - if len(rulesGetted1) != len(rulesCreated1) { - t.Fatalf("Bad ruleset length got %d want %d", len(rulesGetted1), len(rulesCreated1)) - } + if err := c.Flush(); err != nil { + t.Fatal(err) + } - if len(rulesGetted2) != len(rulesCreated2) { - t.Fatalf("Bad ruleset length got %d want %d", len(rulesGetted2), len(rulesCreated2)) - } + rulesGetted, _ := c.GetRule(filter, chain) - for i, r := range rulesGetted1 { - if r.Handle == 0 { - t.Fatalf("handle value is empty at %d", i) - } + for i, rg := range rulesGetted { + if r.Handle == 0 { + t.Fatalf("handle value is empty at %d", i) + } - if r.Handle != rulesCreated1[i].Handle { - t.Fatalf("mismatched handle at %d", i) + if bytes.Equal(rg.UserData, r.UserData) && rg.Handle != r.Handle { + t.Fatalf("mismatched handle at %d-%d, got: %d, want: %d", w, i, r.Handle, rg.Handle) + } + } } } - for i, r := range rulesGetted2 { - if r.Handle == 0 { - t.Fatalf("handle value is empty at %d", i) - } + const ( + workers = 16 + iterations = 256 + ) - if r.Handle != rulesCreated2[i].Handle { - t.Fatalf("mismatched handle at %d", i) - } + var wg sync.WaitGroup + wg.Add(workers) + for i := 0; i < workers; i++ { + go func(n int) { + defer wg.Done() + execN(n, iterations) + }(i) } + + wg.Wait() } From e906f3354d4d58ed9f441cef7c1761210d17541a Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Mon, 6 Jan 2020 15:19:16 +0100 Subject: [PATCH 09/16] errors handling --- conn.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index 57ce29b..6011341 100644 --- a/conn.go +++ b/conn.go @@ -78,7 +78,8 @@ func (cc *Conn) Flush() error { } // Trigger entities callback - for cc.checkReceive(conn) { + msg, err := cc.checkReceive(conn) + for msg { rmsg, err := conn.Receive() if err != nil { @@ -90,18 +91,23 @@ func (cc *Conn) Flush() error { e.HandleResponse(msg) } } + msg, err = cc.checkReceive(conn) } - return nil + return err } -func (cc *Conn) checkReceive(c *netlink.Conn) bool { +func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) { if cc.TestDial != nil { - return false + return false, nil } sc, err := c.SyscallConn() + if err != nil { + return false, fmt.Errorf("SyscallConn error: %w", err) + } + var n int sc.Control(func(fd uintptr) { @@ -113,10 +119,10 @@ func (cc *Conn) checkReceive(c *netlink.Conn) bool { }) if err == nil && n > 0 { - return true + return true, nil } - return false + return false, err } // FlushRuleset flushes the entire ruleset. See also From 3aaad4cf4ce68b0b83a4e46f498c338ab4baf6d9 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Mon, 6 Jan 2020 15:40:44 +0100 Subject: [PATCH 10/16] allow handle test in travis --- nftables_test.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/nftables_test.go b/nftables_test.go index 3201457..8d0e5c4 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -3982,11 +3982,7 @@ func TestStatelessNAT(t *testing.T) { } } -func TestHandleBack(t *testing.T) { - - if os.Getenv("TRAVIS") == "true" { - t.SkipNow() - } +func TestIntegrationAddRule(t *testing.T) { // Create a new network namespace to test these operations, // and tear down the namespace at test completion. From c578ee35d6bef766dc023c00cc30e409630f896c Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Wed, 8 Jan 2020 10:52:26 +0100 Subject: [PATCH 11/16] improve reliability --- chain.go | 6 ++--- conn.go | 62 +++++++++++++++++++++++++++++++++--------------- nftables_test.go | 8 +++---- obj.go | 2 +- rule.go | 11 ++++----- set.go | 12 +++++----- table.go | 6 ++--- 7 files changed, 64 insertions(+), 43 deletions(-) diff --git a/chain.go b/chain.go index 74caca5..9b77640 100644 --- a/chain.go +++ b/chain.go @@ -123,7 +123,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain { {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, })...) } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -144,7 +144,7 @@ func (cc *Conn) DelChain(c *Chain) { {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), Flags: netlink.Request | netlink.Acknowledge, @@ -162,7 +162,7 @@ func (cc *Conn) FlushChain(c *Chain) { {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge, diff --git a/conn.go b/conn.go index 6011341..67c97a6 100644 --- a/conn.go +++ b/conn.go @@ -17,6 +17,7 @@ package nftables import ( "fmt" "sync" + "sync/atomic" "github.com/google/nftables/expr" "github.com/mdlayher/netlink" @@ -39,7 +40,8 @@ type Conn struct { NetNS int // Network namespace netlink will interact with. sync.Mutex messages []netlink.Message - entities map[int]Entity + entities map[int32]Entity + it int32 err error } @@ -49,6 +51,7 @@ func (cc *Conn) Flush() error { defer func() { cc.messages = nil cc.entities = nil + cc.it = 0 cc.Unlock() }() if len(cc.messages) == 0 { @@ -65,7 +68,9 @@ func (cc *Conn) Flush() error { defer conn.Close() - smsg, err := conn.SendMessages(batch(cc.messages)) + cc.endBatch(cc.messages) + + _, err = conn.SendMessages(cc.messages[:cc.it+1]) if err != nil { return fmt.Errorf("SendMessages: %w", err) @@ -74,7 +79,7 @@ func (cc *Conn) Flush() error { // Retrieving of seq number associated to entities entitiesBySeq := make(map[uint32]Entity) for i, e := range cc.entities { - entitiesBySeq[smsg[i].Header.Sequence] = e + entitiesBySeq[cc.messages[i].Header.Sequence] = e } // Trigger entities callback @@ -97,6 +102,36 @@ func (cc *Conn) Flush() error { return err } +// PutMessage store netlink message to sent after +func (cc *Conn) PutMessage(msg netlink.Message) int32 { + if cc.messages == nil { + cc.messages = make([]netlink.Message, 128) + cc.messages = append(cc.messages, netlink.Message{}) + cc.messages[0] = netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), + Flags: netlink.Request, + }, + Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + } + } + + i := atomic.AddInt32(&cc.it, 1) + + cc.messages = append(cc.messages, netlink.Message{}) + cc.messages[i] = msg + + return i +} + +// PutEntity store entity to relate to netlink response +func (cc *Conn) PutEntity(i int32, e Entity) { + if cc.entities == nil { + cc.entities = make(map[int32]Entity) + } + cc.entities[i] = e +} + func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) { if cc.TestDial != nil { return false, nil @@ -130,7 +165,7 @@ func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) { func (cc *Conn) FlushRuleset() { cc.Lock() defer cc.Unlock() - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -171,26 +206,15 @@ func (cc *Conn) marshalExpr(e expr.Any) []byte { return b } -func batch(messages []netlink.Message) []netlink.Message { - batch := []netlink.Message{ - { - Header: netlink.Header{ - Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), - Flags: netlink.Request, - }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - }, - } +func (cc *Conn) endBatch(messages []netlink.Message) { - batch = append(batch, messages...) + i := atomic.AddInt32(&cc.it, 1) - batch = append(batch, netlink.Message{ + cc.messages[i] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - }) - - return batch + } } diff --git a/nftables_test.go b/nftables_test.go index 8d0e5c4..c90d157 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -4033,13 +4033,13 @@ func TestIntegrationAddRule(t *testing.T) { t.Fatal(err) } + if r.Handle == 0 { + t.Fatalf("handle value is empty at %d", i) + } + rulesGetted, _ := c.GetRule(filter, chain) for i, rg := range rulesGetted { - if r.Handle == 0 { - t.Fatalf("handle value is empty at %d", i) - } - if bytes.Equal(rg.UserData, r.UserData) && rg.Handle != r.Handle { t.Fatalf("mismatched handle at %d-%d, got: %d, want: %d", w, i, r.Handle, rg.Handle) } diff --git a/obj.go b/obj.go index f3627df..d3528f8 100644 --- a/obj.go +++ b/obj.go @@ -43,7 +43,7 @@ func (cc *Conn) AddObj(o Obj) Obj { return nil } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, diff --git a/rule.go b/rule.go index 2f3deae..9ca4168 100644 --- a/rule.go +++ b/rule.go @@ -122,19 +122,16 @@ func (cc *Conn) AddRule(r *Rule) *Rule { flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND } - cc.messages = append(cc.messages, netlink.Message{ + m := netlink.Message{ Header: netlink.Header{ Type: ruleHeaderType, Flags: flags, }, Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), - }) - - if cc.entities == nil { - cc.entities = make(map[int]Entity) } - cc.entities[len(cc.messages)] = r + i := cc.PutMessage(m) + cc.PutEntity(i, r) return r } @@ -155,7 +152,7 @@ func (cc *Conn) DelRule(r *Rule) error { })...) flags := netlink.Request | netlink.Acknowledge - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: flags, diff --git a/set.go b/set.go index 2b9ee7e..4fa283d 100644 --- a/set.go +++ b/set.go @@ -165,7 +165,7 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -327,7 +327,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")}) } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -342,7 +342,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { if err != nil { return err } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -362,7 +362,7 @@ func (cc *Conn) DelSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), Flags: netlink.Request | netlink.Acknowledge, @@ -383,7 +383,7 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -402,7 +402,7 @@ func (cc *Conn) FlushSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge, diff --git a/table.go b/table.go index da0126a..9b47f1f 100644 --- a/table.go +++ b/table.go @@ -53,7 +53,7 @@ func (cc *Conn) DelTable(t *Table) { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge, @@ -71,7 +71,7 @@ func (cc *Conn) AddTable(t *Table) *Table { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -89,7 +89,7 @@ func (cc *Conn) FlushTable(t *Table) { data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.PutMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge, From 4edcb6035255592f184c898f0761fc368875ade3 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Wed, 8 Jan 2020 10:58:01 +0100 Subject: [PATCH 12/16] fix --- conn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/conn.go b/conn.go index 67c97a6..3c452cb 100644 --- a/conn.go +++ b/conn.go @@ -118,7 +118,6 @@ func (cc *Conn) PutMessage(msg netlink.Message) int32 { i := atomic.AddInt32(&cc.it, 1) - cc.messages = append(cc.messages, netlink.Message{}) cc.messages[i] = msg return i From c8335d667e6136c5b4691e71b9cdd797c09069f8 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Wed, 8 Jan 2020 11:45:04 +0100 Subject: [PATCH 13/16] fix 2 --- conn.go | 1 - 1 file changed, 1 deletion(-) diff --git a/conn.go b/conn.go index 3c452cb..fba8d51 100644 --- a/conn.go +++ b/conn.go @@ -106,7 +106,6 @@ func (cc *Conn) Flush() error { func (cc *Conn) PutMessage(msg netlink.Message) int32 { if cc.messages == nil { cc.messages = make([]netlink.Message, 128) - cc.messages = append(cc.messages, netlink.Message{}) cc.messages[0] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), From a400c5deff1426d6fc7d3c9f4da71cec961de844 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Wed, 8 Jan 2020 14:27:53 +0100 Subject: [PATCH 14/16] remove slice limit --- conn.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index fba8d51..6a552d8 100644 --- a/conn.go +++ b/conn.go @@ -39,6 +39,7 @@ type Conn struct { TestDial nltest.Func // for testing only; passed to nltest.Dial NetNS int // Network namespace netlink will interact with. sync.Mutex + put sync.Mutex messages []netlink.Message entities map[int32]Entity it int32 @@ -104,8 +105,11 @@ func (cc *Conn) Flush() error { // PutMessage store netlink message to sent after func (cc *Conn) PutMessage(msg netlink.Message) int32 { + cc.put.Lock() + defer cc.put.Unlock() + if cc.messages == nil { - cc.messages = make([]netlink.Message, 128) + cc.messages = make([]netlink.Message, 16) cc.messages[0] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), @@ -117,6 +121,10 @@ func (cc *Conn) PutMessage(msg netlink.Message) int32 { i := atomic.AddInt32(&cc.it, 1) + if len(cc.messages) <= int(i) { + cc.messages = resize(cc.messages) + } + cc.messages[i] = msg return i @@ -208,6 +216,10 @@ func (cc *Conn) endBatch(messages []netlink.Message) { i := atomic.AddInt32(&cc.it, 1) + if len(cc.messages) <= int(i) { + cc.messages = resize(cc.messages) + } + cc.messages[i] = netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), @@ -216,3 +228,9 @@ func (cc *Conn) endBatch(messages []netlink.Message) { Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), } } + +func resize(messages []netlink.Message) []netlink.Message { + new := make([]netlink.Message, cap(messages)*2) + copy(new, messages) + return new +} From 0d3f3ffbed84f3ecaca0845a080bebd3c72c4618 Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Sat, 11 Jan 2020 21:29:28 +0100 Subject: [PATCH 15/16] cleaner way to put msg --- conn.go | 44 +++++++++++++------------------------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/conn.go b/conn.go index 6a552d8..34cf0b9 100644 --- a/conn.go +++ b/conn.go @@ -17,7 +17,6 @@ package nftables import ( "fmt" "sync" - "sync/atomic" "github.com/google/nftables/expr" "github.com/mdlayher/netlink" @@ -41,7 +40,7 @@ type Conn struct { sync.Mutex put sync.Mutex messages []netlink.Message - entities map[int32]Entity + entities map[int]Entity it int32 err error } @@ -52,7 +51,6 @@ func (cc *Conn) Flush() error { defer func() { cc.messages = nil cc.entities = nil - cc.it = 0 cc.Unlock() }() if len(cc.messages) == 0 { @@ -71,7 +69,7 @@ func (cc *Conn) Flush() error { cc.endBatch(cc.messages) - _, err = conn.SendMessages(cc.messages[:cc.it+1]) + _, err = conn.SendMessages(cc.messages) if err != nil { return fmt.Errorf("SendMessages: %w", err) @@ -104,36 +102,29 @@ func (cc *Conn) Flush() error { } // PutMessage store netlink message to sent after -func (cc *Conn) PutMessage(msg netlink.Message) int32 { +func (cc *Conn) PutMessage(msg netlink.Message) int { cc.put.Lock() defer cc.put.Unlock() if cc.messages == nil { - cc.messages = make([]netlink.Message, 16) - cc.messages[0] = netlink.Message{ + cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), Flags: netlink.Request, }, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - } - } - - i := atomic.AddInt32(&cc.it, 1) - - if len(cc.messages) <= int(i) { - cc.messages = resize(cc.messages) + }) } - cc.messages[i] = msg + cc.messages = append(cc.messages, msg) - return i + return len(cc.messages) - 1 } // PutEntity store entity to relate to netlink response -func (cc *Conn) PutEntity(i int32, e Entity) { +func (cc *Conn) PutEntity(i int, e Entity) { if cc.entities == nil { - cc.entities = make(map[int32]Entity) + cc.entities = make(map[int]Entity) } cc.entities[i] = e } @@ -214,23 +205,14 @@ func (cc *Conn) marshalExpr(e expr.Any) []byte { func (cc *Conn) endBatch(messages []netlink.Message) { - i := atomic.AddInt32(&cc.it, 1) - - if len(cc.messages) <= int(i) { - cc.messages = resize(cc.messages) - } + cc.put.Lock() + defer cc.put.Unlock() - cc.messages[i] = netlink.Message{ + cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - } -} - -func resize(messages []netlink.Message) []netlink.Message { - new := make([]netlink.Message, cap(messages)*2) - copy(new, messages) - return new + }) } From 0f16d393b29063e205c54cdea61e0e8ca7d4bc2c Mon Sep 17 00:00:00 2001 From: Alexis PIRES Date: Wed, 15 Jan 2020 11:38:46 +0100 Subject: [PATCH 16/16] refactoring conn.go --- chain.go | 6 +++--- conn.go | 40 ++++++++++++++++++++++------------------ nftables_test.go | 1 - obj.go | 2 +- rule.go | 11 ++++++----- set.go | 12 ++++++------ table.go | 6 +++--- 7 files changed, 41 insertions(+), 37 deletions(-) diff --git a/chain.go b/chain.go index 9b77640..48ebadf 100644 --- a/chain.go +++ b/chain.go @@ -123,7 +123,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain { {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, })...) } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -144,7 +144,7 @@ func (cc *Conn) DelChain(c *Chain) { {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), Flags: netlink.Request | netlink.Acknowledge, @@ -162,7 +162,7 @@ func (cc *Conn) FlushChain(c *Chain) { {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge, diff --git a/conn.go b/conn.go index 34cf0b9..533336d 100644 --- a/conn.go +++ b/conn.go @@ -35,14 +35,13 @@ type Entity interface { // // Commands are buffered. Flush sends all buffered commands in a single batch. type Conn struct { - TestDial nltest.Func // for testing only; passed to nltest.Dial - NetNS int // Network namespace netlink will interact with. sync.Mutex - put sync.Mutex - messages []netlink.Message - entities map[int]Entity - it int32 - err error + TestDial nltest.Func // for testing only; passed to nltest.Dial + NetNS int // Network namespace netlink will interact with. + entities map[int]Entity + messagesMu sync.Mutex + messages []netlink.Message + err error } // Flush sends all buffered commands in a single batch to nftables. @@ -69,9 +68,7 @@ func (cc *Conn) Flush() error { cc.endBatch(cc.messages) - _, err = conn.SendMessages(cc.messages) - - if err != nil { + if _, err = conn.SendMessages(cc.messages); err != nil { return fmt.Errorf("SendMessages: %w", err) } @@ -83,9 +80,12 @@ func (cc *Conn) Flush() error { // Trigger entities callback msg, err := cc.checkReceive(conn) + if err != nil { + return err + } + for msg { rmsg, err := conn.Receive() - if err != nil { return fmt.Errorf("Receive: %w", err) } @@ -93,18 +93,22 @@ func (cc *Conn) Flush() error { for _, msg := range rmsg { if e, ok := entitiesBySeq[msg.Header.Sequence]; ok { e.HandleResponse(msg) + } } msg, err = cc.checkReceive(conn) + if err != nil { + return err + } } return err } -// PutMessage store netlink message to sent after -func (cc *Conn) PutMessage(msg netlink.Message) int { - cc.put.Lock() - defer cc.put.Unlock() +// putMessage store netlink message to sent after +func (cc *Conn) putMessage(msg netlink.Message) int { + cc.messagesMu.Lock() + defer cc.messagesMu.Unlock() if cc.messages == nil { cc.messages = append(cc.messages, netlink.Message{ @@ -162,7 +166,7 @@ func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) { func (cc *Conn) FlushRuleset() { cc.Lock() defer cc.Unlock() - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -205,8 +209,8 @@ func (cc *Conn) marshalExpr(e expr.Any) []byte { func (cc *Conn) endBatch(messages []netlink.Message) { - cc.put.Lock() - defer cc.put.Unlock() + cc.messagesMu.Lock() + defer cc.messagesMu.Unlock() cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ diff --git a/nftables_test.go b/nftables_test.go index c90d157..bc966d6 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -4008,7 +4008,6 @@ func TestIntegrationAddRule(t *testing.T) { c.Flush() execN := func(w int, n int) { - c := &nftables.Conn{NetNS: int(newNS)} for i := 0; i < n; i++ { diff --git a/obj.go b/obj.go index d3528f8..99d51e0 100644 --- a/obj.go +++ b/obj.go @@ -43,7 +43,7 @@ func (cc *Conn) AddObj(o Obj) Obj { return nil } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, diff --git a/rule.go b/rule.go index 9ca4168..d878b5e 100644 --- a/rule.go +++ b/rule.go @@ -130,7 +130,7 @@ func (cc *Conn) AddRule(r *Rule) *Rule { Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), } - i := cc.PutMessage(m) + i := cc.putMessage(m) cc.PutEntity(i, r) return r @@ -152,7 +152,7 @@ func (cc *Conn) DelRule(r *Rule) error { })...) flags := netlink.Request | netlink.Acknowledge - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: flags, @@ -166,10 +166,11 @@ func (cc *Conn) DelRule(r *Rule) error { // HandleResponse retrieves Handle in netlink response func (r *Rule) HandleResponse(msg netlink.Message) { rule, err := ruleFromMsg(msg) - - if err == nil { - r.Handle = rule.Handle + if err != nil { + return } + + r.Handle = rule.Handle } func exprsFromMsg(b []byte) ([]expr.Any, error) { diff --git a/set.go b/set.go index 4fa283d..f45e0be 100644 --- a/set.go +++ b/set.go @@ -165,7 +165,7 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -327,7 +327,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")}) } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -342,7 +342,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { if err != nil { return err } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -362,7 +362,7 @@ func (cc *Conn) DelSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), Flags: netlink.Request | netlink.Acknowledge, @@ -383,7 +383,7 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -402,7 +402,7 @@ func (cc *Conn) FlushSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge, diff --git a/table.go b/table.go index 9b47f1f..08c83f7 100644 --- a/table.go +++ b/table.go @@ -53,7 +53,7 @@ func (cc *Conn) DelTable(t *Table) { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge, @@ -71,7 +71,7 @@ func (cc *Conn) AddTable(t *Table) *Table { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -89,7 +89,7 @@ func (cc *Conn) FlushTable(t *Table) { data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, }) - cc.PutMessage(netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge,