Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add select #160

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"golang.org/x/net/bpf"
"golang.org/x/sys/unix"
)

// A Conn is a connection to netlink. A Conn can be used to send and
Expand Down Expand Up @@ -49,6 +50,7 @@ type Socket interface {
Send(m Message) error
SendMessages(m []Message) error
Receive() ([]Message, error)
Select(tv *unix.Timeval) (int, error)
}

// Dial dials a connection to netlink, using the specified netlink family.
Expand Down Expand Up @@ -319,6 +321,11 @@ func (c *Conn) receive() ([]Message, error) {
}
}

// Select allow to check whether netlink messages are available
func (c *Conn) Select(tv *unix.Timeval) (int, error) {
return c.sock.Select(tv)
}

// A groupJoinLeaver is a Socket that supports joining and leaving
// netlink multicast groups.
type groupJoinLeaver interface {
Expand Down
17 changes: 17 additions & 0 deletions conn_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type socket interface {
Close() error
FD() int
File() *os.File
Select(tv *unix.Timeval) (int, error)
Getsockname() (unix.Sockaddr, error)
Recvmsg(p, oob []byte, flags int) (n int, oobn int, recvflags int, from unix.Sockaddr, err error)
Sendmsg(p, oob []byte, to unix.Sockaddr, flags int) error
Expand Down Expand Up @@ -184,6 +185,11 @@ func (c *conn) Receive() ([]Message, error) {
return msgs, nil
}

// Select allow to check whether netlink messages are available
func (c *conn) Select(tv *unix.Timeval) (int, error) {
return c.s.Select(tv)
}

// Close closes the connection.
func (c *conn) Close() error {
return os.NewSyscallError("close", c.s.Close())
Expand Down Expand Up @@ -517,6 +523,17 @@ func (s *sysSocket) FD() int { return int(s.fd.Fd()) }

func (s *sysSocket) File() *os.File { return s.fd }

func (s *sysSocket) Select(tv *unix.Timeval) (int, error) {

var fdSet unix.FdSet
fdSet.Zero()
fdSet.Set(s.FD())

n, err := unix.Select(s.FD()+1, &fdSet, nil, nil, tv)

return n, err
}

func (s *sysSocket) Getsockname() (unix.Sockaddr, error) {
var (
sa unix.Sockaddr
Expand Down
23 changes: 23 additions & 0 deletions conn_linux_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,29 @@ func TestIntegrationConnConcurrentManyConns(t *testing.T) {
panicf("unexpected number of reply messages: %d", l)
}
}

nsel, err := c.Select(&unix.Timeval{})
if err != nil {
t.Fatalf("failed to execute select: %v", err)
}

if nsel > 0 {
t.Fatalf("expected no messages")
}

_, err = c.Send(req)
if err != nil {
panicf("failed to send request: %v", err)
}

nsel, err = c.Select(&unix.Timeval{})
if err != nil {
t.Fatalf("failed to execute select: %v", err)
}

if nsel != 1 {
t.Fatalf("expected messages got none")
}
}

const (
Expand Down
31 changes: 31 additions & 0 deletions conn_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ func TestLinuxConnReceive(t *testing.T) {
s.recvmsg.p = resb
s.recvmsg.from = from

n, err := c.Select(&unix.Timeval{})
if err != nil {
t.Fatalf("failed to execute: %v", err)
}

if n != 1 {
t.Fatalf("expected messages")
}

msgs, err := c.Receive()
if err != nil {
t.Fatalf("failed to receive messages: %v", err)
Expand Down Expand Up @@ -186,6 +195,20 @@ func TestLinuxConnReceive(t *testing.T) {
}
}

func TestLinuxConnReceiveNoMessage(t *testing.T) {

c, _ := testLinuxConn(t, nil)

n, err := c.Select(&unix.Timeval{})
if err != nil {
t.Fatalf("failed to execute: %v", err)
}

if n > 0 {
t.Fatalf("expected no messages")
}
}

func TestLinuxConnReceiveLargeMessage(t *testing.T) {
n := os.Getpagesize() * 4

Expand Down Expand Up @@ -633,6 +656,14 @@ func (s *testSocket) Getsockname() (unix.Sockaddr, error) {
return s.getsockname, s.getsocknameErr
}

func (s *testSocket) Select(tv *unix.Timeval) (int, error) {
if len(s.recvmsg.p) > 0 {
return 1, nil
}

return 0, nil
}

func (s *testSocket) Recvmsg(p, oob []byte, flags int) (int, int, int, unix.Sockaddr, error) {
s.recvmsg.flags = append(s.recvmsg.flags, flags)
n := copy(p, s.recvmsg.p)
Expand Down
11 changes: 7 additions & 4 deletions conn_others.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package netlink
import (
"fmt"
"runtime"

"golang.org/x/sys/unix"
)

// errUnimplemented is returned by all functions on platforms that
Expand All @@ -23,7 +25,8 @@ type conn struct{}
func dial(_ int, _ *Config) (*conn, uint32, error) { return nil, 0, errUnimplemented }
func newError(_ int) error { return errUnimplemented }

func (c *conn) Send(_ Message) error { return errUnimplemented }
func (c *conn) SendMessages(_ []Message) error { return errUnimplemented }
func (c *conn) Receive() ([]Message, error) { return nil, errUnimplemented }
func (c *conn) Close() error { return errUnimplemented }
func (c *conn) Send(_ Message) error { return errUnimplemented }
func (c *conn) SendMessages(_ []Message) error { return errUnimplemented }
func (c *conn) Receive() ([]Message, error) { return nil, errUnimplemented }
func (c *conn) Close() error { return errUnimplemented }
func (c *conn) Select(_ *unix.Timeval) (int, error) { return -1, errUnimplemented }
11 changes: 10 additions & 1 deletion conn_others_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

package netlink

import "testing"
import (
"testing"

"golang.org/x/sys/unix"
)

func TestOthersConnUnimplemented(t *testing.T) {
c := &conn{}
Expand Down Expand Up @@ -37,4 +41,9 @@ func TestOthersConnUnimplemented(t *testing.T) {
t.Fatalf("unexpected error during c.Close:\n- want: %v\n- got: %v",
want, got)
}

if _, got := c.Select(&unix.Timeval{}); want != got {
t.Fatalf("unexpected error during c.Select:\n- want: %v\n- got: %v",
want, got)
}
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ require (
github.com/google/go-cmp v0.3.1
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456
golang.org/x/sys v0.0.0-20200103143344-a1369afcdac7
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456 h1:ng0gs1AKnRRuEMZoTLLlbOd+C17zUDepwGQBb/n+JVg=
golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200103143344-a1369afcdac7 h1:/W9OPMnnpmFXHYkcp2rQsbFUbRlRzfECQjmAFiOyHE8=
golang.org/x/sys v0.0.0-20200103143344-a1369afcdac7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
11 changes: 11 additions & 0 deletions nltest/nltest.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nlenc"
"golang.org/x/sys/unix"
)

// PID is the netlink header PID value assigned by nltest.
Expand Down Expand Up @@ -204,6 +205,16 @@ func (c *socket) Receive() ([]netlink.Message, error) {
return msgs, err
}

func (c *socket) Select(tv *unix.Timeval) (int, error) {
msgs, _ := c.fn(nil)

if len(c.msgs) > 0 || len(msgs) > 0 {
return 1, nil
}

return 0, nil
}

func panicf(format string, a ...interface{}) {
panic(fmt.Sprintf(format, a...))
}
19 changes: 19 additions & 0 deletions nltest/nltest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nltest"
"golang.org/x/sys/unix"
)

func TestConnSend(t *testing.T) {
Expand Down Expand Up @@ -54,6 +55,15 @@ func TestConnReceiveMulticast(t *testing.T) {
})
defer c.Close()

n, err := c.Select(&unix.Timeval{})
if err != nil {
t.Fatalf("failed to execute: %v", err)
}

if n != 1 {
t.Fatalf("expected messages")
}

got, err := c.Receive()
if err != nil {
t.Fatalf("failed to receive messages: %v", err)
Expand All @@ -71,6 +81,15 @@ func TestConnReceiveNoMessages(t *testing.T) {
})
defer c.Close()

n, err := c.Select(&unix.Timeval{})
if err != nil {
t.Fatalf("failed to execute: %v", err)
}

if n > 0 {
t.Fatalf("expected no messages")
}

msgs, err := c.Receive()
if err != nil {
t.Fatalf("failed to execute: %v", err)
Expand Down