|
12 | 12 | package netroute |
13 | 13 |
|
14 | 14 | import ( |
| 15 | + "bytes" |
| 16 | + "encoding/binary" |
| 17 | + "errors" |
| 18 | + "fmt" |
15 | 19 | "net" |
16 | | - "sort" |
| 20 | + "os" |
| 21 | + "slices" |
| 22 | + "sync/atomic" |
17 | 23 | "syscall" |
18 | 24 | "unsafe" |
19 | 25 | ) |
20 | 26 |
|
| 27 | +var nlSequence uint32 |
| 28 | + |
| 29 | +type linuxRouter struct{} |
| 30 | + |
21 | 31 | func New() (Router, error) { |
22 | | - rtr := &router{} |
23 | | - rtr.ifaces = make(map[int]net.Interface) |
24 | | - rtr.addrs = make(map[int]ipAddrs) |
25 | | - tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC) |
26 | | - if err != nil { |
27 | | - return nil, err |
| 32 | + return &linuxRouter{}, nil |
| 33 | +} |
| 34 | + |
| 35 | +// Route implements Router. |
| 36 | +func (l *linuxRouter) Route(dst net.IP) (iface *net.Interface, gateway net.IP, preferredSrc net.IP, err error) { |
| 37 | + return l.RouteWithSrc(nil, nil, dst) |
| 38 | +} |
| 39 | + |
| 40 | +// RouteWithSrc implements Router. |
| 41 | +func (l *linuxRouter) RouteWithSrc(input net.HardwareAddr, src net.IP, dst net.IP) (iface *net.Interface, gateway net.IP, preferredSrc net.IP, err error) { |
| 42 | + if dst == nil || dst.IsUnspecified() { |
| 43 | + return nil, nil, nil, errors.New("destination IP must be specified") |
28 | 44 | } |
29 | | - msgs, err := syscall.ParseNetlinkMessage(tab) |
| 45 | + |
| 46 | + // trim bytes if this is a v4 addr |
| 47 | + if v4 := dst.To4(); v4 != nil { |
| 48 | + dst = v4 |
| 49 | + } |
| 50 | + |
| 51 | + dstFamily := addressFamily(dst) |
| 52 | + if src != nil && !src.IsUnspecified() { |
| 53 | + srcFamily := addressFamily(src) |
| 54 | + if srcFamily != dstFamily { |
| 55 | + return nil, nil, nil, fmt.Errorf("source %q and destination %q use different address families", src.String(), dst.String()) |
| 56 | + } |
| 57 | + } |
| 58 | + |
| 59 | + ifaces, err := net.Interfaces() |
30 | 60 | if err != nil { |
31 | | - return nil, err |
32 | | - } |
33 | | -loop: |
34 | | - for _, m := range msgs { |
35 | | - switch m.Header.Type { |
36 | | - case syscall.NLMSG_DONE: |
37 | | - break loop |
38 | | - case syscall.RTM_NEWROUTE: |
39 | | - rt := (*syscall.RtMsg)(unsafe.Pointer(&m.Data[0])) |
40 | | - routeInfo := rtInfo{} |
41 | | - attrs, err := syscall.ParseNetlinkRouteAttr(&m) |
42 | | - if err != nil { |
43 | | - return nil, err |
44 | | - } |
45 | | - if rt.Family != syscall.AF_INET && rt.Family != syscall.AF_INET6 { |
46 | | - continue loop |
47 | | - } |
48 | | - for _, attr := range attrs { |
49 | | - switch attr.Attr.Type { |
50 | | - case syscall.RTA_DST: |
51 | | - routeInfo.Dst = &net.IPNet{ |
52 | | - IP: net.IP(attr.Value), |
53 | | - Mask: net.CIDRMask(int(rt.Dst_len), len(attr.Value)*8), |
54 | | - } |
55 | | - case syscall.RTA_SRC: |
56 | | - routeInfo.Src = &net.IPNet{ |
57 | | - IP: net.IP(attr.Value), |
58 | | - Mask: net.CIDRMask(int(rt.Src_len), len(attr.Value)*8), |
59 | | - } |
60 | | - case syscall.RTA_GATEWAY: |
61 | | - routeInfo.Gateway = net.IP(attr.Value) |
62 | | - case syscall.RTA_PREFSRC: |
63 | | - routeInfo.PrefSrc = net.IP(attr.Value) |
64 | | - case syscall.RTA_IIF: |
65 | | - routeInfo.InputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0])) |
66 | | - case syscall.RTA_OIF: |
67 | | - routeInfo.OutputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0])) |
68 | | - case syscall.RTA_PRIORITY: |
69 | | - routeInfo.Priority = *(*uint32)(unsafe.Pointer(&attr.Value[0])) |
70 | | - } |
71 | | - } |
72 | | - if routeInfo.Dst == nil && routeInfo.Src == nil && routeInfo.Gateway == nil { |
73 | | - continue loop |
74 | | - } |
75 | | - switch rt.Family { |
76 | | - case syscall.AF_INET: |
77 | | - rtr.v4 = append(rtr.v4, &routeInfo) |
78 | | - case syscall.AF_INET6: |
79 | | - rtr.v6 = append(rtr.v6, &routeInfo) |
80 | | - default: |
81 | | - // should not happen. |
82 | | - continue loop |
| 61 | + return nil, nil, nil, fmt.Errorf("list interfaces: %w", err) |
| 62 | + } |
| 63 | + |
| 64 | + var oif int |
| 65 | + if len(input) > 0 { |
| 66 | + for i := range ifaces { |
| 67 | + iface := ifaces[i] |
| 68 | + if bytes.Equal(iface.HardwareAddr, input) { |
| 69 | + oif = iface.Index |
| 70 | + break |
83 | 71 | } |
84 | 72 | } |
| 73 | + if oif == 0 { |
| 74 | + return nil, nil, nil, fmt.Errorf("no interface with address %s found", input.String()) |
| 75 | + } |
85 | 76 | } |
86 | | - sort.Sort(rtr.v4) |
87 | | - sort.Sort(rtr.v6) |
88 | | - ifaces, err := net.Interfaces() |
| 77 | + |
| 78 | + fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_DGRAM, syscall.NETLINK_ROUTE) |
89 | 79 | if err != nil { |
90 | | - return nil, err |
| 80 | + return nil, nil, nil, fmt.Errorf("open netlink socket: %w", err) |
91 | 81 | } |
92 | | - for _, iface := range ifaces { |
93 | | - rtr.ifaces[iface.Index] = iface |
94 | | - var addrs ipAddrs |
95 | | - ifaceAddrs, err := iface.Addrs() |
96 | | - if err != nil { |
| 82 | + defer syscall.Close(fd) |
| 83 | + |
| 84 | + sa := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK} |
| 85 | + if err := syscall.Bind(fd, sa); err != nil { |
| 86 | + return nil, nil, nil, fmt.Errorf("bind netlink socket: %w", err) |
| 87 | + } |
| 88 | + |
| 89 | + seq := atomic.AddUint32(&nlSequence, 1) |
| 90 | + request, err := buildRouteRequest(dstFamily, dst, src, oif, seq, uint32(os.Getpid())) |
| 91 | + if err != nil { |
| 92 | + return nil, nil, nil, err |
| 93 | + } |
| 94 | + |
| 95 | + if err := syscall.Sendto(fd, request, 0, &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}); err != nil { |
| 96 | + return nil, nil, nil, fmt.Errorf("send netlink request: %w", err) |
| 97 | + } |
| 98 | + |
| 99 | + route, err := readRouteResponse(fd, seq) |
| 100 | + if err != nil { |
| 101 | + return nil, nil, nil, err |
| 102 | + } |
| 103 | + |
| 104 | + outIface, err := net.InterfaceByIndex(int(route.OutputIface)) |
| 105 | + if err != nil { |
| 106 | + return nil, nil, nil, fmt.Errorf("get interface by index: %w", err) |
| 107 | + } |
| 108 | + |
| 109 | + if len(input) > 0 && iface != nil && len(iface.HardwareAddr) > 0 && !bytes.Equal(iface.HardwareAddr, input) { |
| 110 | + return nil, nil, nil, fmt.Errorf("route resolved to interface %s (%s), expected %s", iface.Name, iface.HardwareAddr, input) |
| 111 | + } |
| 112 | + |
| 113 | + var srcIP net.IP |
| 114 | + if len(route.PrefSrc) > 0 && !route.PrefSrc.IsUnspecified() { |
| 115 | + srcIP = route.PrefSrc |
| 116 | + } else if route.Src != nil { |
| 117 | + srcIP = route.Src.IP |
| 118 | + } else { |
| 119 | + return nil, nil, nil, fmt.Errorf("no source IP found") |
| 120 | + } |
| 121 | + |
| 122 | + // copyIP so we don't leak a reference to our working buffer |
| 123 | + return outIface, copyIP(route.Gateway), copyIP(srcIP), nil |
| 124 | +} |
| 125 | + |
| 126 | +func buildRouteRequest(family int, dst []byte, src []byte, oif int, seq, pid uint32) ([]byte, error) { |
| 127 | + bodyBuf := new(bytes.Buffer) |
| 128 | + |
| 129 | + rtm := syscall.RtMsg{ |
| 130 | + Family: uint8(family), |
| 131 | + Dst_len: uint8(len(dst) * 8), |
| 132 | + Src_len: uint8(len(src) * 8), |
| 133 | + Table: syscall.RT_TABLE_UNSPEC, |
| 134 | + } |
| 135 | + if err := binary.Write(bodyBuf, binary.NativeEndian, rtm); err != nil { |
| 136 | + return nil, fmt.Errorf("marshal rtmsg: %w", err) |
| 137 | + } |
| 138 | + if len(dst) > 0 { |
| 139 | + if err := writeRouteAttr(bodyBuf, syscall.RTA_DST, dst); err != nil { |
| 140 | + return nil, err |
| 141 | + } |
| 142 | + } |
| 143 | + if len(src) > 0 { |
| 144 | + if err := writeRouteAttr(bodyBuf, syscall.RTA_SRC, src); err != nil { |
97 | 145 | return nil, err |
98 | 146 | } |
99 | | - for _, addr := range ifaceAddrs { |
100 | | - if inet, ok := addr.(*net.IPNet); ok { |
101 | | - // Go has a nasty habit of giving you IPv4s as ::ffff:1.2.3.4 instead of 1.2.3.4. |
102 | | - // We want to use mapped v4 addresses as v4 preferred addresses, never as v6 |
103 | | - // preferred addresses. |
104 | | - if v4 := inet.IP.To4(); v4 != nil { |
105 | | - if addrs.v4 == nil { |
106 | | - addrs.v4 = v4 |
107 | | - } |
108 | | - } else if addrs.v6 == nil { |
109 | | - addrs.v6 = inet.IP |
| 147 | + } |
| 148 | + if oif != 0 { |
| 149 | + oifBytes := make([]byte, 4) |
| 150 | + binary.NativeEndian.PutUint32(oifBytes, uint32(oif)) |
| 151 | + if err := writeRouteAttr(bodyBuf, syscall.RTA_OIF, oifBytes); err != nil { |
| 152 | + return nil, err |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + header := syscall.NlMsghdr{ |
| 157 | + Len: uint32(syscall.NLMSG_HDRLEN + bodyBuf.Len()), |
| 158 | + Type: uint16(syscall.RTM_GETROUTE), |
| 159 | + Flags: uint16(syscall.NLM_F_REQUEST), |
| 160 | + Seq: seq, |
| 161 | + Pid: pid, |
| 162 | + } |
| 163 | + |
| 164 | + msgBuf := new(bytes.Buffer) |
| 165 | + if err := binary.Write(msgBuf, binary.NativeEndian, header); err != nil { |
| 166 | + return nil, fmt.Errorf("marshal nlmsghdr: %w", err) |
| 167 | + } |
| 168 | + if _, err := msgBuf.Write(bodyBuf.Bytes()); err != nil { |
| 169 | + return nil, fmt.Errorf("assemble netlink request: %w", err) |
| 170 | + } |
| 171 | + return msgBuf.Bytes(), nil |
| 172 | +} |
| 173 | + |
| 174 | +func writeRouteAttr(buf *bytes.Buffer, attrType uint16, payload []byte) error { |
| 175 | + attrLen := uint16(syscall.SizeofRtAttr + len(payload)) |
| 176 | + attr := syscall.RtAttr{ |
| 177 | + Len: attrLen, |
| 178 | + Type: attrType, |
| 179 | + } |
| 180 | + if err := binary.Write(buf, binary.NativeEndian, attr); err != nil { |
| 181 | + return fmt.Errorf("marshal rtattr: %w", err) |
| 182 | + } |
| 183 | + if _, err := buf.Write(payload); err != nil { |
| 184 | + return fmt.Errorf("write rtattr payload: %w", err) |
| 185 | + } |
| 186 | + |
| 187 | + const nlMsghdrLenAlignment = 4 |
| 188 | + padLen := alignTo(attrLen, nlMsghdrLenAlignment) - int(attrLen) |
| 189 | + if padLen > 0 { |
| 190 | + _, _ = buf.Write(make([]byte, padLen)) |
| 191 | + } |
| 192 | + return nil |
| 193 | +} |
| 194 | + |
| 195 | +func alignTo(length uint16, alignment int) int { |
| 196 | + l := int(length) |
| 197 | + return (l + alignment - 1) & ^(alignment - 1) |
| 198 | +} |
| 199 | + |
| 200 | +func readRouteResponse(fd int, seq uint32) (*rtInfo, error) { |
| 201 | + buf := make([]byte, 1<<16) |
| 202 | + for { |
| 203 | + n, _, err := syscall.Recvfrom(fd, buf, 0) |
| 204 | + if err != nil { |
| 205 | + return nil, fmt.Errorf("receive netlink response: %w", err) |
| 206 | + } |
| 207 | + |
| 208 | + msgs, err := syscall.ParseNetlinkMessage(buf[:n]) |
| 209 | + if err != nil { |
| 210 | + return nil, fmt.Errorf("parse netlink response: %w", err) |
| 211 | + } |
| 212 | + |
| 213 | + for _, m := range msgs { |
| 214 | + if m.Header.Seq != seq { |
| 215 | + continue |
| 216 | + } |
| 217 | + switch m.Header.Type { |
| 218 | + case syscall.NLMSG_ERROR: |
| 219 | + if err := parseNetlinkError(m.Data); err != nil { |
| 220 | + return nil, fmt.Errorf("netlink error: %w", err) |
110 | 221 | } |
| 222 | + case syscall.NLMSG_DONE: |
| 223 | + return nil, errors.New("route lookup returned no result") |
| 224 | + case syscall.RTM_NEWROUTE: |
| 225 | + var routeInfo rtInfo |
| 226 | + if err := routeInfo.parse(&m); err != nil { |
| 227 | + return nil, err |
| 228 | + } |
| 229 | + if routeInfo.Dst == nil && routeInfo.Src == nil && routeInfo.Gateway == nil { |
| 230 | + continue |
| 231 | + } |
| 232 | + return &routeInfo, nil |
111 | 233 | } |
112 | 234 | } |
113 | | - rtr.addrs[iface.Index] = addrs |
114 | 235 | } |
115 | | - return rtr, nil |
| 236 | +} |
| 237 | + |
| 238 | +func parseNetlinkError(data []byte) error { |
| 239 | + if len(data) < syscall.SizeofNlMsgerr { |
| 240 | + return fmt.Errorf("short netlink error payload: %d", len(data)) |
| 241 | + } |
| 242 | + var msg syscall.NlMsgerr |
| 243 | + reader := bytes.NewReader(data[:syscall.SizeofNlMsgerr]) |
| 244 | + if err := binary.Read(reader, binary.NativeEndian, &msg); err != nil { |
| 245 | + return fmt.Errorf("decode nlmsgerr: %w", err) |
| 246 | + } |
| 247 | + if msg.Error == 0 { |
| 248 | + return nil |
| 249 | + } |
| 250 | + |
| 251 | + // Error is negative errno or 0 for acknowledgements |
| 252 | + return syscall.Errno(-msg.Error) |
| 253 | +} |
| 254 | + |
| 255 | +func (routeInfo *rtInfo) parse(msg *syscall.NetlinkMessage) error { |
| 256 | + rt := (*syscall.RtMsg)(unsafe.Pointer(&msg.Data[0])) |
| 257 | + if rt.Family != syscall.AF_INET && rt.Family != syscall.AF_INET6 { |
| 258 | + return errors.New("unsupported address family") |
| 259 | + } |
| 260 | + |
| 261 | + attrs, err := syscall.ParseNetlinkRouteAttr(msg) |
| 262 | + if err != nil { |
| 263 | + return err |
| 264 | + } |
| 265 | + for _, attr := range attrs { |
| 266 | + switch attr.Attr.Type { |
| 267 | + case syscall.RTA_DST: |
| 268 | + routeInfo.Dst = &net.IPNet{ |
| 269 | + IP: net.IP(attr.Value), |
| 270 | + Mask: net.CIDRMask(int(rt.Dst_len), len(attr.Value)*8), |
| 271 | + } |
| 272 | + case syscall.RTA_SRC: |
| 273 | + routeInfo.Src = &net.IPNet{ |
| 274 | + // Copy the IP so we don't keep a reference to this buffer |
| 275 | + IP: net.IP(attr.Value), |
| 276 | + Mask: net.CIDRMask(int(rt.Src_len), len(attr.Value)*8), |
| 277 | + } |
| 278 | + case syscall.RTA_GATEWAY: |
| 279 | + routeInfo.Gateway = net.IP(attr.Value) |
| 280 | + case syscall.RTA_PREFSRC: |
| 281 | + routeInfo.PrefSrc = net.IP(attr.Value) |
| 282 | + case syscall.RTA_IIF: |
| 283 | + routeInfo.InputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0])) |
| 284 | + case syscall.RTA_OIF: |
| 285 | + routeInfo.OutputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0])) |
| 286 | + case syscall.RTA_PRIORITY: |
| 287 | + routeInfo.Priority = *(*uint32)(unsafe.Pointer(&attr.Value[0])) |
| 288 | + } |
| 289 | + } |
| 290 | + return nil |
| 291 | +} |
| 292 | + |
| 293 | +func addressFamily(ip net.IP) int { |
| 294 | + if ip.To4() != nil { |
| 295 | + return syscall.AF_INET |
| 296 | + } |
| 297 | + return syscall.AF_INET6 |
| 298 | +} |
| 299 | + |
| 300 | +func copyIP(ip net.IP) net.IP { |
| 301 | + if len(ip) == 0 { |
| 302 | + return nil |
| 303 | + } |
| 304 | + out := slices.Clone(ip) |
| 305 | + if v4 := out.To4(); v4 != nil { |
| 306 | + out = v4 |
| 307 | + } |
| 308 | + return out |
116 | 309 | } |
0 commit comments