diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index 2be0b7553c..91b7bf2516 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -157,3 +157,8 @@ type NetlinkErrorMessage struct { Error int32 Header NetlinkMessageHeader } + +// RTNetlink multicast groups, from uapi/linux/rtnetlink.h. +const ( + RTNLGRP_LINK = 1 +) diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 035dc3b099..067656ac70 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -26,6 +26,13 @@ declare_mutex( prefix = "abstractSocketNamespace", ) +declare_mutex( + name = "nlmcast_table_mutex", + out = "nlmcast_table_mutex.go", + package = "inet", + prefix = "nlmcastTable", +) + go_library( name = "inet", srcs = [ @@ -35,6 +42,8 @@ go_library( "inet.go", "namespace.go", "namespace_refs.go", + "nlmcast.go", + "nlmcast_table_mutex.go", "test_stack.go", ], deps = [ diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index b5e4dadd9d..1166096329 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -32,7 +32,7 @@ type Stack interface { Interfaces() map[int32]Interface // RemoveInterface removes the specified network interface. - RemoveInterface(idx int32) error + RemoveInterface(ctx context.Context, idx int32) error // InterfaceAddrs returns all network interface addresses as a mapping from // interface indexes to a slice of associated interface address properties. diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go index 0ffeee03d6..d2a4bf4618 100644 --- a/pkg/sentry/inet/namespace.go +++ b/pkg/sentry/inet/namespace.go @@ -45,6 +45,9 @@ type Namespace struct { // abstractSockets tracks abstract sockets that are in use. abstractSockets AbstractSocketNamespace + + // netlinkMcastTable manages multicast group membership for netlink sockets. + netlinkMcastTable *McastTable } // NewRootNamespace creates the root network namespace, with creator @@ -52,10 +55,14 @@ type Namespace struct { // networking will function if the network is namespaced. func NewRootNamespace(stack Stack, creator NetworkStackCreator, userNS *auth.UserNamespace) *Namespace { n := &Namespace{ - stack: stack, - creator: creator, - isRoot: true, - userNS: userNS, + stack: stack, + creator: creator, + isRoot: true, + userNS: userNS, + netlinkMcastTable: NewNetlinkMcastTable(), + } + if eventPublishingStack, ok := stack.(InterfaceEventPublisher); ok { + eventPublishingStack.AddInterfaceEventSubscriber(n.netlinkMcastTable) } n.abstractSockets.init() return n @@ -79,8 +86,9 @@ func (n *Namespace) GetInode() *nsfs.Inode { // NewNamespace creates a new network namespace from the root. func NewNamespace(root *Namespace, userNS *auth.UserNamespace) *Namespace { n := &Namespace{ - creator: root.creator, - userNS: userNS, + creator: root.creator, + userNS: userNS, + netlinkMcastTable: NewNetlinkMcastTable(), } n.init() return n @@ -148,6 +156,9 @@ func (n *Namespace) init() { if err != nil { panic(err) } + if eventPublishingStack, ok := n.stack.(InterfaceEventPublisher); ok { + eventPublishingStack.AddInterfaceEventSubscriber(n.netlinkMcastTable) + } } n.abstractSockets.init() } @@ -162,6 +173,11 @@ func (n *Namespace) AbstractSockets() *AbstractSocketNamespace { return &n.abstractSockets } +// NetlinkMcastTable returns the netlink multicast group table. +func (n *Namespace) NetlinkMcastTable() *McastTable { + return n.netlinkMcastTable +} + // NetworkStackCreator allows new instances of a network stack to be created. It // is used by the kernel to create new network namespaces when requested. type NetworkStackCreator interface { diff --git a/pkg/sentry/inet/nlmcast.go b/pkg/sentry/inet/nlmcast.go new file mode 100644 index 0000000000..49d8614233 --- /dev/null +++ b/pkg/sentry/inet/nlmcast.go @@ -0,0 +1,142 @@ +// Copyright 2025 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inet + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" +) + +const ( + routeProtocol = linux.NETLINK_ROUTE + routeLinkMcastGroup = linux.RTNLGRP_LINK +) + +// InterfaceEventSubscriber allows clients to subscribe to events published by an inet.Stack. +// +// It is a rough parallel to the objects in Linux that subscribe to netdev +// events by calling register_netdevice_notifier(). +type InterfaceEventSubscriber interface { + // OnInterfaceChangeEvent is called by InterfaceEventPublishers when an interface event takes place. + OnInterfaceChangeEvent(ctx context.Context, idx int32, i Interface) + + // OnInterfaceDeleteEvent is called by InterfaceEventPublishers when an interface event takes place. + OnInterfaceDeleteEvent(ctx context.Context, idx int32, i Interface) +} + +// InterfaceEventPublisher is the interface event publishing aspect of an inet.Stack. +// +// The Linux parallel is how it notifies subscribers via call_netdev_notifiers(). +type InterfaceEventPublisher interface { + AddInterfaceEventSubscriber(sub InterfaceEventSubscriber) +} + +// NetlinkSocket corresponds to a netlink socket. +type NetlinkSocket interface { + // Protocol returns the netlink protocol value. + Protocol() int + + // Groups returns the bitmap of multicast groups the socket is bound to. + Groups() uint64 + + // HandleInterfaceChangeEvent is called on NetlinkSockets that are members of the RTNLGRP_LINK + // multicast group when an interface is modified. + HandleInterfaceChangeEvent(context.Context, int32, Interface) + + // HandleInterfaceDeleteEvent is called on NetlinkSockets that are members of the RTNLGRP_LINK + // multicast group when an interface is deleted. + HandleInterfaceDeleteEvent(context.Context, int32, Interface) +} + +// McastTable holds multicast group membership information for netlink netlinkSocket. +// It corresponds roughly to Linux's struct netlink_table. +// +// +stateify savable +type McastTable struct { + mu nlmcastTableMutex `state:"nosave"` + socks map[int]map[NetlinkSocket]struct{} +} + +// WithTableLocked runs fn with the table mutex held. +func (m *McastTable) WithTableLocked(fn func()) { + m.mu.Lock() + defer m.mu.Unlock() + fn() +} + +// AddSocket adds a netlinkSocket to the multicast-group table. +// +// Preconditions: the netlink multicast table is locked. +func (m *McastTable) AddSocket(s NetlinkSocket) { + p := s.Protocol() + if _, ok := m.socks[p]; !ok { + m.socks[p] = make(map[NetlinkSocket]struct{}) + } + if _, ok := m.socks[p][s]; ok { + return + } + m.socks[p][s] = struct{}{} +} + +// RemoveSocket removes a netlinkSocket from the multicast-group table. +// +// Preconditions: the netlink multicast table is locked. +func (m *McastTable) RemoveSocket(s NetlinkSocket) { + p := s.Protocol() + if _, ok := m.socks[p]; !ok { + return + } + if _, ok := m.socks[p][s]; !ok { + return + } + delete(m.socks[p], s) +} + +func (m *McastTable) forEachMcastSock(protocol int, mcastGroup int, fn func(s NetlinkSocket)) { + m.mu.Lock() + defer m.mu.Unlock() + if _, ok := m.socks[protocol]; !ok { + return + } + for s := range m.socks[protocol] { + if s.Groups()&(1<<(mcastGroup-1)) == 0 { + return + } + fn(s) + } +} + +// OnInterfaceChangeEvent implements InterfaceEventSubscriber.OnInterfaceChangeEvent. +func (m *McastTable) OnInterfaceChangeEvent(ctx context.Context, idx int32, i Interface) { + // Relay the event to RTNLGRP_LINK subscribers. + m.forEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) { + s.HandleInterfaceChangeEvent(ctx, idx, i) + }) +} + +// OnInterfaceDeleteEvent implements InterfaceEventSubscriber.OnInterfaceDeleteEvent. +func (m *McastTable) OnInterfaceDeleteEvent(ctx context.Context, idx int32, i Interface) { + // Relay the event to RTNLGRP_LINK subscribers. + m.forEachMcastSock(routeProtocol, routeLinkMcastGroup, func(s NetlinkSocket) { + s.HandleInterfaceDeleteEvent(ctx, idx, i) + }) +} + +// NewNetlinkMcastTable creates a new McastTable. +func NewNetlinkMcastTable() *McastTable { + return &McastTable{ + socks: make(map[int]map[NetlinkSocket]struct{}), + } +} diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index d083aef3e0..894c3e96ab 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -61,7 +61,7 @@ func (s *TestStack) Destroy() { } // RemoveInterface implements Stack. -func (s *TestStack) RemoveInterface(idx int32) error { +func (s *TestStack) RemoveInterface(ctx context.Context, idx int32) error { delete(s.InterfacesMap, idx) return nil } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 5419d991e5..72f23b8db0 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -152,7 +152,7 @@ func (s *Stack) Interfaces() map[int32]inet.Interface { } // RemoveInterface implements inet.Stack.RemoveInterface. -func (*Stack) RemoveInterface(idx int32) error { +func (*Stack) RemoveInterface(ctx context.Context, idx int32) error { return removeInterface(idx) } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 997e43728e..15399c9148 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/abi/linux/errno", + "//pkg/atomicbitops", "//pkg/context", "//pkg/errors/linuxerr", "//pkg/hostarch", diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index df302cd827..bf895439d2 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/nlmsg" @@ -51,6 +52,18 @@ type Protocol interface { ProcessMessage(ctx context.Context, s *Socket, msg *nlmsg.Message, ms *nlmsg.MessageSet) *syserr.Error } +// RouteProtocol corresponds to the NETLINK_ROUTE family. +type RouteProtocol interface { + Protocol + + // AddNewLinkMessage is called when an interface is mutated or created by the stack. + // It is the rough equivalent of Linux's rtnetlink_event(). + AddNewLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) + + // AddDelLinkMessage is called when an interface is deleted by the stack. + AddDelLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) +} + // Provider is a function that creates a new Protocol for a specific netlink // protocol. // diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 34ec373154..fa1b0f546c 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -101,7 +101,7 @@ func (p *Protocol) dumpLinks(ctx context.Context, s *netlink.Socket, msg *nlmsg. } for idx, i := range stack.Interfaces() { - addNewLinkMessage(ms, idx, i) + p.AddNewLinkMessage(ms, idx, i) } return nil @@ -158,7 +158,7 @@ func (p *Protocol) getLink(ctx context.Context, s *netlink.Socket, msg *nlmsg.Me return syserr.ErrInvalidArgument } - addNewLinkMessage(ms, idx, i) + p.AddNewLinkMessage(ms, idx, i) found = true break } @@ -232,16 +232,10 @@ func (p *Protocol) delLink(ctx context.Context, s *netlink.Socket, msg *nlmsg.Me return syserr.ErrNoDevice } } - return syserr.FromError(stack.RemoveInterface(ifinfomsg.Index)) + return syserr.FromError(stack.RemoveInterface(ctx, ifinfomsg.Index)) } -// addNewLinkMessage appends RTM_NEWLINK message for the given interface into -// the message set. -func addNewLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) { - m := ms.AddMessage(linux.NetlinkMessageHeader{ - Type: linux.RTM_NEWLINK, - }) - +func writeLinkInfo(m *nlmsg.Message, idx int32, i inet.Interface) { m.Put(&linux.InterfaceInfoMessage{ Family: linux.AF_UNSPEC, Type: i.DeviceType, @@ -264,6 +258,26 @@ func addNewLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) { // TODO(gvisor.dev/issue/578): There are many more attributes. } +// AddNewLinkMessage appends an RTM_NEWLINK message for the given interface into +// the message set. +// AddNewLinkMessage implements netlink.RouteProtocol.AddNewLinkMessage. +func (p *Protocol) AddNewLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.RTM_NEWLINK, + }) + writeLinkInfo(m, idx, i) +} + +// AddDelLinkMessage appends an RTM_DELLINK message for the given interface into +// the message set. +// AddDelLinkMessage implements netlink.RouteProtocol.AddDelLinkMessage. +func (p *Protocol) AddDelLinkMessage(ms *nlmsg.MessageSet, idx int32, i inet.Interface) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.RTM_DELLINK, + }) + writeLinkInfo(m, idx, i) +} + // dumpAddrs handles RTM_GETADDR dump requests. func (p *Protocol) dumpAddrs(ctx context.Context, s *netlink.Socket, msg *nlmsg.Message, ms *nlmsg.MessageSet) *syserr.Error { // RTM_GETADDR dump requests need not contain anything more than the diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 60403b04db..b6a98d2c56 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -16,12 +16,14 @@ package netlink import ( + "fmt" "io" "math" "time" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/abi/linux/errno" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" @@ -55,6 +57,9 @@ const ( // maxBufferSize is the largest size a send buffer can grow to. maxSendBufferSize = 4 << 20 // 4MB + + // supportedGroups is the set of multicast groups that are supported. + supportedGroups = 1 << (linux.RTNLGRP_LINK - 1) ) var errNoFilter = syserr.New("no filter attached", errno.ENOENT) @@ -92,6 +97,13 @@ type Socket struct { // sent to userspace. connection transport.ConnectedEndpoint + // netns is the network namespace associated with the socket. + netns *inet.Namespace + + // groups is a bitmap of the set of multicast groups this socket is bound to. + // Writing to it requires the per-netns table lock to be held, reading it does not. + groups atomicbitops.Uint64 + // mu protects the fields below. mu sync.Mutex `state:"nosave"` @@ -110,13 +122,11 @@ type Socket struct { // TODO(gvisor.dev/issue/1119): We don't actually support filtering, // this is just bookkeeping for tracking add/remove. filter bool - - // netns is the network namespace associated with the socket. - netns *inet.Namespace } var _ socket.Socket = (*Socket)(nil) var _ transport.Credentialer = (*Socket)(nil) +var _ inet.NetlinkSocket = (*Socket)(nil) // New creates a new Socket. func New(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { @@ -157,6 +167,12 @@ func (s *Socket) Stack() inet.Stack { // Release implements vfs.FileDescriptionImpl.Release. func (s *Socket) Release(ctx context.Context) { + if s.groups.Load() != 0 { + s.netns.NetlinkMcastTable().WithTableLocked(func() { + s.netns.NetlinkMcastTable().RemoveSocket(s) + }) + } + t := kernel.TaskFromContext(ctx) t.Kernel().DeleteSocket(&s.vfsfd) s.connection.Release(ctx) @@ -304,6 +320,117 @@ func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error { return nil } +func (s *Socket) checkMcastSupport(t *kernel.Task) *syserr.Error { + // Currently only ROUTE family sockets support multicast. + if s.Protocol() != linux.NETLINK_ROUTE { + return syserr.ErrNotSupported + } + // Not all inet.Stacks relay interface events, currently only netstack/tcpip does. + if _, ok := s.Stack().(inet.InterfaceEventPublisher); !ok { + return syserr.ErrNotSupported + } + // man 7 netlink: "Only processes with an effective UID of 0 or the CAP_NET_ADMIN + // capability may send or listen to a netlink multicast group." + if !t.HasCapability(linux.CAP_NET_ADMIN) { + return syserr.ErrPermissionDenied + } + return nil +} + +// preconditions: the netlink multicast table is locked. +func (s *Socket) joinGroups(t *kernel.Task, groups uint64) *syserr.Error { + if groups&supportedGroups != groups { + return syserr.ErrNotSupported + } + if err := s.checkMcastSupport(t); err != nil { + return err + } + + oldGroups := s.groups.Load() + s.groups.Store(groups) + if oldGroups == 0 && s.groups.Load() != 0 { + s.netns.NetlinkMcastTable().AddSocket(s) + } else if oldGroups != 0 && s.groups.Load() == 0 { + s.netns.NetlinkMcastTable().RemoveSocket(s) + } + return nil +} + +// preconditions: the netlink multicast table is locked. +func (s *Socket) joinGroup(t *kernel.Task, group uint32) *syserr.Error { + if group == 0 || group > 64 { + return syserr.ErrInvalidArgument + } + groups := uint64(1) << (group - 1) + if groups&supportedGroups != groups { + return syserr.ErrNotSupported + } + if err := s.checkMcastSupport(t); err != nil { + return err + } + + oldGroups := s.groups.Load() + s.groups.Store(oldGroups | groups) + if oldGroups == 0 { + s.netns.NetlinkMcastTable().AddSocket(s) + } + return nil +} + +// preconditions: the netlink multicast table is locked. +func (s *Socket) leaveGroup(t *kernel.Task, group uint32) *syserr.Error { + if group == 0 || group > 64 { + return syserr.ErrInvalidArgument + } + groups := uint64(1) << (group - 1) + if groups&supportedGroups != groups { + return syserr.ErrNotSupported + } + if err := s.checkMcastSupport(t); err != nil { + return err + } + + s.groups.Store(s.groups.Load() &^ groups) + if s.groups.Load() == 0 { + s.netns.NetlinkMcastTable().RemoveSocket(s) + } + return nil +} + +// Protocol implements inet.NetlinkSocket.Protocol. +func (s *Socket) Protocol() int { + return s.protocol.Protocol() +} + +// Groups implements inet.NetlinkSocket.Groups. +func (s *Socket) Groups() uint64 { + return s.groups.Load() +} + +// HandleInterfaceChangeEvent implements inet.NetlinkSocket.HandleInterfaceChangeEvent. +func (s *Socket) HandleInterfaceChangeEvent(ctx context.Context, idx int32, i inet.Interface) { + routeProtocol, ok := s.protocol.(RouteProtocol) + if !ok { + panic(fmt.Sprintf("Non-ROUTE netlink socket (protocol %d) cannot handle interface events", s.Protocol())) + } + ms := nlmsg.NewMessageSet(s.portID, 0) + routeProtocol.AddNewLinkMessage(ms, idx, i) + // TODO(b/456238795): Implement netlink ENOBUFS. + s.SendResponse(ctx, ms) +} + +// HandleInterfaceDeleteEvent implements inet.NetlinkSocket.HandleInterfaceDeleteEvent. +func (s *Socket) HandleInterfaceDeleteEvent(ctx context.Context, idx int32, i inet.Interface) { + routeProtocol, ok := s.protocol.(RouteProtocol) + if !ok { + panic(fmt.Sprintf("Non-ROUTE netlink socket (protocol %d) cannot handle interface events", s.Protocol())) + } + ms := nlmsg.NewMessageSet(s.portID, 0) + routeProtocol.AddDelLinkMessage(ms, idx, i) + // TODO(b/456238795): Implement netlink ENOBUFS. + s.SendResponse(ctx, ms) +} + // Bind implements socket.Socket.Bind. func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { a, err := ExtractSockAddr(sockaddr) @@ -311,14 +438,18 @@ func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { return err } - // No support for multicast groups yet. if a.Groups != 0 { - return syserr.ErrPermissionDenied + var err *syserr.Error + s.netns.NetlinkMcastTable().WithTableLocked(func() { + err = s.joinGroups(t, uint64(a.Groups)) + }) + if err != nil { + return err + } } s.mu.Lock() defer s.mu.Unlock() - return s.bindPort(t, int32(a.PortID)) } @@ -329,7 +460,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr return err } - // No support for multicast groups yet. + // No support for sending to destination multicast groups yet. if a.Groups != 0 { return syserr.ErrPermissionDenied } @@ -417,13 +548,19 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch } case linux.SOL_NETLINK: switch name { + case linux.NETLINK_LIST_MEMBERSHIPS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + return primitive.AllocateUint64(s.groups.Load()), nil + case linux.NETLINK_BROADCAST_ERROR, linux.NETLINK_CAP_ACK, linux.NETLINK_DUMP_STRICT_CHK, linux.NETLINK_EXT_ACK, - linux.NETLINK_LIST_MEMBERSHIPS, linux.NETLINK_NO_ENOBUFS, linux.NETLINK_PKTINFO: + // Not supported. } } @@ -528,15 +665,36 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy } case linux.SOL_NETLINK: switch name { - case linux.NETLINK_ADD_MEMBERSHIP, - linux.NETLINK_BROADCAST_ERROR, + case linux.NETLINK_ADD_MEMBERSHIP: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + group := hostarch.ByteOrder.Uint32(opt) + var err *syserr.Error + s.netns.NetlinkMcastTable().WithTableLocked(func() { + err = s.joinGroup(t, group) + }) + return err + + case linux.NETLINK_DROP_MEMBERSHIP: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + group := hostarch.ByteOrder.Uint32(opt) + var err *syserr.Error + s.netns.NetlinkMcastTable().WithTableLocked(func() { + err = s.leaveGroup(t, group) + }) + return err + + case linux.NETLINK_BROADCAST_ERROR, linux.NETLINK_CAP_ACK, - linux.NETLINK_DROP_MEMBERSHIP, linux.NETLINK_DUMP_STRICT_CHK, linux.NETLINK_EXT_ACK, linux.NETLINK_LISTEN_ALL_NSID, linux.NETLINK_NO_ENOBUFS, linux.NETLINK_PKTINFO: + // Not supported. } } diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index cebd8db5ba..0bb273a64d 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -1,3 +1,4 @@ +load("//pkg/sync/locking:locking.bzl", "declare_mutex") load("//tools:defs.bzl", "go_library", "proto_library") package( @@ -5,10 +6,18 @@ package( licenses = ["notice"], ) +declare_mutex( + name = "netstack_link_mutex", + out = "netstack_link_mutex.go", + package = "netstack", + prefix = "netstackLink", +) + go_library( name = "netstack", srcs = [ "netstack.go", + "netstack_link_mutex.go", "netstack_state.go", "provider.go", "save_restore.go", @@ -48,6 +57,7 @@ go_library( "//pkg/sentry/socket/netstack/packetmmap", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/sync/locking", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index d4934cf7bd..dba248b684 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -41,6 +41,46 @@ import ( // +stateify savable type Stack struct { Stack *stack.Stack `state:".(*stack.Stack)"` + + eventSubscriber inet.InterfaceEventSubscriber + + // linkLock serializes link creation, modification and deletion. + // It is a rough parallel to the per-netns rtnl_mutex in Linux. + linkLock netstackLinkMutex `state:"nosave"` +} + +// AddInterfaceEventSubscriber implements inet.InterfaceEventPublisher.AddInterfaceEventSubscriber. +func (s *Stack) AddInterfaceEventSubscriber(sub inet.InterfaceEventSubscriber) { + if s.eventSubscriber != nil { + panic("AddInterfaceEventSubscriber called twice: multiple subscribers yet to be supported") + } + s.eventSubscriber = sub +} + +func makeInterfaceInfo(ni *stack.NICInfo) inet.Interface { + return inet.Interface{ + Name: ni.Name, + Addr: []byte(ni.LinkAddress), + Flags: uint32(nicStateFlagsToLinux(ni.Flags)), + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), + MTU: ni.MTU, + } +} + +func (s *Stack) sendChangeEvent(ctx context.Context, id tcpip.NICID) { + if s.eventSubscriber == nil { + return + } + if nicInfo, ok := s.Stack.SingleNICInfo(id); ok { + s.eventSubscriber.OnInterfaceChangeEvent(ctx, int32(id), makeInterfaceInfo(nicInfo)) + } +} + +func (s *Stack) sendDeleteEvent(ctx context.Context, id tcpip.NICID, nicInfo *stack.NICInfo) { + if s.eventSubscriber == nil { + return + } + s.eventSubscriber.OnInterfaceDeleteEvent(ctx, int32(id), makeInterfaceInfo(nicInfo)) } // EnableSaveRestore enables netstack s/r. @@ -90,22 +130,19 @@ func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - is[int32(id)] = inet.Interface{ - Name: ni.Name, - Addr: []byte(ni.LinkAddress), - Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), - MTU: ni.MTU, - } + is[int32(id)] = makeInterfaceInfo(&ni) } return is } // RemoveInterface implements inet.Stack.RemoveInterface. -func (s *Stack) RemoveInterface(idx int32) error { +func (s *Stack) RemoveInterface(ctx context.Context, idx int32) error { + s.linkLock.Lock() + defer s.linkLock.Unlock() + nic := tcpip.NICID(idx) - nicInfo, ok := s.Stack.NICInfo()[nic] + nicInfo, ok := s.Stack.SingleNICInfo(nic) if !ok { return syserr.ErrUnknownNICID.ToError() } @@ -115,7 +152,12 @@ func (s *Stack) RemoveInterface(idx int32) error { return syserr.ErrNotSupported.ToError() } - return syserr.TranslateNetstackError(s.Stack.RemoveNIC(nic)).ToError() + if err := syserr.TranslateNetstackError(s.Stack.RemoveNIC(nic)); err != nil { + return err.ToError() + } + s.sendDeleteEvent(ctx, nic, nicInfo) + return nil + } // SetInterface implements inet.Stack.SetInterface. @@ -180,10 +222,18 @@ func (s *Stack) SetInterface(ctx context.Context, msg *nlmsg.Message) *syserr.Er // Netstack interfaces are always up. } - return s.setLink(ctx, tcpip.NICID(ifinfomsg.Index), attrs) + s.linkLock.Lock() + defer s.linkLock.Unlock() + return s.setLinkLocked(ctx, tcpip.NICID(ifinfomsg.Index), attrs) } -func (s *Stack) setLink(ctx context.Context, id tcpip.NICID, linkAttrs map[uint16]nlmsg.BytesView) *syserr.Error { +// precondition: s.linkLock is held. +func (s *Stack) setLinkLocked(ctx context.Context, id tcpip.NICID, linkAttrs map[uint16]nlmsg.BytesView) *syserr.Error { + oldNicInfo, ok := s.Stack.SingleNICInfo(id) + if !ok { + return syserr.ErrUnknownNICID + } + // IFLA_NET_NS_FD has to be handled first, because other parameters may be reset. if v, ok := linkAttrs[linux.IFLA_NET_NS_FD]; ok { fd, ok := v.Uint32() @@ -202,12 +252,21 @@ func (s *Stack) setLink(ctx context.Context, id tcpip.NICID, linkAttrs map[uint1 peer := ns.Stack().(*Stack) if peer.Stack != s.Stack { var err tcpip.Error + oldID := id + id, err = s.Stack.SetNICStack(id, peer.Stack) if err != nil { return syserr.TranslateNetstackError(err) } + + s.sendDeleteEvent(ctx, oldID, oldNicInfo) // inform about exit from old ns + peer.sendChangeEvent(ctx, id) // inform about entry into new ns + // TODO: Once we support IFLA_LINK_NETNSID, we need to call sendChangeEvent on + // the peer interface if this interface is part of a veth pair. } } + + changed := false for t, v := range linkAttrs { switch t { case linux.IFLA_MASTER: @@ -215,35 +274,55 @@ func (s *Stack) setLink(ctx context.Context, id tcpip.NICID, linkAttrs map[uint1 if !ok { return syserr.ErrInvalidArgument } + if mid, ok := s.Stack.GetNICCoordinatorID(id); ok && mid == tcpip.NICID(master) { + continue + } if master != 0 { if err := s.Stack.SetNICCoordinator(id, tcpip.NICID(master)); err != nil { return syserr.TranslateNetstackError(err) } + changed = true } case linux.IFLA_ADDRESS: if len(v) != tcpip.LinkAddressSize { return syserr.ErrInvalidArgument } addr := tcpip.LinkAddress(v) + if oldNicInfo.LinkAddress == addr { + continue + } if err := s.Stack.SetNICAddress(id, addr); err != nil { return syserr.TranslateNetstackError(err) } + changed = true case linux.IFLA_IFNAME: + if oldNicInfo.Name == v.String() { + continue + } if err := s.Stack.SetNICName(id, v.String()); err != nil { return syserr.TranslateNetstackError(err) } + changed = true case linux.IFLA_MTU: mtu, ok := v.Uint32() if !ok { return syserr.ErrInvalidArgument } + if oldNicInfo.MTU == mtu { + continue + } if err := s.Stack.SetNICMTU(id, mtu); err != nil { return syserr.TranslateNetstackError(err) } + changed = true case linux.IFLA_TXQLEN: // TODO(b/340388892): support IFLA_TXQLEN. } } + + if changed { + s.sendChangeEvent(ctx, id) + } return nil } @@ -298,6 +377,8 @@ func (s *Stack) newVeth(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesVie } } } + + s.linkLock.Lock() ep, peerEP := veth.NewPair(defaultMTU, veth.DefaultBacklogSize) id := s.Stack.NextNICID() peerID := peerStack.Stack.NextNICID() @@ -308,16 +389,21 @@ func (s *Stack) newVeth(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesVie Name: ifname, }) if err != nil { + s.linkLock.Unlock() return syserr.TranslateNetstackError(err) } - if err := s.setLink(ctx, id, linkAttrs); err != nil { + if err := s.setLinkLocked(ctx, id, linkAttrs); err != nil { + s.linkLock.Unlock() peerEP.Close() return err } + s.linkLock.Unlock() if peerName == "" { peerName = fmt.Sprintf("veth%d", peerID) } + peerStack.linkLock.Lock() + defer peerStack.linkLock.Unlock() err = peerStack.Stack.CreateNICWithOptions(peerID, packetsocket.New(ethernet.New(peerEP)), stack.NICOptions{ Name: peerName, }) @@ -326,7 +412,7 @@ func (s *Stack) newVeth(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesVie return syserr.TranslateNetstackError(err) } if peerLinkAttrs != nil { - if err := peerStack.setLink(ctx, peerID, peerLinkAttrs); err != nil { + if err := peerStack.setLinkLocked(ctx, peerID, peerLinkAttrs); err != nil { peerStack.Stack.RemoveNIC(peerID) peerEP.Close() return err @@ -337,6 +423,9 @@ func (s *Stack) newVeth(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesVie } func (s *Stack) newBridge(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesView, linkInfoAttrs map[uint16]nlmsg.BytesView) *syserr.Error { + s.linkLock.Lock() + defer s.linkLock.Unlock() + ifname := "" if v, ok := linkAttrs[linux.IFLA_IFNAME]; ok { @@ -350,7 +439,7 @@ func (s *Stack) newBridge(ctx context.Context, linkAttrs map[uint16]nlmsg.BytesV if err != nil { return syserr.TranslateNetstackError(err) } - if err := s.setLink(ctx, id, linkAttrs); err != nil { + if err := s.setLinkLocked(ctx, id, linkAttrs); err != nil { return err } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c486a91c81..23b1ba0c58 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1066,6 +1066,18 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) (func(), tcpip.Error) { return nic.remove(true /* closeLinkEndpoint */) } +// GetNICCoordinatorID returns the ID of the coordinator device of a NIC. +func (s *Stack) GetNICCoordinatorID(id tcpip.NICID) (tcpip.NICID, bool) { + s.mu.Lock() + defer s.mu.Unlock() + if nic, ok := s.nics[id]; ok { + if nic.Primary != nil { + return nic.Primary.id, true + } + } + return 0, false +} + // SetNICCoordinator sets a coordinator device. func (s *Stack) SetNICCoordinator(id tcpip.NICID, mid tcpip.NICID) tcpip.Error { s.mu.Lock() @@ -1176,65 +1188,83 @@ func (s *Stack) HasNIC(id tcpip.NICID) bool { return ok } -// NICInfo returns a map of NICIDs to their associated information. -func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { - s.mu.RLock() - defer s.mu.RUnlock() +type forwardingFn func(tcpip.NetworkProtocolNumber) (bool, tcpip.Error) - type forwardingFn func(tcpip.NetworkProtocolNumber) (bool, tcpip.Error) - forwardingValue := func(forwardingFn forwardingFn, proto tcpip.NetworkProtocolNumber, nicID tcpip.NICID, fnName string) (forward bool, ok bool) { - switch forwarding, err := forwardingFn(proto); err.(type) { - case nil: - return forwarding, true - case *tcpip.ErrUnknownProtocol: - panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nicID)) - case *tcpip.ErrNotSupported: - // Not all network protocols support forwarding. - default: - panic(fmt.Sprintf("nic(id=%d).%s(%d): %s", nicID, fnName, proto, err)) - } - return false, false +func forwardingValue(forwardingFn forwardingFn, proto tcpip.NetworkProtocolNumber, nicID tcpip.NICID, fnName string) (forward bool, ok bool) { + switch forwarding, err := forwardingFn(proto); err.(type) { + case nil: + return forwarding, true + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nicID)) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + default: + panic(fmt.Sprintf("nic(name=%d).%s(%d): %s", nicID, fnName, proto, err)) } + return false, false +} - nics := make(map[tcpip.NICID]NICInfo) - for id, nic := range s.nics { - flags := NICStateFlags{ - Up: true, // Netstack interfaces are always up. - Running: nic.Enabled(), - Promiscuous: nic.Promiscuous(), - Loopback: nic.IsLoopback(), - } +// precondition: s.mu is held. +func (s *Stack) getNICInfo(nic *nic, id tcpip.NICID) *NICInfo { + flags := NICStateFlags{ + Up: true, // Netstack interfaces are always up. + Running: nic.Enabled(), + Promiscuous: nic.Promiscuous(), + Loopback: nic.IsLoopback(), + } - netStats := make(map[tcpip.NetworkProtocolNumber]NetworkEndpointStats) - for proto, netEP := range nic.networkEndpoints { - netStats[proto] = netEP.Stats() + netStats := make(map[tcpip.NetworkProtocolNumber]NetworkEndpointStats) + for proto, netEP := range nic.networkEndpoints { + netStats[proto] = netEP.Stats() + } + + info := NICInfo{ + Name: nic.name, + LinkAddress: nic.NetworkLinkEndpoint.LinkAddress(), + ProtocolAddresses: nic.primaryAddresses(), + Flags: flags, + MTU: nic.NetworkLinkEndpoint.MTU(), + Stats: nic.stats.local, + NetworkStats: netStats, + Context: nic.context, + ARPHardwareType: nic.NetworkLinkEndpoint.ARPHardwareType(), + Forwarding: make(map[tcpip.NetworkProtocolNumber]bool), + MulticastForwarding: make(map[tcpip.NetworkProtocolNumber]bool), + } + + for proto := range s.networkProtocols { + if forwarding, ok := forwardingValue(nic.forwarding, proto, id, "forwarding"); ok { + info.Forwarding[proto] = forwarding } - info := NICInfo{ - Name: nic.name, - LinkAddress: nic.NetworkLinkEndpoint.LinkAddress(), - ProtocolAddresses: nic.primaryAddresses(), - Flags: flags, - MTU: nic.NetworkLinkEndpoint.MTU(), - Stats: nic.stats.local, - NetworkStats: netStats, - Context: nic.context, - ARPHardwareType: nic.NetworkLinkEndpoint.ARPHardwareType(), - Forwarding: make(map[tcpip.NetworkProtocolNumber]bool), - MulticastForwarding: make(map[tcpip.NetworkProtocolNumber]bool), + if multicastForwarding, ok := forwardingValue(nic.multicastForwarding, proto, id, "multicastForwarding"); ok { + info.MulticastForwarding[proto] = multicastForwarding } + } - for proto := range s.networkProtocols { - if forwarding, ok := forwardingValue(nic.forwarding, proto, id, "forwarding"); ok { - info.Forwarding[proto] = forwarding - } + return &info +} - if multicastForwarding, ok := forwardingValue(nic.multicastForwarding, proto, id, "multicastForwarding"); ok { - info.MulticastForwarding[proto] = multicastForwarding - } - } +// SingleNICInfo returns the NICInfo for the given NICID. +func (s *Stack) SingleNICInfo(id tcpip.NICID) (*NICInfo, bool) { + s.mu.RLock() + defer s.mu.RUnlock() - nics[id] = info + if nic, ok := s.nics[id]; !ok { + return nil, false + } else { + return s.getNICInfo(nic, id), true + } +} + +// NICInfo returns a map of NICIDs to their associated information. +func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + nics := make(map[tcpip.NICID]NICInfo) + for id, nic := range s.nics { + nics[id] = *s.getNICInfo(nic, id) } return nics } diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 55f6e5afae..c0348f6c78 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -16,10 +16,11 @@ #include #include #include -#include #include #include #include +#include +#include #include #include #include @@ -280,6 +281,29 @@ TEST_P(NetlinkSetLinkTest, ChangeLinkName) { EXPECT_TRUE(found) << "Netlink response does not contain any links."; } +struct MtuRequest { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + uint32_t mtu; +}; + +MtuRequest GetMtuRequest(const Link& link, uint16_t nlmsg_type, uint32_t mtu) { + MtuRequest req = {}; + + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = nlmsg_type; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = link.index; + req.rtattr.rta_type = IFLA_MTU; + req.rtattr.rta_len = RTA_LENGTH(sizeof(uint32_t)); + req.mtu = mtu; + + return req; +} + TEST_P(NetlinkSetLinkTest, ChangeMTU) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); SKIP_IF(IsRunningWithHostinet()); @@ -289,23 +313,9 @@ TEST_P(NetlinkSetLinkTest, ChangeMTU) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - struct rtattr rtattr; - uint32_t mtu; - } req = {}; - // Change the MTU. - req.hdr.nlmsg_len = sizeof(req); - req.hdr.nlmsg_type = GetParam(); - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; - req.hdr.nlmsg_seq = kSeq; - req.ifm.ifi_family = AF_UNSPEC; - req.ifm.ifi_index = loopback_link.index; - req.rtattr.rta_type = IFLA_MTU; - req.rtattr.rta_len = RTA_LENGTH(sizeof(uint32_t)); - req.mtu = loopback_link.mtu + 10; + MtuRequest req = + GetMtuRequest(loopback_link, GetParam(), loopback_link.mtu + 10); EXPECT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req))); // Update the local loopback_link's MTU to the requested value. @@ -1739,55 +1749,55 @@ void addattr(struct nlmsghdr* n, int maxlen, int type, const void* data, n->nlmsg_len = NLMSG_ALIGN(n->nlmsg_len) + RTA_ALIGN(len); } -TEST(NetlinkRouteTest, VethAdd) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); - SKIP_IF(IsRunningWithHostinet()); +struct VethRequest { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + char buf[1024]; +}; - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifinfomsg ifm; - char buf[1024]; - }; - - struct request req = {}; +struct VethRequest GetVethRequest(uint32_t seq, const char* ifname_first, + const char* ifname_second) { + struct VethRequest req = {}; req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifinfomsg)); req.hdr.nlmsg_type = RTM_NEWLINK; req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE; - req.hdr.nlmsg_seq = kSeq; + req.hdr.nlmsg_seq = seq; req.ifm.ifi_family = AF_UNSPEC; req.ifm.ifi_index = 0; - req.ifm.ifi_change = IFF_UP; - req.ifm.ifi_flags = IFF_UP; - const char veth_first[] = "veth_first"; - addattr(&req.hdr, sizeof(req), IFLA_IFNAME, veth_first, strlen(veth_first)); + addattr(&req.hdr, sizeof(req), IFLA_IFNAME, ifname_first, + strlen(ifname_first)); - struct rtattr* linkinfo; - linkinfo = NLMSG_TAIL(&req.hdr); + struct rtattr* linkinfo = NLMSG_TAIL(&req.hdr); { addattr(&req.hdr, sizeof(req), IFLA_LINKINFO, nullptr, 0); addattr(&req.hdr, sizeof(req), IFLA_INFO_KIND, "veth", 4); - - struct rtattr *veth_data, *peer_data; - veth_data = NLMSG_TAIL(&req.hdr); + struct rtattr* veth_data = NLMSG_TAIL(&req.hdr); { addattr(&req.hdr, sizeof(req), IFLA_INFO_DATA, NULL, 0); - peer_data = NLMSG_TAIL(&req.hdr); + struct rtattr* peer_data = NLMSG_TAIL(&req.hdr); { struct ifinfomsg ifm = {}; addattr(&req.hdr, sizeof(req), VETH_INFO_PEER, &ifm, sizeof(ifm)); - const char veth_second[] = "veth_second"; - addattr(&req.hdr, sizeof(req), IFLA_IFNAME, veth_second, - strlen(veth_second)); + addattr(&req.hdr, sizeof(req), IFLA_IFNAME, ifname_second, + strlen(ifname_second)); } peer_data->rta_len = (uint64_t)NLMSG_TAIL(&req.hdr) - (uint64_t)peer_data; } veth_data->rta_len = (uint64_t)NLMSG_TAIL(&req.hdr) - (uint64_t)veth_data; } linkinfo->rta_len = (uint64_t)NLMSG_TAIL(&req.hdr) - (uint64_t)linkinfo; + + return req; +} + +TEST(NetlinkRouteTest, VethAdd) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + VethRequest req = GetVethRequest(kSeq, "veth1", "veth2"); EXPECT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len)); } @@ -1811,6 +1821,352 @@ TEST(NetlinkRouteTest, LookupAllAddrOrder) { freeifaddrs(if_addr_list); } } + +struct NameRequest { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + char name[IFNAMSIZ]; +}; + +NameRequest GetNameRequest(const Link& link, const char* name, uint32_t seq) { + NameRequest req = {}; + req.hdr.nlmsg_type = RTM_SETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = link.index; + + const size_t payload_len = strlen(name) + 1; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(payload_len); + memcpy(req.name, name, payload_len); + + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(struct ifinfomsg)) + RTA_SPACE(payload_len); + return req; +}; + +TEST(NetlinkRouteTest, LinkMulticastGroupBasic) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + + // nlsk_bound_group joins RTMGRP_LINK via bind(). + struct sockaddr_nl addr = {}; + addr.nl_family = AF_NETLINK; + addr.nl_groups = RTMGRP_LINK; + FileDescriptor nlsk_bound_group = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &addr)); + + // nlsk_sockopt_group joins RTMGRP_LINK via setsockopt(). + addr = {}; + addr.nl_family = AF_NETLINK; + FileDescriptor nlsk_sockopt_group = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &addr)); + unsigned int group = RTMGRP_LINK; + ASSERT_THAT(setsockopt(nlsk_sockopt_group.get(), SOL_NETLINK, + NETLINK_ADD_MEMBERSHIP, &group, sizeof(group)), + SyscallSucceeds()); + int64_t res_groups; + socklen_t res_groups_len = sizeof(res_groups); + EXPECT_THAT( + getsockopt(nlsk_sockopt_group.get(), SOL_NETLINK, + NETLINK_LIST_MEMBERSHIPS, &res_groups, &res_groups_len), + SyscallSucceeds()); + EXPECT_EQ(res_groups_len, sizeof(res_groups)); + EXPECT_EQ(res_groups, RTMGRP_LINK); + + FileDescriptor control_fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + // Change the name of the loopback interface. + const Link link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + std::string old_loopback_name = link.name; + NameRequest name_request = GetNameRequest(link, "lo_test", kSeq); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_fd, kSeq, &name_request, + name_request.hdr.nlmsg_len)); + auto restore_loopback_name = Cleanup([&]() { + NameRequest name_request = + GetNameRequest(link, old_loopback_name.c_str(), kSeq); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_fd, kSeq, &name_request, + name_request.hdr.nlmsg_len)); + }); + const Link link_newname = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + + // Change the MTU of the loopback interface. + MtuRequest mtu_request = GetMtuRequest(link, RTM_SETLINK, link.mtu + 10); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_fd, kSeq, &mtu_request, + sizeof(mtu_request))); + auto restore_mtu = Cleanup([&]() { + MtuRequest mtu_request = GetMtuRequest(link, RTM_SETLINK, link.mtu); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_fd, kSeq, &mtu_request, + sizeof(mtu_request))); + }); + const Link link_newmtu = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + + struct TestCase { + const char* name; + FileDescriptor* nlsk; + const char* event_name; + const Link& link; + }; + std::vector test_cases = { + { + .name = "bound_group", + .nlsk = &nlsk_bound_group, + .event_name = "name_change", + .link = link_newname, + }, + { + .name = "sockopt_group", + .nlsk = &nlsk_sockopt_group, + .event_name = "name_change", + .link = link_newname, + }, + { + .name = "bound_group", + .nlsk = &nlsk_bound_group, + .event_name = "mtu_change", + .link = link_newmtu, + }, + { + .name = "sockopt_group", + .nlsk = &nlsk_sockopt_group, + .event_name = "mtu_change", + .link = link_newmtu, + }, + }; + + for (const auto& tc : test_cases) { + struct pollfd pfd = {.fd = tc.nlsk->get(), .events = POLLIN}; + constexpr int kPollTimeoutMs = 1000; + int poll_ret = RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs); + ASSERT_EQ(poll_ret, 1) << "Did not get link event " << tc.event_name + << " on the " << tc.name << " netlink socket."; + + bool got_msg = false; + ASSERT_NO_ERRNO(NetlinkResponse( + *tc.nlsk, + [&](const struct nlmsghdr* hdr) { + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + if (msg->ifi_index != tc.link.index) { + return; + } + CheckLinkMsg(hdr, tc.link); + got_msg = true; + }, + /*expect_nlmsgerr=*/false)); + EXPECT_TRUE(got_msg) << "Did not get link event " << tc.event_name + << " on the " << tc.name << " netlink socket."; + } +} + +struct VethRequest GetSetNetNSRequest(uint32_t seq, int if_index, int ns_fd) { + struct VethRequest req = {}; + + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifinfomsg)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = if_index; + addattr(&req.hdr, sizeof(req), IFLA_NET_NS_FD, &ns_fd, sizeof(ns_fd)); + + return req; +} + +// To verify the namespaced nature of the netlink multicast groups. +TEST(NetlinkRouteTest, LinkMulticastGroupNamespaced) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + + FileDescriptor control_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + VethRequest req = GetVethRequest(kSeq, "veth1", "veth2"); + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(control_nlsk, kSeq, &req, req.hdr.nlmsg_len)); + + int inner_veth_idx = if_nametoindex("veth2"); + ASSERT_NE(inner_veth_idx, 0); + + struct sockaddr_nl mcast_addr = {}; + mcast_addr.nl_family = AF_NETLINK; + mcast_addr.nl_groups = RTMGRP_LINK; + FileDescriptor root_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &mcast_addr)); + + const FileDescriptor root_nsfd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/thread-self/ns/net", O_RDONLY)); + Cleanup restore_netns = Cleanup([&] { + ASSERT_THAT(setns(root_nsfd.get(), CLONE_NEWNET), + SyscallSucceedsWithValue(0)); + }); + + // Enter a new network namespace. + ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0)); + FileDescriptor inner_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &mcast_addr)); + + // And move veth2 into it. + const FileDescriptor inner_nsfd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/thread-self/ns/net", O_RDONLY)); + VethRequest set_netns_req = + GetSetNetNSRequest(kSeq, inner_veth_idx, inner_nsfd.get()); + EXPECT_NO_ERRNO(NetlinkRequestAckOrError(control_nlsk, kSeq, &set_netns_req, + set_netns_req.hdr.nlmsg_len)); + + constexpr int kPollTimeoutMs = 1000; + bool got_msg = false; + // We expect an RTM_DELINK message for veth2 in the root netns socket. + // But an RTM_NEWLINK is also expected for veth1 because its peer was moved. + // Hence the two attempts. N.B. gVisor does not send the RTM_NEWLINK because + // IFLA_LINK_NETNSID is not yet supported. + for (int i = 0; i < 2; i++) { + struct pollfd pfd = {.fd = root_nlsk.get(), .events = POLLIN}; + ASSERT_EQ(RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs), 1) + << "root_nlsk: Did not get veth2 DELLINK"; + + ASSERT_NO_ERRNO(NetlinkResponse( + root_nlsk, + [&](const struct nlmsghdr* hdr) { + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + if (hdr->nlmsg_type != RTM_DELLINK) return; + if (msg->ifi_index != inner_veth_idx) return; + got_msg = true; + }, + /*expect_nlmsgerr=*/false)); + if (got_msg) break; + } + EXPECT_TRUE(got_msg) << "root_nlsk: Did not get veth2 DELLINK"; + + // We expect an RTM_NEWLINK message for veth2 in the inner netns socket. + { + struct pollfd pfd = {.fd = inner_nlsk.get(), .events = POLLIN}; + ASSERT_EQ(RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs), 1) + << "inner_nlsk: Did not get veth2 NEWLINK"; + + bool got_msg = false; + ASSERT_NO_ERRNO(NetlinkResponse( + inner_nlsk, + [&](const struct nlmsghdr* hdr) { + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + ASSERT_EQ(hdr->nlmsg_type, RTM_NEWLINK); + if (msg->ifi_index == 1) return; // Ignore the loopback interface. + + char ifname[IF_NAMESIZE]; + EXPECT_NE(if_indextoname(msg->ifi_index, ifname), nullptr); + EXPECT_STREQ(ifname, "veth2"); + got_msg = true; + }, + /*expect_nlmsgerr=*/false)); + EXPECT_TRUE(got_msg) << "inner_nlsk: Did not get veth2 NEWLINK"; + } +} + +// NOOP requests should not result in any netlink multicast messages. +TEST(NetlinkRouteTest, LinkMulticastGroupNoop) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + + struct sockaddr_nl mcast_addr = {}; + mcast_addr.nl_family = AF_NETLINK; + mcast_addr.nl_groups = RTMGRP_LINK; + FileDescriptor nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &mcast_addr)); + + // Issue a request to set the name of the loopback interface to the same name. + const Link link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + NameRequest name_request = GetNameRequest(link, link.name.c_str(), kSeq); + FileDescriptor control_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_nlsk, kSeq, &name_request, + name_request.hdr.nlmsg_len)); + + // We expect no RTM_NEWLINK message for the loopback interface. + struct pollfd pfd = {.fd = nlsk.get(), .events = POLLIN}; + constexpr int kPollTimeoutMs = 500; + bool got_msg = false; + if (RetryEINTR(poll)(&pfd, 1, kPollTimeoutMs) >= 1) { + ASSERT_NO_ERRNO(NetlinkResponse( + nlsk, + [&](const struct nlmsghdr* hdr) { + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + if (hdr->nlmsg_type != RTM_NEWLINK) return; + if (msg->ifi_index != link.index) return; + got_msg = true; + }, + /*expect_nlmsgerr=*/false)); + } + EXPECT_FALSE(got_msg) + << "Should not get a newlink event for the loopback interface."; +} + +// Userspace should know that it failed to keep up with its recvmsg()s, and the +// kernel alerts it to this by having recvmsg() return ENOBUFS. +TEST(NetlinkRouteTest, LinkMulticastGroupEnobufs) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + SKIP_IF(IsRunningWithHostinet()); + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; + // TODO(b/456238795): enable this test once gVisor returns ENOBUFS. + if (IsRunningOnGvisor()) { + GTEST_SKIP() << "gVisor never returns ENOBUFS."; + } + + struct sockaddr_nl mcast_addr = {}; + mcast_addr.nl_family = AF_NETLINK; + mcast_addr.nl_groups = RTMGRP_LINK; + FileDescriptor nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE, &mcast_addr)); + + // N.B. gvisor ignores the SO_RCVBUF value. + constexpr int kSmallRcvBufSize = 512; + ASSERT_THAT(setsockopt(nlsk.get(), SOL_SOCKET, SO_RCVBUF, &kSmallRcvBufSize, + sizeof(int)), + SyscallSucceeds()); + int recv_buf_size; + socklen_t rec_buf_size_len = sizeof(recv_buf_size); + ASSERT_THAT(getsockopt(nlsk.get(), SOL_SOCKET, SO_RCVBUF, &recv_buf_size, + &rec_buf_size_len), + SyscallSucceeds()); + + // Generate enough link events to overflow poor nlsk's receive buffer. + FileDescriptor control_nlsk = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + Link link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + constexpr int kMinimumNewlinkMsgSize = 32; + const int num_msgs = recv_buf_size / kMinimumNewlinkMsgSize; + int i = 0; + while (i < num_msgs || link.name != "lo") { + std::string name = link.name == "lo" ? "lo_test" : "lo"; + NameRequest name_request = GetNameRequest(link, name.c_str(), kSeq); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(control_nlsk, kSeq, &name_request, + name_request.hdr.nlmsg_len)); + link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + i++; + } + + std::vector buf(kSmallRcvBufSize); + struct iovec iov = {}; + iov.iov_base = buf.data(); + iov.iov_len = buf.size(); + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + EXPECT_THAT(RetryEINTR(recvmsg)(nlsk.get(), &msg, 0), + SyscallFailsWithErrno(ENOBUFS)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index 61faeb50e2..f77ece9ed1 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -290,6 +290,7 @@ PosixErrorOr> DumpLinks() { rta_address == nullptr ? "" : std::string(reinterpret_cast(RTA_DATA(rta_address))); + links.back().flags = msg->ifi_flags; })); return links; } diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index fdb1bb5a0b..2b37eb9e9e 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -31,6 +31,7 @@ struct Link { std::string name; uint32_t mtu; std::string address; + unsigned int flags; }; PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index c100066df6..9b810d02c0 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -38,14 +38,17 @@ namespace gvisor { namespace testing { PosixErrorOr NetlinkBoundSocket(int protocol) { - FileDescriptor fd; - ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol)); - struct sockaddr_nl addr = {}; addr.nl_family = AF_NETLINK; + return NetlinkBoundSocket(protocol, &addr); +} - RETURN_ERROR_IF_SYSCALL_FAIL( - bind(fd.get(), reinterpret_cast(&addr), sizeof(addr))); +PosixErrorOr NetlinkBoundSocket( + int protocol, const struct sockaddr_nl* addr) { + FileDescriptor fd; + ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol)); + + RETURN_ERROR_IF_SYSCALL_FAIL(bind(fd.get(), AsSockAddr(addr), sizeof(*addr))); MaybeSave(); return std::move(fd); diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index b21f513f69..c6661b7e90 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -35,6 +35,10 @@ namespace testing { // Returns a bound netlink socket. PosixErrorOr NetlinkBoundSocket(int protocol); +// Returns a bound netlink socket. +PosixErrorOr NetlinkBoundSocket(int protocol, + const struct sockaddr_nl* addr); + // Returns the port ID of the passed socket. PosixErrorOr NetlinkPortID(int fd); @@ -86,6 +90,14 @@ void InitNetlinkAttr(struct nlattr* attr, int payload_size, uint16_t attr_type); // Helper function to find a netlink attribute in a message. const struct nfattr* FindNfAttr(const struct nlmsghdr* hdr, const struct nfgenmsg* msg, int16_t attr); + +inline sockaddr* AsSockAddr(sockaddr_nl* s) { + return reinterpret_cast(s); +} +inline const sockaddr* AsSockAddr(const sockaddr_nl* s) { + return reinterpret_cast(s); +} + } // namespace testing } // namespace gvisor