Skip to content

Commit

Permalink
Merge pull request #402 from kubescape/feature/dns-cache
Browse files Browse the repository at this point in the history
Adding cache for dns
  • Loading branch information
amitschendel authored Nov 8, 2024
2 parents f9d51a1 + f15c5e9 commit e9616ee
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 27 deletions.
73 changes: 62 additions & 11 deletions pkg/dnsmanager/dns_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,89 @@ package dnsmanager

import (
"net"
"time"

"github.com/goradd/maps"
tracerdnstype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/dns/types"
"istio.io/pkg/cache"
)

// DNSManager is used to manage DNS events and save IP resolutions. It exposes an API to resolve IP address to domain name.
// DNSManager is used to manage DNS events and save IP resolutions.
type DNSManager struct {
addressToDomainMap maps.SafeMap[string, string] // this map is used to resolve IP address to domain name
addressToDomainMap maps.SafeMap[string, string]
lookupCache cache.ExpiringCache // Cache for DNS lookups
failureCache cache.ExpiringCache // Cache for failed lookups
}

type cacheEntry struct {
addresses []string
}

const (
defaultPositiveTTL = 1 * time.Minute // Default TTL for successful lookups
defaultNegativeTTL = 5 * time.Second // Default TTL for failed lookups
)

var _ DNSManagerClient = (*DNSManager)(nil)
var _ DNSResolver = (*DNSManager)(nil)

func CreateDNSManager() *DNSManager {
return &DNSManager{}
return &DNSManager{
// Create TTL caches with their respective expiration times
lookupCache: cache.NewTTL(defaultPositiveTTL, defaultPositiveTTL),
failureCache: cache.NewTTL(defaultNegativeTTL, defaultNegativeTTL),
}
}

func (dm *DNSManager) ReportDNSEvent(dnsEvent tracerdnstype.Event) {

// If we have addresses in the event, use them directly
if len(dnsEvent.Addresses) > 0 {
for _, address := range dnsEvent.Addresses {
dm.addressToDomainMap.Set(address, dnsEvent.DNSName)
}
} else {
addresses, err := net.LookupIP(dnsEvent.DNSName)
if err != nil {
return
}
for _, address := range addresses {
dm.addressToDomainMap.Set(address.String(), dnsEvent.DNSName)

// Update the cache with these known good addresses
dm.lookupCache.Set(dnsEvent.DNSName, cacheEntry{
addresses: dnsEvent.Addresses,
})
return
}

// Check if we've recently failed to look up this domain
if _, found := dm.failureCache.Get(dnsEvent.DNSName); found {
return
}

// Check if we have a cached result
if cached, found := dm.lookupCache.Get(dnsEvent.DNSName); found {
entry := cached.(cacheEntry)
// Use cached addresses
for _, addr := range entry.addresses {
dm.addressToDomainMap.Set(addr, dnsEvent.DNSName)
}
return
}

// Only perform lookup if we don't have cached results
addresses, err := net.LookupIP(dnsEvent.DNSName)
if err != nil {
// Cache the failure - we just need to store something, using empty struct
dm.failureCache.Set(dnsEvent.DNSName, struct{}{})
return
}

// Convert addresses to strings and store them
addrStrings := make([]string, 0, len(addresses))
for _, addr := range addresses {
addrStr := addr.String()
addrStrings = append(addrStrings, addrStr)
dm.addressToDomainMap.Set(addrStr, dnsEvent.DNSName)
}

// Cache the successful lookup
dm.lookupCache.Set(dnsEvent.DNSName, cacheEntry{
addresses: addrStrings,
})
}

func (dm *DNSManager) ResolveIPAddress(ipAddr string) (string, bool) {
Expand Down
175 changes: 159 additions & 16 deletions pkg/dnsmanager/dns_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package dnsmanager

import (
"net"
"sync"
"testing"

"math/rand/v2"

tracerdnstype "github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/trace/dns/types"
)

Expand All @@ -13,6 +16,7 @@ func TestResolveIPAddress(t *testing.T) {
dnsEvent tracerdnstype.Event
ipAddr string
want string
wantOk bool
}{
{
name: "ip found",
Expand All @@ -24,7 +28,8 @@ func TestResolveIPAddress(t *testing.T) {
"67.225.146.248",
},
},
want: "test.com",
want: "test.com",
wantOk: true,
},
{
name: "ip not found",
Expand All @@ -36,57 +41,195 @@ func TestResolveIPAddress(t *testing.T) {
"54.23.332.4",
},
},
want: "",
want: "",
wantOk: false,
},
{
name: "no address",
ipAddr: "67.225.146.248",
dnsEvent: tracerdnstype.Event{
DNSName: "test.com",
NumAnswers: 0, // will not resolve
NumAnswers: 0,
},
want: "",
want: "",
wantOk: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dm := &DNSManager{}
dm := CreateDNSManager()

dm.ReportDNSEvent(tt.dnsEvent)
got, _ := dm.ResolveIPAddress(tt.ipAddr)
if got != tt.want {
t.Errorf("ResolveIPAddress() got = %v, want %v", got, tt.want)
got, ok := dm.ResolveIPAddress(tt.ipAddr)
if got != tt.want || ok != tt.wantOk {
t.Errorf("ResolveIPAddress() got = %v, ok = %v, want = %v, wantOk = %v", got, ok, tt.want, tt.wantOk)
}
})
}
}

func TestResolveIPAddressFallback(t *testing.T) {
// Skip the test if running in CI or without network access
if testing.Short() {
t.Skip("Skipping test that requires network access")
}

tests := []struct {
name string
dnsEvent tracerdnstype.Event
want string
wantOk bool
}{

{
name: "dns resolution fallback",
dnsEvent: tracerdnstype.Event{
DNSName: "test.com",
DNSName: "example.com", // Using example.com as it's guaranteed to exist
NumAnswers: 1,
},
want: "test.com",
want: "example.com",
wantOk: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addresses, _ := net.LookupIP(tt.dnsEvent.DNSName)
dm := &DNSManager{}
dm := CreateDNSManager()

// Perform the actual DNS lookup
addresses, err := net.LookupIP(tt.dnsEvent.DNSName)
if err != nil {
t.Skipf("DNS lookup failed: %v", err)
return
}
if len(addresses) == 0 {
t.Skip("No addresses returned from DNS lookup")
return
}

dm.ReportDNSEvent(tt.dnsEvent)
got, _ := dm.ResolveIPAddress(addresses[0].String())
if got != tt.want {
t.Errorf("ResolveIPAddress() got = %v, want %v", got, tt.want)
got, ok := dm.ResolveIPAddress(addresses[0].String())
if got != tt.want || ok != tt.wantOk {
t.Errorf("ResolveIPAddress() got = %v, ok = %v, want = %v, wantOk = %v", got, ok, tt.want, tt.wantOk)
}
})
}
}

func TestCacheFallbackBehavior(t *testing.T) {
dm := CreateDNSManager()

// Test successful DNS lookup caching
event := tracerdnstype.Event{
DNSName: "test.com",
Addresses: []string{
"1.2.3.4",
},
}
dm.ReportDNSEvent(event)

// Check if the lookup is cached
cached, found := dm.lookupCache.Get(event.DNSName)
if !found {
t.Error("Expected DNS lookup to be cached")
}

entry, ok := cached.(cacheEntry)
if !ok {
t.Error("Cached entry is not of type cacheEntry")
}
if len(entry.addresses) != 1 || entry.addresses[0] != "1.2.3.4" {
t.Error("Cached addresses do not match expected values")
}

// Test failed lookup caching
failEvent := tracerdnstype.Event{
DNSName: "nonexistent.local",
}
dm.ReportDNSEvent(failEvent)

// Check if the failure is cached
_, found = dm.failureCache.Get(failEvent.DNSName)
if !found {
t.Error("Expected failed DNS lookup to be cached")
}

// Test cache hit behavior
hitCount := 0
for i := 0; i < 5; i++ {
if cached, found := dm.lookupCache.Get(event.DNSName); found {
entry := cached.(cacheEntry)
if len(entry.addresses) > 0 {
hitCount++
}
}
}
if hitCount != 5 {
t.Errorf("Expected 5 cache hits, got %d", hitCount)
}
}

func TestConcurrentAccess(t *testing.T) {
dm := CreateDNSManager()
const numGoroutines = 100
const numOperations = 1000

// Create a wait group to synchronize goroutines
var wg sync.WaitGroup
wg.Add(numGoroutines)

// Create some test data
testEvents := []tracerdnstype.Event{
{
DNSName: "test1.com",
Addresses: []string{"1.1.1.1", "2.2.2.2"},
},
{
DNSName: "test2.com",
Addresses: []string{"3.3.3.3", "4.4.4.4"},
},
{
DNSName: "test3.com",
Addresses: []string{"5.5.5.5", "6.6.6.6"},
},
}

// Launch multiple goroutines to concurrently access the cache
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()

for j := 0; j < numOperations; j++ {
// Randomly choose between writing and reading
if rand.Float32() < 0.5 {
// Write operation
event := testEvents[rand.IntN(len(testEvents))]
dm.ReportDNSEvent(event)
} else {
// Read operation
if cached, found := dm.lookupCache.Get("test1.com"); found {
entry := cached.(cacheEntry)
// Verify the slice hasn't been modified
if len(entry.addresses) != 2 {
t.Errorf("Unexpected number of addresses: %d", len(entry.addresses))
}
}
}
}
}()
}

// Wait for all goroutines to complete
wg.Wait()

// Verify final state
for _, event := range testEvents {
if cached, found := dm.lookupCache.Get(event.DNSName); found {
entry := cached.(cacheEntry)
if len(entry.addresses) != len(event.Addresses) {
t.Errorf("Cache entry for %s has wrong number of addresses: got %d, want %d",
event.DNSName, len(entry.addresses), len(event.Addresses))
}
}
}
}

0 comments on commit e9616ee

Please sign in to comment.