Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the rule handle after flush #88

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
55 changes: 52 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ import (
"golang.org/x/sys/unix"
)

type Entity interface {
HandleResponse(netlink.Message)
}

// A Conn represents a netlink connection of the nftables family.
//
// All methods return their input, so that variables can be defined from string
Expand All @@ -35,6 +39,7 @@ type Conn struct {
NetNS int // Network namespace netlink will interact with.
sync.Mutex
messages []netlink.Message
entities map[int]Entity
err error
}

Expand All @@ -43,6 +48,7 @@ func (cc *Conn) Flush() error {
cc.Lock()
defer func() {
cc.messages = nil
cc.entities = nil
cc.Unlock()
}()
if len(cc.messages) == 0 {
Expand All @@ -59,17 +65,60 @@ func (cc *Conn) Flush() error {

defer conn.Close()

if _, err := conn.SendMessages(batch(cc.messages)); err != nil {
smsg, err := conn.SendMessages(batch(cc.messages))

if err != nil {
return fmt.Errorf("SendMessages: %w", err)
}

if _, err := conn.Receive(); err != nil {
return fmt.Errorf("Receive: %w", err)
// Retrieving of seq number associated to entities
entitiesBySeq := make(map[uint32]Entity)
for i, e := range cc.entities {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the separate entites book-keeping, why not just iterate over cc.messages and check if the type implements HandleResponse?

for _, msg := range cc.messages {
  m, ok := msg.(interface { HandleResponse(netlink.Message) })
  if !ok {
    continue
  }
  m.HandleResponse(rmsg)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HandleResponse is not a method that belong to netlink.Message. It's belong to entity.
For a given entity, it's perform some operations on it by using the netlink response. We have to connect an entity (like a Rule) and a netlink message in case of the slice of responses is unordered. I connect them with the netlink sequence.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. I’m still not particularly happy about having to manage messages and entities, which are separate but connected.

Instead of the current approach, we could make messages take an interface type which has a NetlinkMessage method. putMessages would then wrap a netlink.Message in a type which implements NetlinkMessage, and can optionally implement HandleResponse as well.

Copy link
Contributor Author

@alexispires alexispires Jan 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stapelberg I'm not sure to understand what do you have in mind but the main limit I have is that the SendMessages method of netlink library take a slice of netlink.Message and not a slice of netlink.Message's pointer. So the only data struct that give me the netlink sequence number is the slice of netlink.Message returned by sendMessages. So the slice of netlink.Message have to be connected in one way or another to something else. In your solution I have to connect the type that wrap netlink.Message with the netlink.Message that contains the seq even if both are the same netlink message.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m a little confused now: you’re saying the sequence numbers are only set in the []netlink.Message that’s returned by (*netlink.Conn).SendMessages, but in your PR, you’re discarding that result:

if _, err = conn.SendMessages(cc.messages); err != nil {

Is that a bug in your PR, or am I misunderstanding?

(Independently, related to your question: the fact that the netlink package takes a []netlink.Message instead of []*netlink.Message doesn’t limit us in which data structures we want to use. If need be, we can do a copy from one to the other. A netlink.Message’s Header is just a few ints, and copying the Body byte slice won’t need to copy its contents.)

entitiesBySeq[smsg[i].Header.Sequence] = e
}

// Trigger entities callback
for cc.checkReceive(conn) {
rmsg, err := conn.Receive()

alexispires marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return fmt.Errorf("Receive: %w", err)
}

for _, msg := range rmsg {
if e, ok := entitiesBySeq[msg.Header.Sequence]; ok {
e.HandleResponse(msg)
}
}
}

return nil
}

func (cc *Conn) checkReceive(c *netlink.Conn) bool {
alexispires marked this conversation as resolved.
Show resolved Hide resolved
if cc.TestDial != nil {
return false
}

sc, err := c.SyscallConn()
alexispires marked this conversation as resolved.
Show resolved Hide resolved

var n int

sc.Control(func(fd uintptr) {
alexispires marked this conversation as resolved.
Show resolved Hide resolved
var fdSet unix.FdSet
fdSet.Zero()
fdSet.Set(int(fd))

n, err = unix.Select(int(fd)+1, &fdSet, nil, nil, &unix.Timeval{})
})

if err == nil && n > 0 {
return true
}

return false
}

// FlushRuleset flushes the entire ruleset. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level
func (cc *Conn) FlushRuleset() {
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ go 1.12

require (
github.com/koneu/natend v0.0.0-20150829182554-ec0926ea948d
github.com/mdlayher/netlink v0.0.0-20191009155606-de872b0d824b
github.com/mdlayher/netlink v1.0.0
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc
golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271 // indirect
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c
golang.org/x/sys v0.0.0-20200106114638-5f8ca72cd632
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ github.com/koneu/natend v0.0.0-20150829182554-ec0926ea948d/go.mod h1:QHb4k4cr1fQ
github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA=
github.com/mdlayher/netlink v0.0.0-20191009155606-de872b0d824b h1:W3er9pI7mt2gOqOWzwvx20iJ8Akiqz1mUMTxU6wdvl8=
github.com/mdlayher/netlink v0.0.0-20191009155606-de872b0d824b/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M=
github.com/mdlayher/netlink v1.0.0 h1:vySPY5Oxnn/8lxAPn2cK6kAzcZzYJl3KriSLO46OT18=
github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4=
github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand All @@ -25,4 +27,6 @@ golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c h1:S/FtSvpNLtFBgjTqcKsRpsa6a
golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191112214154-59a1497f0cea h1:Mz1TMnfJDRJLk8S8OPCoJYgrsp/Se/2TBre2+vwX128=
golang.org/x/sys v0.0.0-20191113150313-8ad342257130 h1:+sdNBpwFF05NvMnEyGynbOs/Gr2LQwORWEPKXuEXxzU=
golang.org/x/sys v0.0.0-20200106114638-5f8ca72cd632 h1:ateQkYCVYo8UwIBvoR3zj1Dh2K6Op/n3GxemXfB44/Y=
golang.org/x/sys v0.0.0-20200106114638-5f8ca72cd632/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
119 changes: 119 additions & 0 deletions nftables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3980,3 +3980,122 @@ func TestStatelessNAT(t *testing.T) {
t.Fatal(err)
}
}

func TestHandleBack(t *testing.T) {

if os.Getenv("TRAVIS") == "true" {
t.SkipNow()
}

// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()

filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})

chain1 := c.AddChain(&nftables.Chain{
Name: "chain1",
Table: filter,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityFilter,
})

chain2 := c.AddChain(&nftables.Chain{
Name: "chain2",
Table: filter,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityFilter,
})

var rulesCreated1 []*nftables.Rule
var rulesCreated2 []*nftables.Rule

alexispires marked this conversation as resolved.
Show resolved Hide resolved
rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{
Table: filter,
Chain: chain1,
Exprs: []expr.Any{
&expr.Verdict{
// [ immediate reg 0 drop ]
Kind: expr.VerdictDrop,
},
},
}))

rulesCreated1 = append(rulesCreated1, c.AddRule(&nftables.Rule{
Table: filter,
Chain: chain1,
Exprs: []expr.Any{
&expr.Verdict{
// [ immediate reg 0 drop ]
Kind: expr.VerdictDrop,
},
},
}))

rulesCreated2 = append(rulesCreated2, c.AddRule(&nftables.Rule{
Table: filter,
Chain: chain2,
Exprs: []expr.Any{
&expr.Verdict{
// [ immediate reg 0 drop ]
Kind: expr.VerdictDrop,
},
},
}))

for i, r := range rulesCreated1 {
if r.Handle != 0 {
t.Fatalf("unexpected handle value at %d", i)
}
}

for i, r := range rulesCreated2 {
if r.Handle != 0 {
t.Fatalf("unexpected handle value at %d", i)
}
}

if err := c.Flush(); err != nil {
t.Fatal(err)
}

rulesGetted1, _ := c.GetRule(filter, chain1)
rulesGetted2, _ := c.GetRule(filter, chain2)

if len(rulesGetted1) != len(rulesCreated1) {
t.Fatalf("Bad ruleset length got %d want %d", len(rulesGetted1), len(rulesCreated1))
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this test running many workers concurrently? In other words: why is not sufficient to test with 1 worker?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To test the behaviors with concurrency as exprimed here: #88 (comment)
IMO it's safer to keep it to identify regression on concurent access. But it's not specific on this part of lib, I think concurrency have to be tested on the whole lib.

if len(rulesGetted2) != len(rulesCreated2) {
t.Fatalf("Bad ruleset length got %d want %d", len(rulesGetted2), len(rulesCreated2))
}

for i, r := range rulesGetted1 {
if r.Handle == 0 {
t.Fatalf("handle value is empty at %d", i)
}

if r.Handle != rulesCreated1[i].Handle {
t.Fatalf("mismatched handle at %d", i)
}
}

for i, r := range rulesGetted2 {
if r.Handle == 0 {
t.Fatalf("handle value is empty at %d", i)
}

if r.Handle != rulesCreated2[i].Handle {
t.Fatalf("mismatched handle at %d", i)
}
}
}
15 changes: 15 additions & 0 deletions rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ func (cc *Conn) AddRule(r *Rule) *Rule {
Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...),
})

if cc.entities == nil {
cc.entities = make(map[int]Entity)
}

cc.entities[len(cc.messages)] = r
alexispires marked this conversation as resolved.
Show resolved Hide resolved

return r
}

Expand Down Expand Up @@ -160,6 +166,15 @@ func (cc *Conn) DelRule(r *Rule) error {
return nil
}

// HandleResponse retrieves Handle in netlink response
func (r *Rule) HandleResponse(msg netlink.Message) {
rule, err := ruleFromMsg(msg)

alexispires marked this conversation as resolved.
Show resolved Hide resolved
if err == nil {
r.Handle = rule.Handle
alexispires marked this conversation as resolved.
Show resolved Hide resolved
}
}

func exprsFromMsg(b []byte) ([]expr.Any, error) {
ad, err := netlink.NewAttributeDecoder(b)
if err != nil {
Expand Down