Skip to content

Commit 7756ff6

Browse files
committed
linux: use rtnetlink directly
This change uses rtnetlink directly per route query instead of relying on the routing information base. This fixes a routing issue where, with some VPNs (wireguard based ones), this library would return the incorrect source IP when queried. This is because Wireguard creates a new table and routing rules that are not reflected in the routing information base. Instead of trying to recreate this logic, we can query kernel directly via the rtnetlink socket. This is a bit painful from Go, but doable. Consult `man 7 rtnetlink` for more information on the rtnetlink interface.
1 parent 7c836b7 commit 7756ff6

File tree

1 file changed

+275
-82
lines changed

1 file changed

+275
-82
lines changed

netroute_linux.go

Lines changed: 275 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -12,105 +12,298 @@
1212
package netroute
1313

1414
import (
15+
"bytes"
16+
"encoding/binary"
17+
"errors"
18+
"fmt"
1519
"net"
16-
"sort"
20+
"os"
21+
"slices"
22+
"sync/atomic"
1723
"syscall"
1824
"unsafe"
1925
)
2026

27+
var nlSequence uint32
28+
29+
type linuxRouter struct{}
30+
2131
func New() (Router, error) {
22-
rtr := &router{}
23-
rtr.ifaces = make(map[int]net.Interface)
24-
rtr.addrs = make(map[int]ipAddrs)
25-
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
26-
if err != nil {
27-
return nil, err
32+
return &linuxRouter{}, nil
33+
}
34+
35+
// Route implements Router.
36+
func (l *linuxRouter) Route(dst net.IP) (iface *net.Interface, gateway net.IP, preferredSrc net.IP, err error) {
37+
return l.RouteWithSrc(nil, nil, dst)
38+
}
39+
40+
// RouteWithSrc implements Router.
41+
func (l *linuxRouter) RouteWithSrc(input net.HardwareAddr, src net.IP, dst net.IP) (iface *net.Interface, gateway net.IP, preferredSrc net.IP, err error) {
42+
if dst == nil || dst.IsUnspecified() {
43+
return nil, nil, nil, errors.New("destination IP must be specified")
2844
}
29-
msgs, err := syscall.ParseNetlinkMessage(tab)
45+
46+
// trim bytes if this is a v4 addr
47+
if v4 := dst.To4(); v4 != nil {
48+
dst = v4
49+
}
50+
51+
dstFamily := addressFamily(dst)
52+
if src != nil && !src.IsUnspecified() {
53+
srcFamily := addressFamily(src)
54+
if srcFamily != dstFamily {
55+
return nil, nil, nil, fmt.Errorf("source %q and destination %q use different address families", src.String(), dst.String())
56+
}
57+
}
58+
59+
ifaces, err := net.Interfaces()
3060
if err != nil {
31-
return nil, err
32-
}
33-
loop:
34-
for _, m := range msgs {
35-
switch m.Header.Type {
36-
case syscall.NLMSG_DONE:
37-
break loop
38-
case syscall.RTM_NEWROUTE:
39-
rt := (*syscall.RtMsg)(unsafe.Pointer(&m.Data[0]))
40-
routeInfo := rtInfo{}
41-
attrs, err := syscall.ParseNetlinkRouteAttr(&m)
42-
if err != nil {
43-
return nil, err
44-
}
45-
if rt.Family != syscall.AF_INET && rt.Family != syscall.AF_INET6 {
46-
continue loop
47-
}
48-
for _, attr := range attrs {
49-
switch attr.Attr.Type {
50-
case syscall.RTA_DST:
51-
routeInfo.Dst = &net.IPNet{
52-
IP: net.IP(attr.Value),
53-
Mask: net.CIDRMask(int(rt.Dst_len), len(attr.Value)*8),
54-
}
55-
case syscall.RTA_SRC:
56-
routeInfo.Src = &net.IPNet{
57-
IP: net.IP(attr.Value),
58-
Mask: net.CIDRMask(int(rt.Src_len), len(attr.Value)*8),
59-
}
60-
case syscall.RTA_GATEWAY:
61-
routeInfo.Gateway = net.IP(attr.Value)
62-
case syscall.RTA_PREFSRC:
63-
routeInfo.PrefSrc = net.IP(attr.Value)
64-
case syscall.RTA_IIF:
65-
routeInfo.InputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
66-
case syscall.RTA_OIF:
67-
routeInfo.OutputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
68-
case syscall.RTA_PRIORITY:
69-
routeInfo.Priority = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
70-
}
71-
}
72-
if routeInfo.Dst == nil && routeInfo.Src == nil && routeInfo.Gateway == nil {
73-
continue loop
74-
}
75-
switch rt.Family {
76-
case syscall.AF_INET:
77-
rtr.v4 = append(rtr.v4, &routeInfo)
78-
case syscall.AF_INET6:
79-
rtr.v6 = append(rtr.v6, &routeInfo)
80-
default:
81-
// should not happen.
82-
continue loop
61+
return nil, nil, nil, fmt.Errorf("list interfaces: %w", err)
62+
}
63+
64+
var oif int
65+
if len(input) > 0 {
66+
for i := range ifaces {
67+
iface := ifaces[i]
68+
if bytes.Equal(iface.HardwareAddr, input) {
69+
oif = iface.Index
70+
break
8371
}
8472
}
73+
if oif == 0 {
74+
return nil, nil, nil, fmt.Errorf("no interface with address %s found", input.String())
75+
}
8576
}
86-
sort.Sort(rtr.v4)
87-
sort.Sort(rtr.v6)
88-
ifaces, err := net.Interfaces()
77+
78+
fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_ROUTE)
8979
if err != nil {
90-
return nil, err
80+
return nil, nil, nil, fmt.Errorf("open netlink socket: %w", err)
9181
}
92-
for _, iface := range ifaces {
93-
rtr.ifaces[iface.Index] = iface
94-
var addrs ipAddrs
95-
ifaceAddrs, err := iface.Addrs()
96-
if err != nil {
82+
defer syscall.Close(fd)
83+
84+
sa := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}
85+
if err := syscall.Bind(fd, sa); err != nil {
86+
return nil, nil, nil, fmt.Errorf("bind netlink socket: %w", err)
87+
}
88+
89+
seq := atomic.AddUint32(&nlSequence, 1)
90+
request, err := buildRouteRequest(dstFamily, dst, src, oif, seq, uint32(os.Getpid()))
91+
if err != nil {
92+
return nil, nil, nil, err
93+
}
94+
95+
if err := syscall.Sendto(fd, request, 0, &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}); err != nil {
96+
return nil, nil, nil, fmt.Errorf("send netlink request: %w", err)
97+
}
98+
99+
route, err := readRouteResponse(fd, seq)
100+
if err != nil {
101+
return nil, nil, nil, err
102+
}
103+
104+
outIface, err := net.InterfaceByIndex(int(route.OutputIface))
105+
if err != nil {
106+
return nil, nil, nil, fmt.Errorf("get interface by index: %w", err)
107+
}
108+
109+
if len(input) > 0 && iface != nil && len(iface.HardwareAddr) > 0 && !bytes.Equal(iface.HardwareAddr, input) {
110+
return nil, nil, nil, fmt.Errorf("route resolved to interface %s (%s), expected %s", iface.Name, iface.HardwareAddr, input)
111+
}
112+
113+
var srcIP net.IP
114+
if len(route.PrefSrc) > 0 && !route.PrefSrc.IsUnspecified() {
115+
srcIP = route.PrefSrc
116+
} else if route.Src != nil {
117+
srcIP = route.Src.IP
118+
} else {
119+
return nil, nil, nil, fmt.Errorf("no source IP found")
120+
}
121+
122+
// copyIP so we don't leak a reference to our working buffer
123+
return outIface, copyIP(route.Gateway), copyIP(srcIP), nil
124+
}
125+
126+
func buildRouteRequest(family int, dst []byte, src []byte, oif int, seq, pid uint32) ([]byte, error) {
127+
bodyBuf := new(bytes.Buffer)
128+
129+
rtm := syscall.RtMsg{
130+
Family: uint8(family),
131+
Dst_len: uint8(len(dst) * 8),
132+
Src_len: uint8(len(src) * 8),
133+
Table: syscall.RT_TABLE_UNSPEC,
134+
}
135+
if err := binary.Write(bodyBuf, binary.NativeEndian, rtm); err != nil {
136+
return nil, fmt.Errorf("marshal rtmsg: %w", err)
137+
}
138+
if len(dst) > 0 {
139+
if err := writeRouteAttr(bodyBuf, syscall.RTA_DST, dst); err != nil {
140+
return nil, err
141+
}
142+
}
143+
if len(src) > 0 {
144+
if err := writeRouteAttr(bodyBuf, syscall.RTA_SRC, src); err != nil {
97145
return nil, err
98146
}
99-
for _, addr := range ifaceAddrs {
100-
if inet, ok := addr.(*net.IPNet); ok {
101-
// Go has a nasty habit of giving you IPv4s as ::ffff:1.2.3.4 instead of 1.2.3.4.
102-
// We want to use mapped v4 addresses as v4 preferred addresses, never as v6
103-
// preferred addresses.
104-
if v4 := inet.IP.To4(); v4 != nil {
105-
if addrs.v4 == nil {
106-
addrs.v4 = v4
107-
}
108-
} else if addrs.v6 == nil {
109-
addrs.v6 = inet.IP
147+
}
148+
if oif != 0 {
149+
oifBytes := make([]byte, 4)
150+
binary.NativeEndian.PutUint32(oifBytes, uint32(oif))
151+
if err := writeRouteAttr(bodyBuf, syscall.RTA_OIF, oifBytes); err != nil {
152+
return nil, err
153+
}
154+
}
155+
156+
header := syscall.NlMsghdr{
157+
Len: uint32(syscall.NLMSG_HDRLEN + bodyBuf.Len()),
158+
Type: uint16(syscall.RTM_GETROUTE),
159+
Flags: uint16(syscall.NLM_F_REQUEST),
160+
Seq: seq,
161+
Pid: pid,
162+
}
163+
164+
msgBuf := new(bytes.Buffer)
165+
if err := binary.Write(msgBuf, binary.NativeEndian, header); err != nil {
166+
return nil, fmt.Errorf("marshal nlmsghdr: %w", err)
167+
}
168+
if _, err := msgBuf.Write(bodyBuf.Bytes()); err != nil {
169+
return nil, fmt.Errorf("assemble netlink request: %w", err)
170+
}
171+
return msgBuf.Bytes(), nil
172+
}
173+
174+
func writeRouteAttr(buf *bytes.Buffer, attrType uint16, payload []byte) error {
175+
attrLen := uint16(syscall.SizeofRtAttr + len(payload))
176+
attr := syscall.RtAttr{
177+
Len: attrLen,
178+
Type: attrType,
179+
}
180+
if err := binary.Write(buf, binary.NativeEndian, attr); err != nil {
181+
return fmt.Errorf("marshal rtattr: %w", err)
182+
}
183+
if _, err := buf.Write(payload); err != nil {
184+
return fmt.Errorf("write rtattr payload: %w", err)
185+
}
186+
187+
const nlMsghdrLenAlignment = 4
188+
padLen := alignTo(attrLen, nlMsghdrLenAlignment) - int(attrLen)
189+
if padLen > 0 {
190+
_, _ = buf.Write(make([]byte, padLen))
191+
}
192+
return nil
193+
}
194+
195+
func alignTo(length uint16, alignment int) int {
196+
l := int(length)
197+
return (l + alignment - 1) & ^(alignment - 1)
198+
}
199+
200+
func readRouteResponse(fd int, seq uint32) (*rtInfo, error) {
201+
buf := make([]byte, 1<<16)
202+
for {
203+
n, _, err := syscall.Recvfrom(fd, buf, 0)
204+
if err != nil {
205+
return nil, fmt.Errorf("receive netlink response: %w", err)
206+
}
207+
208+
msgs, err := syscall.ParseNetlinkMessage(buf[:n])
209+
if err != nil {
210+
return nil, fmt.Errorf("parse netlink response: %w", err)
211+
}
212+
213+
for _, m := range msgs {
214+
if m.Header.Seq != seq {
215+
continue
216+
}
217+
switch m.Header.Type {
218+
case syscall.NLMSG_ERROR:
219+
if err := parseNetlinkError(m.Data); err != nil {
220+
return nil, fmt.Errorf("netlink error: %w", err)
110221
}
222+
case syscall.NLMSG_DONE:
223+
return nil, errors.New("route lookup returned no result")
224+
case syscall.RTM_NEWROUTE:
225+
var routeInfo rtInfo
226+
if err := routeInfo.parse(&m); err != nil {
227+
return nil, err
228+
}
229+
if routeInfo.Dst == nil && routeInfo.Src == nil && routeInfo.Gateway == nil {
230+
continue
231+
}
232+
return &routeInfo, nil
111233
}
112234
}
113-
rtr.addrs[iface.Index] = addrs
114235
}
115-
return rtr, nil
236+
}
237+
238+
func parseNetlinkError(data []byte) error {
239+
if len(data) < syscall.SizeofNlMsgerr {
240+
return fmt.Errorf("short netlink error payload: %d", len(data))
241+
}
242+
var msg syscall.NlMsgerr
243+
reader := bytes.NewReader(data[:syscall.SizeofNlMsgerr])
244+
if err := binary.Read(reader, binary.NativeEndian, &msg); err != nil {
245+
return fmt.Errorf("decode nlmsgerr: %w", err)
246+
}
247+
if msg.Error == 0 {
248+
return nil
249+
}
250+
251+
// Error is negative errno or 0 for acknowledgements
252+
return syscall.Errno(-msg.Error)
253+
}
254+
255+
func (routeInfo *rtInfo) parse(msg *syscall.NetlinkMessage) error {
256+
rt := (*syscall.RtMsg)(unsafe.Pointer(&msg.Data[0]))
257+
if rt.Family != syscall.AF_INET && rt.Family != syscall.AF_INET6 {
258+
return errors.New("unsupported address family")
259+
}
260+
261+
attrs, err := syscall.ParseNetlinkRouteAttr(msg)
262+
if err != nil {
263+
return err
264+
}
265+
for _, attr := range attrs {
266+
switch attr.Attr.Type {
267+
case syscall.RTA_DST:
268+
routeInfo.Dst = &net.IPNet{
269+
IP: net.IP(attr.Value),
270+
Mask: net.CIDRMask(int(rt.Dst_len), len(attr.Value)*8),
271+
}
272+
case syscall.RTA_SRC:
273+
routeInfo.Src = &net.IPNet{
274+
// Copy the IP so we don't keep a reference to this buffer
275+
IP: net.IP(attr.Value),
276+
Mask: net.CIDRMask(int(rt.Src_len), len(attr.Value)*8),
277+
}
278+
case syscall.RTA_GATEWAY:
279+
routeInfo.Gateway = net.IP(attr.Value)
280+
case syscall.RTA_PREFSRC:
281+
routeInfo.PrefSrc = net.IP(attr.Value)
282+
case syscall.RTA_IIF:
283+
routeInfo.InputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
284+
case syscall.RTA_OIF:
285+
routeInfo.OutputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
286+
case syscall.RTA_PRIORITY:
287+
routeInfo.Priority = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
288+
}
289+
}
290+
return nil
291+
}
292+
293+
func addressFamily(ip net.IP) int {
294+
if ip.To4() != nil {
295+
return syscall.AF_INET
296+
}
297+
return syscall.AF_INET6
298+
}
299+
300+
func copyIP(ip net.IP) net.IP {
301+
if len(ip) == 0 {
302+
return nil
303+
}
304+
out := slices.Clone(ip)
305+
if v4 := out.To4(); v4 != nil {
306+
out = v4
307+
}
308+
return out
116309
}

0 commit comments

Comments
 (0)