diff --git a/go.mod b/go.mod index 0070099..1dca28f 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,14 @@ require ( github.com/google/go-cmp v0.7.0 github.com/mdlayher/genetlink v1.3.2 github.com/mdlayher/netlink v1.8.0 + github.com/ti-mo/conntrack v0.5.2 + github.com/ti-mo/netfilter v0.5.3 golang.org/x/sys v0.37.0 ) require ( github.com/mdlayher/socket v0.5.1 // indirect + github.com/pkg/errors v0.9.1 // indirect golang.org/x/net v0.45.0 // indirect - golang.org/x/sync v0.3.0 // indirect + golang.org/x/sync v0.14.0 // indirect ) diff --git a/go.sum b/go.sum index 138a534..e9e4608 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= @@ -6,9 +8,23 @@ github.com/mdlayher/netlink v1.8.0 h1:e7XNIYJKD7hUct3Px04RuIGJbBxy1/c4nX7D5Yyvvl github.com/mdlayher/netlink v1.8.0/go.mod h1:UhgKXUlDQhzb09DrCl2GuRNEglHmhYoWAHid9HK3594= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/ti-mo/conntrack v0.5.2 h1:PQ7MCdFjniEiTJT+qsAysREUsT5iH62/VNyhkB06HOI= +github.com/ti-mo/conntrack v0.5.2/go.mod h1:4HZrFQQLOSuBzgQNid3H/wYyyp1kfGXUYxueXjIGibo= +github.com/ti-mo/netfilter v0.5.3 h1:ikzduvnaUMwre5bhbNwWOd6bjqLMVb33vv0XXbK0xGQ= +github.com/ti-mo/netfilter v0.5.3/go.mod h1:08SyBCg6hu1qyQk4s3DjjJKNrm3RTb32nm6AzyT972E= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM= golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= -golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ovsnl/client.go b/ovsnl/client.go index 4e522bc..518f621 100644 --- a/ovsnl/client.go +++ b/ovsnl/client.go @@ -37,7 +37,8 @@ type Client struct { // Datapath provides access to DatapathService methods. Datapath *DatapathService - c *genetlink.Conn + c *genetlink.Conn + Agg *ZoneMarkAggregator } // New creates a new Linux Open vSwitch generic netlink client. @@ -45,37 +46,50 @@ type Client struct { // If no OvS generic netlink families are available on this system, an // error will be returned which can be checked using os.IsNotExist. func New() (*Client, error) { - c, err := genetlink.Dial(nil) + c := &Client{} // Create client instance first + + // Initialize the underlying genetlink connection. + conn, err := genetlink.Dial(nil) if err != nil { return nil, err } + c.c = conn - return newClient(c) -} - -// newClient is the internal Client constructor, used in tests. -func newClient(c *genetlink.Conn) (*Client, error) { - // Must ensure that the generic netlink connection is closed on any errors - // that occur before it is returned to the caller. - - families, err := c.ListFamilies() + // Initialize services. + families, err := c.c.ListFamilies() if err != nil { - _ = c.Close() + _ = c.c.Close() return nil, err } - client := &Client{c: c} - if err := client.init(families); err != nil { - _ = c.Close() + if err := c.init(families); err != nil { + _ = c.c.Close() return nil, err } - return client, nil + // Initialize aggregator as nil - will be created when needed + c.Agg = nil + + return c, nil } // Close closes the Client's generic netlink connection. func (c *Client) Close() error { - return c.c.Close() + var errs []error + + if c.Agg != nil { + c.Agg.Stop() + } + + if c.c != nil { + if err := c.c.Close(); err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + return fmt.Errorf("errors closing client: %v", errs) + } + return nil } // init initializes the generic netlink family service of Client. @@ -83,18 +97,22 @@ func (c *Client) init(families []genetlink.Family) error { var gotf int for _, f := range families { - // Ignore any families without the OVS prefix. - if !strings.HasPrefix(f.Name, "ovs_") { - continue - } - // Ignore any families that might be unknown. - if err := c.initFamily(f); err != nil { + // Initialize OVS-specific families + if strings.HasPrefix(f.Name, "ovs_") { + if err := c.initFamily(f); err != nil { + // Log but continue if an OVS family fails to init + fmt.Printf("Warning: failed to initialize OVS family %q: %v\n", f.Name, err) + continue + } + } else if f.Name == "nf_conntrack" { // Explicitly initialize for Netfilter conntrack family + // Acknowledge that conntrack family exists - aggregator will handle conntrack operations + } else { + // Skip other non-OVS/non-conntrack families continue } gotf++ } - // No known families; return error for os.IsNotExist check. if gotf == 0 { return os.ErrNotExist } diff --git a/ovsnl/client_linux_test.go b/ovsnl/client_linux_test.go index d09239d..a18b760 100644 --- a/ovsnl/client_linux_test.go +++ b/ovsnl/client_linux_test.go @@ -29,6 +29,28 @@ import ( "golang.org/x/sys/unix" ) +func newTestClient(conn *genetlink.Conn) (*Client, error) { + c := &Client{} + c.c = conn + + families, err := c.c.ListFamilies() + if err != nil { + return nil, err + } + + if err := c.init(families); err != nil { + return nil, err + } + + // Inject our mock connection directly into the datapath service + if c.Datapath != nil { + c.Datapath.c = c + c.c = conn + } + + return c, nil +} + func TestClientNoFamiliesIsNotExist(t *testing.T) { conn := genltest.Dial(func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { // Unrelated generic netlink families. @@ -38,7 +60,7 @@ func TestClientNoFamiliesIsNotExist(t *testing.T) { }), nil }) - _, err := newClient(conn) + _, err := newTestClient(conn) if !os.IsNotExist(err) { t.Fatalf("expected is not exist error, but got: %v", err) } @@ -53,7 +75,7 @@ func TestClientUnknownFamilies(t *testing.T) { }), nil }) - _, err := newClient(conn) + _, err := newTestClient(conn) if err == nil { t.Fatalf("expected an error, but none occurred") } @@ -67,7 +89,7 @@ func TestClientNoFamilies(t *testing.T) { return nil, nil }) - _, err := newClient(conn) + _, err := newTestClient(conn) if err == nil { t.Fatalf("expected an error, but none occurred") } @@ -82,7 +104,7 @@ func TestClientKnownFamilies(t *testing.T) { }), nil }) - _, err := newClient(conn) + _, err := newTestClient(conn) if err != nil { t.Fatalf("failed to create client: %v", err) } @@ -112,24 +134,6 @@ func familyMessages(families []string) []genetlink.Message { return msgs } -// ovsFamilies creates a genltest.Func which intercepts "list family" requests -// and returns all the OVS families. Other requests are passed through to fn. -func ovsFamilies(fn genltest.Func) genltest.Func { - return func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { - if nreq.Header.Type == unix.GENL_ID_CTRL && greq.Header.Command == unix.CTRL_CMD_GETFAMILY { - return familyMessages([]string{ - ovsh.DatapathFamily, - ovsh.FlowFamily, - ovsh.PacketFamily, - ovsh.VportFamily, - ovsh.MeterFamily, - }), nil - } - - return fn(greq, nreq) - } -} - func mustMarshalAttributes(attrs []netlink.Attribute) []byte { b, err := netlink.MarshalAttributes(attrs) if err != nil { diff --git a/ovsnl/conntrack_common.go b/ovsnl/conntrack_common.go new file mode 100644 index 0000000..205c317 --- /dev/null +++ b/ovsnl/conntrack_common.go @@ -0,0 +1,61 @@ +// Copyright 2017 DigitalOcean. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ovsnl + +import ( + "net" +) + +// ConntrackEntry represents a single connection tracking entry from the kernel. +type ConntrackEntry struct { + Protocol string // "tcp", "udp", "icmp" etc. + OrigSrc net.IP + OrigDst net.IP + OrigSPort uint16 + OrigDPort uint16 + ReplySrc net.IP + ReplyDst net.IP + ReplySPort uint16 + ReplyDPort uint16 + Zone uint16 + Mark uint32 + State string +} + +// ZoneStats holds statistics for a zone +type ZoneStats struct { + TotalCount int + Entries []ConntrackEntry // Only populated if TotalCount > threshold +} + +// ConntrackPerformanceStats represents aggregated performance counters from all CPUs +type ConntrackPerformanceStats struct { + TotalFound uint32 + TotalInvalid uint32 + TotalIgnore uint32 + TotalInsert uint32 + TotalInsertFailed uint32 + TotalDrop uint32 + TotalEarlyDrop uint32 + TotalError uint32 + TotalSearchRestart uint32 + CPUs int +} + +// ZmKey is a compact key for (zone,mark) +type ZmKey struct { + Zone uint16 + Mark uint32 +} diff --git a/ovsnl/conntrack_linux.go b/ovsnl/conntrack_linux.go new file mode 100644 index 0000000..d2ad1ef --- /dev/null +++ b/ovsnl/conntrack_linux.go @@ -0,0 +1,411 @@ +// Copyright 2017 DigitalOcean. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux +// +build linux + +package ovsnl + +import ( + "fmt" + "log" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/ti-mo/conntrack" + "github.com/ti-mo/netfilter" +) + +// +// Conntrack aggregator with bounded ingestion + DESTROY aggregation +// to handle massive bursts of conntrack DESTROY events without OOMing. +// + +// Tunables - adjust for your environment +const ( + eventChanSize = 512 * 1024 + eventWorkerCount = 100 + destroyFlushIntvl = 100 * time.Millisecond // flush aggregated DESTROYs every 100ms for minimal lag + destroyDeltaCap = 200000 // maximum distinct (zone,mark) entries in destroyDeltas + dropsWarnThreshold = 100 // threshold of missedEvents to log a stronger warning +) + +// ZoneMarkAggregator keeps live counts (zmKey -> count) with bounded ingestion +type ZoneMarkAggregator struct { + // primary counts (zmKey -> count) - simplified flat mapping + counts map[ZmKey]int + mu sync.RWMutex + + // conntrack listening connection + listenCli *conntrack.Conn + + // lifecycle + stopCh chan struct{} + stoppedCh chan struct{} + + // bounded event ingestion + eventsCh chan conntrack.Event + + // aggregated DESTROY deltas (bounded by destroyDeltaCap) + deltaMu sync.Mutex + destroyDeltas map[ZmKey]int + + // metrics / health + eventCount int64 + lastEventTime time.Time + eventRate float64 + missedEvents int64 + lastHealthCheck time.Time + + // initial snapshot state (we keep disabled for huge tables) + initialSnapshotComplete bool + initialSnapshotError error +} + +// NewZoneMarkAggregator creates a new aggregator with its own listening connection. +func NewZoneMarkAggregator() (*ZoneMarkAggregator, error) { + + // Create a separate connection for listening to events + listenCli, err := conntrack.Dial(nil) + if err != nil { + return nil, fmt.Errorf("failed to create listening connection: %w", err) + } + + if err := listenCli.SetReadBuffer(64 * 1024 * 1024); err != nil { // 64MB buffer for 1.4M events/sec + log.Printf("Warning: Failed to set read buffer size: %v", err) + } + if err := listenCli.SetWriteBuffer(64 * 1024 * 1024); err != nil { // 64MB buffer for 1.4M events/sec + log.Printf("Warning: Failed to set write buffer size: %v", err) + } + + a := &ZoneMarkAggregator{ + counts: make(map[ZmKey]int), + listenCli: listenCli, + stopCh: make(chan struct{}), + stoppedCh: make(chan struct{}), + eventsCh: make(chan conntrack.Event, eventChanSize), + destroyDeltas: make(map[ZmKey]int), + lastEventTime: time.Now(), + lastHealthCheck: time.Now(), + initialSnapshotComplete: false, + initialSnapshotError: nil, + } + + return a, nil +} + +// Start subscribes to NEW/DESTROY/UPDATE events and maintains counts with bounded ingestion. +func (a *ZoneMarkAggregator) Start() error { + + if err := a.startEventListener(); err != nil { + return err + } + + for i := 0; i < eventWorkerCount; i++ { + go a.eventWorker(i) + } + + go a.destroyFlusher() + go a.startHealthMonitoring() + + go func() { + a.initialSnapshotComplete = true + a.initialSnapshotError = nil + }() + + return nil +} + +// startEventListener handles real-time conntrack events, pushing into bounded eventsCh. +func (a *ZoneMarkAggregator) startEventListener() error { + libEvents := make(chan conntrack.Event, 8192) + groups := []netfilter.NetlinkGroup{ + netfilter.GroupCTNew, + netfilter.GroupCTDestroy, + netfilter.GroupCTUpdate, + } + + errCh, err := a.listenCli.Listen(libEvents, 10, groups) + if err != nil { + return fmt.Errorf("failed to listen to conntrack events: %w", err) + } + + go func() { + eventCount := int64(0) + rateWindow := make([]time.Time, 0, 100) + + for { + select { + case <-a.stopCh: + log.Printf("Stopping lib->bounded relay after %d lib events", atomic.LoadInt64(&eventCount)) + return + case e := <-errCh: + if e != nil { + log.Printf("conntrack listener error: %v", e) + atomic.AddInt64(&a.missedEvents, 1) + } + case ev := <-libEvents: + select { + case a.eventsCh <- ev: + atomic.AddInt64(&eventCount, 1) + atomic.StoreInt64(&a.eventCount, eventCount) + a.lastEventTime = time.Now() + + rateWindow = append(rateWindow, a.lastEventTime) + if len(rateWindow) > 100 { + rateWindow = rateWindow[1:] + } + if len(rateWindow) > 1 { + duration := rateWindow[len(rateWindow)-1].Sub(rateWindow[0]) + if duration > 0 { + a.eventRate = float64(len(rateWindow)-1) / duration.Seconds() + } + } + default: + atomic.AddInt64(&a.missedEvents, 1) + if atomic.LoadInt64(&a.missedEvents)%100 == 0 { + log.Printf("Warning: eventsCh full, missedEvents=%d", atomic.LoadInt64(&a.missedEvents)) + } + } + } + } + }() + + return nil +} + +// eventWorker consumes events from eventsCh and handles them +func (a *ZoneMarkAggregator) eventWorker(id int) { + processedCount := 0 + + for { + select { + case <-a.stopCh: + log.Printf("Event worker %d stopping (processed %d events)", id, processedCount) + return + case ev := <-a.eventsCh: + a.handleEvent(ev) + processedCount++ + if atomic.LoadInt64(&a.eventCount)%100 == 0 { + runtime.Gosched() + } + } + } +} + +// handleEvent processes a single event. +func (a *ZoneMarkAggregator) handleEvent(ev conntrack.Event) { + f := ev.Flow + key := ZmKey{Zone: f.Zone, Mark: f.Mark} + + if ev.Type == conntrack.EventNew { + a.mu.Lock() + a.counts[key]++ + a.mu.Unlock() + return + } + + if ev.Type == conntrack.EventDestroy { + a.deltaMu.Lock() + if len(a.destroyDeltas) < destroyDeltaCap { + a.destroyDeltas[key]++ + if len(a.destroyDeltas) > 50000 { // If we have >50K deltas, flush immediately + deltas := a.destroyDeltas + a.destroyDeltas = make(map[ZmKey]int) + // Acquire mu while still holding deltaMu to maintain lock ordering + a.mu.Lock() + a.deltaMu.Unlock() + // Apply deltas immediately to minimize lag during extreme load + a.applyDeltasImmediatelyUnsafe(deltas) + a.mu.Unlock() + return + } + // Log every 1000 DESTROY events to verify they're being received + if len(a.destroyDeltas)%1000 == 0 { + log.Printf("DESTROY events: %d entries in destroyDeltas (zone=%d, mark=%d)", len(a.destroyDeltas), key.Zone, key.Mark) + } + } else { + atomic.AddInt64(&a.missedEvents, 1) + if atomic.LoadInt64(&a.missedEvents)%dropsWarnThreshold == 0 { + log.Printf("Warning: destroyDeltas saturated (size=%d). missedEvents=%d", len(a.destroyDeltas), atomic.LoadInt64(&a.missedEvents)) + } + } + a.deltaMu.Unlock() + return + } +} + +// applyDeltasImmediatelyUnsafe applies deltas immediately to minimize lag during extreme load +// This method assumes mu is already held by the caller +func (a *ZoneMarkAggregator) applyDeltasImmediatelyUnsafe(deltas map[ZmKey]int) { + totalDecrements := 0 + for k, cnt := range deltas { + existing, ok := a.counts[k] + if !ok { + atomic.AddInt64(&a.missedEvents, int64(cnt)) + continue + } + if existing <= cnt { + delete(a.counts, k) + totalDecrements += existing + } else { + a.counts[k] = existing - cnt + totalDecrements += cnt + } + } +} + +// destroyFlusher periodically applies the aggregated DESTROY deltas into counts +// Uses adaptive flushing: more frequent during high event rates for minimal lag +func (a *ZoneMarkAggregator) destroyFlusher() { + ticker := time.NewTicker(destroyFlushIntvl) + defer ticker.Stop() + + for { + select { + case <-a.stopCh: + log.Printf("Destroy flusher stopping, final flush...") + a.flushDestroyDeltas() + return + case <-ticker.C: + // Adaptive flushing: flush more frequently during high event rates + a.mu.RLock() + eventRate := a.eventRate + a.mu.RUnlock() + + if eventRate > 500000 { // Very high event rate (>500K/sec) + // Flush immediately and reset ticker for faster interval + a.flushDestroyDeltas() + ticker.Reset(50 * time.Millisecond) // 50ms during extreme load + } else if eventRate > 100000 { // High event rate (>100K/sec) + a.flushDestroyDeltas() + ticker.Reset(100 * time.Millisecond) // 100ms during high load + } else if eventRate > 10000 { // Medium event rate (>10K/sec) + a.flushDestroyDeltas() + ticker.Reset(200 * time.Millisecond) // 200ms during medium load + } else { + // Normal flush + a.flushDestroyDeltas() + ticker.Reset(destroyFlushIntvl) // Back to normal interval + } + } + } +} + +// flushDestroyDeltas atomically swaps the delta map and applies decrements +func (a *ZoneMarkAggregator) flushDestroyDeltas() { + // First acquire deltaMu to check and swap deltas + a.deltaMu.Lock() + if len(a.destroyDeltas) == 0 { + a.deltaMu.Unlock() + return + } + deltas := a.destroyDeltas + a.destroyDeltas = make(map[ZmKey]int) + + // Now acquire mu while still holding deltaMu to ensure atomicity + a.mu.Lock() + // Keep deltaMu locked during processing to prevent race conditions + defer func() { + a.mu.Unlock() + a.deltaMu.Unlock() + }() + + totalDecrements := 0 + for k, cnt := range deltas { + existing, ok := a.counts[k] + if !ok { + atomic.AddInt64(&a.missedEvents, int64(cnt)) + continue + } + if existing <= cnt { + delete(a.counts, k) + totalDecrements += existing + } else { + a.counts[k] = existing - cnt + totalDecrements += cnt + } + } +} + +// Snapshot returns a safe copy of counts. +func (a *ZoneMarkAggregator) Snapshot() map[ZmKey]int { + a.flushDestroyDeltas() + a.mu.RLock() + defer a.mu.RUnlock() + + out := make(map[ZmKey]int, len(a.counts)) + for k, c := range a.counts { + if c > 0 { + out[k] = c + } + } + return out +} + +// startHealthMonitoring periodically logs aggregator health +func (a *ZoneMarkAggregator) startHealthMonitoring() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-a.stopCh: + return + case <-ticker.C: + a.performHealthCheck() + } + } +} + +func (a *ZoneMarkAggregator) performHealthCheck() { + missed := atomic.LoadInt64(&a.missedEvents) + + if missed > dropsWarnThreshold { + if err := a.RestartListener(); err != nil { + log.Printf("Health check: RestartListener failed: %v", err) + } else { + atomic.StoreInt64(&a.missedEvents, 0) + log.Printf("Health check: Listener restarted successfully") + } + } + a.lastHealthCheck = time.Now() +} + +// Stop cancels listening and closes the connection. +func (a *ZoneMarkAggregator) Stop() { + close(a.stopCh) + time.Sleep(20 * time.Millisecond) + if a.listenCli != nil { + if err := a.listenCli.Close(); err != nil { + log.Printf("Error closing listenCli during cleanup: %v", err) + } + } + a.flushDestroyDeltas() +} + +// RestartListener attempts to restart the conntrack event listener +func (a *ZoneMarkAggregator) RestartListener() error { + if a.listenCli != nil { + _ = a.listenCli.Close() + } + listenCli, err := conntrack.Dial(nil) + if err != nil { + return fmt.Errorf("failed to create new listening connection: %w", err) + } + a.listenCli = listenCli + return a.startEventListener() +} diff --git a/ovsnl/conntrack_linux_test.go b/ovsnl/conntrack_linux_test.go new file mode 100644 index 0000000..06f019f --- /dev/null +++ b/ovsnl/conntrack_linux_test.go @@ -0,0 +1,116 @@ +// Copyright 2017 DigitalOcean. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux + +package ovsnl + +import ( + "testing" +) + +func TestZoneMarkAggregator(t *testing.T) { + // Test aggregator creation + agg, err := NewZoneMarkAggregator() + if err != nil { + // This is expected to fail in test environment due to permission requirements + t.Logf("Expected failure in test environment: NewZoneMarkAggregator() error = %v", err) + return + } + + if agg == nil { + t.Fatal("NewZoneMarkAggregator() returned nil aggregator") + } + + // Test basic methods + snapshot := agg.Snapshot() + if snapshot == nil { + t.Fatal("Snapshot() returned nil") + } + + // Clean up + agg.Stop() +} + +func TestZoneMarkAggregatorSnapshot(t *testing.T) { + // Test aggregator creation + agg, err := NewZoneMarkAggregator() + if err != nil { + // This is expected to fail in test environment due to permission requirements + t.Logf("Expected failure in test environment: NewZoneMarkAggregator() error = %v", err) + return + } + + if agg == nil { + t.Fatal("NewZoneMarkAggregator() returned nil aggregator") + } + + // Test snapshot functionality with new zmKey-based mapping + snapshot := agg.Snapshot() + if snapshot == nil { + t.Fatal("Snapshot() returned nil") + } + + // Verify snapshot is a map[ZmKey]int + if len(snapshot) == 0 { + t.Log("Snapshot is empty (expected in test environment)") + } + + // Test that we can iterate over the snapshot + for key, count := range snapshot { + if count <= 0 { + t.Errorf("Invalid count %d for key %+v", count, key) + } + t.Logf("Zone: %d, Mark: %d, Count: %d", key.Zone, key.Mark, count) + } + + // Clean up + agg.Stop() +} + +func TestZmKeyComparison(t *testing.T) { + // Test that ZmKey works correctly as a map key + key1 := ZmKey{Zone: 1, Mark: 100} + key2 := ZmKey{Zone: 1, Mark: 100} + key3 := ZmKey{Zone: 2, Mark: 100} + key4 := ZmKey{Zone: 1, Mark: 200} + + // Test equality + if key1 != key2 { + t.Error("Identical ZmKey structs should be equal") + } + + // Test inequality + if key1 == key3 { + t.Error("Different zone ZmKey structs should not be equal") + } + if key1 == key4 { + t.Error("Different mark ZmKey structs should not be equal") + } + + // Test as map keys + testMap := make(map[ZmKey]int) + testMap[key1] = 5 + testMap[key3] = 10 + + if testMap[key1] != 5 { + t.Error("ZmKey should work as map key") + } + if testMap[key2] != 5 { + t.Error("Equal ZmKey structs should map to same value") + } + if testMap[key3] != 10 { + t.Error("Different ZmKey should map to different value") + } +} diff --git a/ovsnl/conntrack_stub.go b/ovsnl/conntrack_stub.go new file mode 100644 index 0000000..488710b --- /dev/null +++ b/ovsnl/conntrack_stub.go @@ -0,0 +1,56 @@ +// Copyright 2017 DigitalOcean. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !linux + +package ovsnl + +import ( + "context" + "errors" +) + +// ZoneMarkAggregator keeps live counts (zone -> mark -> count). +type ZoneMarkAggregator struct { + // No implementation on non-Linux platforms +} + +// NewZoneMarkAggregator creates a new aggregator. +// On non-Linux platforms, this returns an error. +func NewZoneMarkAggregator() (*ZoneMarkAggregator, error) { + return nil, errors.ErrUnsupported +} + +// Start subscribes to conntrack events and maintains counts. +// On non-Linux platforms, this returns an error. +func (a *ZoneMarkAggregator) Start() error { + return errors.ErrUnsupported +} + +// Stop cancels listening. +func (a *ZoneMarkAggregator) Stop() { + // No-op on non-Linux platforms +} + +// Snapshot returns a safe copy of counts. +// On non-Linux platforms, this returns an empty map. +func (a *ZoneMarkAggregator) Snapshot() map[ZmKey]int { + return make(map[ZmKey]int) +} + +// PrimeSnapshot tries a guarded one-shot dump to seed counts for long-lived flows. +// On non-Linux platforms, this returns an error. +func (a *ZoneMarkAggregator) PrimeSnapshot(ctx context.Context, maxEntries int) error { + return errors.ErrUnsupported +} diff --git a/ovsnl/datapath_linux_integration_test.go b/ovsnl/datapath_linux_integration_test.go new file mode 100644 index 0000000..8685127 --- /dev/null +++ b/ovsnl/datapath_linux_integration_test.go @@ -0,0 +1,51 @@ +// Copyright 2017 DigitalOcean. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux && integration + +package ovsnl + +import ( + "testing" +) + +// TestClientDatapathListIntegration tests datapath listing with real Open vSwitch +func TestClientDatapathListIntegration(t *testing.T) { + // Skip if not running in integration test environment + if testing.Short() { + t.Skip("skipping integration test") + } + + c, err := New() + if err != nil { + t.Skipf("skipping integration test: %v", err) + } + defer c.Close() + + dps, err := c.Datapath.List() + if err != nil { + t.Fatalf("failed to list datapaths: %v", err) + } + + if len(dps) == 0 { + t.Log("no datapaths found (Open vSwitch may not be running)") + return + } + + // Verify we can list datapaths + t.Logf("found %d datapaths", len(dps)) + for _, dp := range dps { + t.Logf("datapath: %s (index: %d)", dp.Name, dp.Index) + } +} diff --git a/ovsnl/datapath_linux_test.go b/ovsnl/datapath_linux_test.go index c40b145..1e55587 100644 --- a/ovsnl/datapath_linux_test.go +++ b/ovsnl/datapath_linux_test.go @@ -26,26 +26,33 @@ import ( "github.com/mdlayher/genetlink/genltest" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nlenc" + "golang.org/x/sys/unix" ) func TestClientDatapathListShortHeader(t *testing.T) { - conn := genltest.Dial(ovsFamilies(func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { - // Not enough data for ovsh.Header. - return []genetlink.Message{ - { - Data: []byte{0xff, 0xff}, - }, - }, nil - })) + conn := genltest.Dial(genltest.Func(ovsFamilies(func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { + + // Check if this is the datapath list command + if greq.Header.Command == ovsh.DpCmdGet { + // Return deliberately short data for datapath list + shortData := []byte{0xff, 0xff} + return []genetlink.Message{ + {Data: shortData}, + }, nil + } + + return []genetlink.Message{}, nil + }))) - c, err := newClient(conn) + c, err := newTestClient(conn) if err != nil { t.Fatalf("failed to create client: %v", err) } defer c.Close() - _, err = c.Datapath.List() - if err == nil { + _, errDatapath := c.Datapath.List() + if errDatapath == nil { + t.Fatalf("expected an error, but none occurred") } @@ -53,22 +60,22 @@ func TestClientDatapathListShortHeader(t *testing.T) { } func TestClientDatapathListBadStats(t *testing.T) { - conn := genltest.Dial(ovsFamilies(func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { + conn := genltest.Dial(genltest.Func(ovsFamilies(func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { // Valid header; not enough data for ovsh.DPStats. return []genetlink.Message{{ Data: append( - // ovsh.Header. + // ovsh.Header (4 bytes). []byte{0xff, 0xff, 0xff, 0xff}, // netlink attributes. mustMarshalAttributes([]netlink.Attribute{{ Type: ovsh.DpAttrStats, - Data: []byte{0xff}, + Data: []byte{0xff}, // Only 1 byte, but sizeofDPStats is 32 bytes }})..., ), }}, nil - })) + }))) - c, err := newClient(conn) + c, err := newTestClient(conn) if err != nil { t.Fatalf("failed to create client: %v", err) } @@ -83,22 +90,22 @@ func TestClientDatapathListBadStats(t *testing.T) { } func TestClientDatapathListBadMegaflowStats(t *testing.T) { - conn := genltest.Dial(ovsFamilies(func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { + conn := genltest.Dial(genltest.Func(ovsFamilies(func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { // Valid header; not enough data for ovsh.DPMegaflowStats. return []genetlink.Message{{ Data: append( - // ovsh.Header. + // ovsh.Header (4 bytes). []byte{0xff, 0xff, 0xff, 0xff}, // netlink attributes. mustMarshalAttributes([]netlink.Attribute{{ Type: ovsh.DpAttrMegaflowStats, - Data: []byte{0xff}, + Data: []byte{0xff}, // Only 1 byte, but sizeofDPMegaflowStats is 32 bytes }})..., ), }}, nil - })) + }))) - c, err := newClient(conn) + c, err := newTestClient(conn) if err != nil { t.Fatalf("failed to create client: %v", err) } @@ -129,7 +136,7 @@ func TestClientDatapathListOK(t *testing.T) { }, } - conn := genltest.Dial(ovsFamilies(func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { + conn := genltest.Dial(genltest.Func(ovsFamilies(func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { // Ensure we are querying the "ovs_datapath" family with the // correct parameters. if diff := cmp.Diff(ovsh.DpCmdGet, int(greq.Header.Command)); diff != "" { @@ -150,9 +157,9 @@ func TestClientDatapathListOK(t *testing.T) { Data: mustMarshalDatapath(system), }, }, nil - })) + }))) - c, err := newClient(conn) + c, err := newTestClient(conn) if err != nil { t.Fatalf("failed to create client: %v", err) } @@ -217,3 +224,21 @@ func mustMarshalDatapath(dp Datapath) []byte { return append(hb[:], ab...) } + +// ovsFamilies creates a test handler that returns OVS family messages +type handlerFn func(genetlink.Message, netlink.Message) ([]genetlink.Message, error) + +func ovsFamilies(handler handlerFn) handlerFn { + return func(greq genetlink.Message, nreq netlink.Message) ([]genetlink.Message, error) { + + // Handle family listing requests (CTRL family) + if nreq.Header.Type == unix.GENL_ID_CTRL && greq.Header.Command == unix.CTRL_CMD_GETFAMILY { + return familyMessages([]string{ + ovsh.DatapathFamily, + }), nil + } + + // Handle actual datapath requests (OVS datapath family) + return handler(greq, nreq) + } +}