diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..b61a9c5 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,18 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true +indent_style = tab +indent_size = 4 +trim_trailing_whitespace = true + +[*.txt] +indent_style = space + +[*.conf] +indent_style = space + +[*.yml] +indent_size = 2 +indent_style = space diff --git a/go.mod b/go.mod index f8172ed..7a7a43b 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,16 @@ go 1.23.3 require ( github.com/charmbracelet/glamour v0.10.0 + github.com/google/go-cmp v0.7.0 github.com/miekg/dns v1.1.66 github.com/pion/stun/v3 v3.0.0 + github.com/quic-go/quic-go v0.53.0 github.com/rbmk-project/common v0.21.0 - github.com/rbmk-project/dnscore v0.13.0 github.com/rbmk-project/x v0.0.0-20250625213336-5718c136805c github.com/spf13/pflag v1.0.6 github.com/stretchr/testify v1.10.0 + golang.org/x/net v0.41.0 + golang.org/x/sys v0.33.0 mvdan.cc/sh/v3 v3.11.0 ) @@ -37,7 +40,7 @@ require ( github.com/pion/logging v0.2.4 // indirect github.com/pion/transport/v3 v3.0.7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/quic-go/quic-go v0.53.0 // indirect + github.com/rbmk-project/dnscore v0.13.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/wlynxg/anet v0.0.5 // indirect @@ -48,9 +51,7 @@ require ( golang.org/x/crypto v0.39.0 // indirect golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect golang.org/x/mod v0.25.0 // indirect - golang.org/x/net v0.41.0 // indirect golang.org/x/sync v0.15.0 // indirect - golang.org/x/sys v0.33.0 // indirect golang.org/x/term v0.32.0 // indirect golang.org/x/text v0.26.0 // indirect golang.org/x/tools v0.34.0 // indirect diff --git a/pkg/cli/curl/curl.go b/pkg/cli/curl/curl.go index f710d5c..ebb1dfb 100644 --- a/pkg/cli/curl/curl.go +++ b/pkg/cli/curl/curl.go @@ -13,9 +13,9 @@ import ( "time" "github.com/rbmk-project/common/cliutils" - "github.com/rbmk-project/common/closepool" "github.com/rbmk-project/common/fsx" "github.com/rbmk-project/rbmk/internal/markdown" + "github.com/rbmk-project/rbmk/pkg/common/closepool" "github.com/spf13/pflag" ) diff --git a/pkg/cli/curl/httplog.go b/pkg/cli/curl/httplog.go index 9ccf5a7..b1c9231 100644 --- a/pkg/cli/curl/httplog.go +++ b/pkg/cli/curl/httplog.go @@ -8,8 +8,8 @@ import ( "net/netip" "time" - "github.com/rbmk-project/common/httpconntrace" - "github.com/rbmk-project/common/httpslog" + "github.com/rbmk-project/rbmk/pkg/common/httpconntrace" + "github.com/rbmk-project/rbmk/pkg/common/httpslog" ) // httpDoAndLog performs the request and emits structured logs. diff --git a/pkg/cli/curl/task.go b/pkg/cli/curl/task.go index 840ddbf..e5f418e 100644 --- a/pkg/cli/curl/task.go +++ b/pkg/cli/curl/task.go @@ -11,10 +11,10 @@ import ( "net/http" "time" - "github.com/rbmk-project/common/closepool" "github.com/rbmk-project/common/dialonce" - "github.com/rbmk-project/dnscore" "github.com/rbmk-project/rbmk/internal/testable" + "github.com/rbmk-project/rbmk/pkg/common/closepool" + "github.com/rbmk-project/rbmk/pkg/dns/dnscore" "github.com/rbmk-project/x/netcore" ) diff --git a/pkg/cli/dig/dig.go b/pkg/cli/dig/dig.go index 272922c..6c62715 100644 --- a/pkg/cli/dig/dig.go +++ b/pkg/cli/dig/dig.go @@ -12,9 +12,9 @@ import ( "strings" "github.com/rbmk-project/common/cliutils" - "github.com/rbmk-project/common/closepool" "github.com/rbmk-project/common/fsx" "github.com/rbmk-project/rbmk/internal/markdown" + "github.com/rbmk-project/rbmk/pkg/common/closepool" "github.com/spf13/pflag" ) diff --git a/pkg/cli/dig/task.go b/pkg/cli/dig/task.go index 27dd190..e40abdf 100644 --- a/pkg/cli/dig/task.go +++ b/pkg/cli/dig/task.go @@ -16,9 +16,9 @@ import ( "time" "github.com/miekg/dns" - "github.com/rbmk-project/common/closepool" - "github.com/rbmk-project/dnscore" "github.com/rbmk-project/rbmk/internal/testable" + "github.com/rbmk-project/rbmk/pkg/common/closepool" + "github.com/rbmk-project/rbmk/pkg/dns/dnscore" "github.com/rbmk-project/x/netcore" ) diff --git a/pkg/cli/nc/nc.go b/pkg/cli/nc/nc.go index 24b987c..90d571e 100644 --- a/pkg/cli/nc/nc.go +++ b/pkg/cli/nc/nc.go @@ -12,9 +12,9 @@ import ( "time" "github.com/rbmk-project/common/cliutils" - "github.com/rbmk-project/common/closepool" "github.com/rbmk-project/common/fsx" "github.com/rbmk-project/rbmk/internal/markdown" + "github.com/rbmk-project/rbmk/pkg/common/closepool" "github.com/spf13/pflag" ) diff --git a/pkg/cli/nc/task.go b/pkg/cli/nc/task.go index 8801db8..ab19412 100644 --- a/pkg/cli/nc/task.go +++ b/pkg/cli/nc/task.go @@ -12,8 +12,8 @@ import ( "net" "time" - "github.com/rbmk-project/common/closepool" "github.com/rbmk-project/rbmk/internal/testable" + "github.com/rbmk-project/rbmk/pkg/common/closepool" "github.com/rbmk-project/x/netcore" ) diff --git a/pkg/cli/stun/stun.go b/pkg/cli/stun/stun.go index 20691ea..6aeb350 100644 --- a/pkg/cli/stun/stun.go +++ b/pkg/cli/stun/stun.go @@ -12,9 +12,9 @@ import ( "time" "github.com/rbmk-project/common/cliutils" - "github.com/rbmk-project/common/closepool" "github.com/rbmk-project/common/fsx" "github.com/rbmk-project/rbmk/internal/markdown" + "github.com/rbmk-project/rbmk/pkg/common/closepool" "github.com/spf13/pflag" ) diff --git a/pkg/cli/stun/task.go b/pkg/cli/stun/task.go index 0591a1d..14452ba 100644 --- a/pkg/cli/stun/task.go +++ b/pkg/cli/stun/task.go @@ -12,8 +12,8 @@ import ( "time" "github.com/pion/stun/v3" - "github.com/rbmk-project/common/closepool" "github.com/rbmk-project/rbmk/internal/testable" + "github.com/rbmk-project/rbmk/pkg/common/closepool" "github.com/rbmk-project/x/netcore" ) diff --git a/pkg/cli/tar/tar.go b/pkg/cli/tar/tar.go index c751283..bfed3f8 100644 --- a/pkg/cli/tar/tar.go +++ b/pkg/cli/tar/tar.go @@ -19,8 +19,8 @@ import ( "path/filepath" "github.com/rbmk-project/common/cliutils" - "github.com/rbmk-project/common/closepool" "github.com/rbmk-project/rbmk/internal/markdown" + "github.com/rbmk-project/rbmk/pkg/common/closepool" "github.com/spf13/pflag" ) diff --git a/pkg/common/closepool/closepool.go b/pkg/common/closepool/closepool.go new file mode 100644 index 0000000..3c70a1e --- /dev/null +++ b/pkg/common/closepool/closepool.go @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +// Package closepool allows pooling [io.Closer] instances +// and closing them in a single operation. +package closepool + +import ( + "errors" + "io" + "slices" + "sync" +) + +// CloserFunc allows to turn a suitable function into an [io.Closer]. +type CloserFunc func() error + +// Ensure that [CloserFunc] implements [io.Closer]. +var _ io.Closer = CloserFunc(nil) + +// Close implements io.Closer. +func (fx CloserFunc) Close() error { + return fx() +} + +// Pool allows pooling a set of [io.Closer]. +// +// The zero value is ready to use. +type Pool struct { + // handles contains the [io.Closer] to close. + handles []io.Closer + + // mu provides mutual exclusion. + mu sync.Mutex +} + +// Add adds a given [io.Closer] to the pool. +func (p *Pool) Add(conn io.Closer) { + p.mu.Lock() + p.handles = append(p.handles, conn) + p.mu.Unlock() +} + +// Close closes all the [io.Closer] inside the pool iterating +// in backward order. Therefore, if one registers a TCP connection +// and then the corresponding TLS connection, the TLS connection +// is closed first. The returned error is the join of all the +// errors that occurred when closing connections. +func (p *Pool) Close() error { + // Lock and copy the [io.Closer] to close. + p.mu.Lock() + handles := p.handles + p.handles = nil + p.mu.Unlock() + + // Close all the [io.Closer]. + var errv []error + for _, handle := range slices.Backward(handles) { + if err := handle.Close(); err != nil { + errv = append(errv, err) + } + } + return errors.Join(errv...) +} diff --git a/pkg/common/closepool/closepool_test.go b/pkg/common/closepool/closepool_test.go new file mode 100644 index 0000000..da3ee8c --- /dev/null +++ b/pkg/common/closepool/closepool_test.go @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package closepool_test + +import ( + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/rbmk-project/rbmk/pkg/common/closepool" +) + +// mockCloser implements io.Closer for testing +type mockCloser struct { + closed atomic.Int64 + err error +} + +// t0 is the time when we started running +var t0 = time.Now() + +func (m *mockCloser) Close() error { + m.closed.Add(int64(time.Since(t0))) + return m.err +} + +func TestCloserFunc(t *testing.T) { + var closed bool + pool := &closepool.Pool{} + pool.Add(closepool.CloserFunc(func() error { + closed = true + return nil + })) + pool.Close() + if !closed { + t.Error("expected closer to be closed") + } +} + +func TestPool(t *testing.T) { + t.Run("successful close", func(t *testing.T) { + pool := closepool.Pool{} + m1 := &mockCloser{} + m2 := &mockCloser{} + + pool.Add(m1) + pool.Add(m2) + + err := pool.Close() + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if m1.closed.Load() <= 0 { + t.Error("first closer was not closed") + } + if m2.closed.Load() <= 0 { + t.Error("second closer was not closed") + } + }) + + t.Run("close order", func(t *testing.T) { + pool := closepool.Pool{} + + m1 := &mockCloser{ + err: nil, + } + m2 := &mockCloser{ + err: nil, + } + + pool.Add(m1) // Added first + pool.Add(m2) // Added second + + // Should close in reverse order + err := pool.Close() + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if m1.closed.Load() <= m2.closed.Load() { + t.Error("expected m1 to be closed after m2") + } + }) + + t.Run("error handling", func(t *testing.T) { + pool := closepool.Pool{} + expectedErr1 := errors.New("close error #1") + expectedErr2 := errors.New("close error #2") + + m1 := &mockCloser{err: expectedErr1} + m2 := &mockCloser{err: expectedErr2} + + pool.Add(m1) + pool.Add(m2) + + err := pool.Close() + if err == nil { + t.Fatalf("expected error, got nil") + } + + t.Log(err) + if errors.Join(expectedErr2, expectedErr1).Error() != err.Error() { + t.Errorf("expected error to include both errors, got %v", err) + } + }) + + t.Run("concurrent usage", func(t *testing.T) { + pool := closepool.Pool{} + done := make(chan struct{}) + + // Concurrently add closers + go func() { + for i := 0; i < 100; i++ { + pool.Add(&mockCloser{}) + } + close(done) + }() + + // Add more closers from main goroutine + for i := 0; i < 100; i++ { + pool.Add(&mockCloser{}) + } + + <-done // Wait for goroutine to finish + + err := pool.Close() + if err != nil { + t.Errorf("expected no error, got %v", err) + } + }) +} diff --git a/pkg/common/doc.go b/pkg/common/doc.go new file mode 100644 index 0000000..85907d6 --- /dev/null +++ b/pkg/common/doc.go @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +/* +Package common contains simple, common packages used by other packages. + +See [dd-001-common.md] for more information. + +[dd-001-common.md]: https://github.com/rbmk-project/rbmk-project.github.io/blob/main/docs/design/dd-001-common.md +*/ +package common diff --git a/pkg/common/errclass/errclass.go b/pkg/common/errclass/errclass.go new file mode 100644 index 0000000..4e3b9d9 --- /dev/null +++ b/pkg/common/errclass/errclass.go @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +/* +Package errclass implements error classification. + +The general idea is to classify golang errors to an enum of strings +with names resembling standard Unix error names. + +# Design Principles + +1. Preserve original error in `err` in the structured logs. + +2. Add the classified error as the `errClass` field. + +3. Use [errors.Is] and [errors.As] for classification. + +4. Use string-based classification for readability. + +5. Follow Unix-like naming where appropriate. + +6. Prefix subsystem-specific errors (`EDNS_`, `ETLS_`). + +7. Keep full names for clarity over brevity. + +8. Map the nil error to an empty string. + +# System and Network Errors + +- [ETIMEDOUT] for [context.DeadlineExceeded], [os.ErrDeadlineExceeded] + +- [EINTR] for [context.Canceled], [net.ErrClosed] + +- [EEOF] for (unexpected) [io.EOF] and [io.ErrUnexpectedEOF] errors + +- [ECONNRESET], [ECONNREFUSED], ... for respective syscall errors + +The actual system error constants are defined in platform-specific files: + +- unix.go for Unix-like systems using x/sys/unix + +- windows.go for Windows systems using x/sys/windows + +This ensures proper mapping between the standardized error classes and +platform-specific error constants. + +# DNS Errors + +- [EDNS_NONAME] for errors with the "no such host" suffix + +- [EDNS_NODATA] for errors with the "no answer" suffix + +# TLS + +- [ETLS_HOSTNAME_MISMATCH] for hostname verification failure + +- [ETLS_CA_UNKNOWN] for unknown certificate authority + +- [ETLS_CERT_INVALID] for invalid certificate + +# Fallback + +- [EGENERIC] for unclassified errors +*/ +package errclass + +import ( + "context" + "crypto/x509" + "errors" + "io" + "net" + "os" + "strings" +) + +const ( + // + // Errors that we can map using [errors.Is]: + // + + // EADDRNOTAVAIL is the address not available error. + EADDRNOTAVAIL = "EADDRNOTAVAIL" + + // EADDRINUSE is the address in use error. + EADDRINUSE = "EADDRINUSE" + + // ECONNABORTED is the connection aborted error. + ECONNABORTED = "ECONNABORTED" + + // ECONNREFUSED is the connection refused error. + ECONNREFUSED = "ECONNREFUSED" + + // ECONNRESET is the connection reset by peer error. + ECONNRESET = "ECONNRESET" + + // EHOSTUNREACH is the host unreachable error. + EHOSTUNREACH = "EHOSTUNREACH" + + // EEOF indicates an unexpected EOF. + EEOF = "EEOF" + + // EINVAL is the invalid argument error. + EINVAL = "EINVAL" + + // EINTR is the interrupted system call error. + EINTR = "EINTR" + + // ENETDOWN is the network is down error. + ENETDOWN = "ENETDOWN" + + // ENETUNREACH is the network unreachable error. + ENETUNREACH = "ENETUNREACH" + + // ENOBUFS is the no buffer space available error. + ENOBUFS = "ENOBUFS" + + // ENOTCONN is the not connected error. + ENOTCONN = "ENOTCONN" + + // EPROTONOSUPPORT is the protocol not supported error. + EPROTONOSUPPORT = "EPROTONOSUPPORT" + + // ETIMEDOUT is the operation timed out error. + ETIMEDOUT = "ETIMEDOUT" + + // + // Errors that we can map using the error message suffix: + // + + // EDNS_NONAME is the DNS error for "no such host". + EDNS_NONAME = "EDNS_NONAME" + + // EDNS_NODATA is the DNS error for "no answer". + EDNS_NODATA = "EDNS_NODATA" + + // + // Errors that we can map using [errors.As]: + // + + // ETLS_HOSTNAME_MISMATCH is the TLS error for hostname verification failure. + ETLS_HOSTNAME_MISMATCH = "ETLS_HOSTNAME_MISMATCH" + + // ETLS_CA_UNKNOWN is the TLS error for unknown certificate authority. + ETLS_CA_UNKNOWN = "ETLS_CA_UNKNOWN" + + // ETLS_CERT_INVALID is the TLS error for invalid certificate. + ETLS_CERT_INVALID = "ETLS_CERT_INVALID" + + // + // Fallback errors: + // + + // EGENERIC is the generic, unclassified error. + EGENERIC = "EGENERIC" +) + +// errorsIsMap contains the errors that we can map with [errors.Is]. +var errorsIsMap = map[error]string{ + context.DeadlineExceeded: ETIMEDOUT, + context.Canceled: EINTR, + errEADDRNOTAVAIL: EADDRNOTAVAIL, + errEADDRINUSE: EADDRINUSE, + errECONNABORTED: ECONNABORTED, + errECONNREFUSED: ECONNREFUSED, + errECONNRESET: ECONNRESET, + errEHOSTUNREACH: EHOSTUNREACH, + io.EOF: EEOF, + io.ErrUnexpectedEOF: EEOF, + errEINVAL: EINVAL, + errEINTR: EINTR, + errENETDOWN: ENETDOWN, + errENETUNREACH: ENETUNREACH, + errENOBUFS: ENOBUFS, + errENOTCONN: ENOTCONN, + errEPROTONOSUPPORT: EPROTONOSUPPORT, + errETIMEDOUT: ETIMEDOUT, + net.ErrClosed: EINTR, + os.ErrDeadlineExceeded: ETIMEDOUT, +} + +// stringSuffixMap contains the errors that we can map using the error message suffix. +var stringSuffixMap = map[string]string{ + "no answer from DNS server": EDNS_NODATA, + "no such host": EDNS_NONAME, +} + +// errorsAsList contains the errors that we can map with [errors.As]. +var errorsAsList = []struct { + as func(err error) bool + class string +}{ + { + as: func(err error) bool { + var candidate x509.HostnameError + return errors.As(err, &candidate) + }, + class: ETLS_HOSTNAME_MISMATCH, + }, + + { + as: func(err error) bool { + var candidate x509.UnknownAuthorityError + return errors.As(err, &candidate) + }, + class: ETLS_CA_UNKNOWN, + }, + + { + as: func(err error) bool { + var candidate x509.CertificateInvalidError + return errors.As(err, &candidate) + }, + class: ETLS_CERT_INVALID, + }, +} + +// New creates a new error class from the given error. +func New(err error) string { + // exclude the nil error case first + if err == nil { + return "" + } + + // attemp direct mapping using the [errors.Is] func + for candidate, class := range errorsIsMap { + if errors.Is(err, candidate) { + return class + } + } + + // attempt indirect mapping using the [errors.As] func + for _, entry := range errorsAsList { + if entry.as(err) { + return entry.class + } + } + + // fallback to attempt matching with the string suffix + for suffix, class := range stringSuffixMap { + if strings.HasSuffix(err.Error(), suffix) { + return class + } + } + + // we don't known this error + return EGENERIC +} diff --git a/pkg/common/errclass/errclass_test.go b/pkg/common/errclass/errclass_test.go new file mode 100644 index 0000000..e9c33dc --- /dev/null +++ b/pkg/common/errclass/errclass_test.go @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package errclass + +import ( + "crypto/x509" + "errors" + "fmt" + "testing" +) + +func TestNew(t *testing.T) { + // testcase is a test case implemented by this function. + type testcase struct { + input error + expect string + } + + // start with a test case for the nil error + var tests = []testcase{ + { + input: nil, + expect: "", + }, + } + + // add tests for cases we can test with errors.Is + for key, value := range errorsIsMap { + tests = append(tests, testcase{ + input: key, + expect: value, + }) + } + + // add tests for cases we can test with string suffix matching + for suffix, class := range stringSuffixMap { + tests = append(tests, testcase{ + input: errors.New("some error message " + suffix), + expect: class, + }) + } + + // add tests for cases we can test with errors.As + tests = append(tests, testcase{ + input: x509.HostnameError{ + Certificate: &x509.Certificate{}, + Host: "", + }, + expect: ETLS_HOSTNAME_MISMATCH, + }) + tests = append(tests, testcase{ + input: x509.UnknownAuthorityError{ + Cert: &x509.Certificate{}, + }, + expect: ETLS_CA_UNKNOWN, + }) + tests = append(tests, testcase{ + input: x509.CertificateInvalidError{ + Cert: &x509.Certificate{}, + Reason: 0, + Detail: "", + }, + expect: ETLS_CERT_INVALID, + }) + + // add test for unknown error + tests = append(tests, testcase{ + input: errors.New("unknown error"), + expect: EGENERIC, + }) + + // run all tests + for _, tt := range tests { + t.Run(fmt.Sprintf("%v", tt.input), func(t *testing.T) { + got := New(tt.input) + if got != tt.expect { + t.Errorf("New(%v) = %v; want %v", tt.input, got, tt.expect) + } + }) + } +} diff --git a/pkg/common/errclass/unix.go b/pkg/common/errclass/unix.go new file mode 100644 index 0000000..1e20ba7 --- /dev/null +++ b/pkg/common/errclass/unix.go @@ -0,0 +1,24 @@ +//go:build unix + +// SPDX-License-Identifier: GPL-3.0-or-later + +package errclass + +import "golang.org/x/sys/unix" + +const ( + errEADDRNOTAVAIL = unix.EADDRNOTAVAIL + errEADDRINUSE = unix.EADDRINUSE + errECONNABORTED = unix.ECONNABORTED + errECONNREFUSED = unix.ECONNREFUSED + errECONNRESET = unix.ECONNRESET + errEHOSTUNREACH = unix.EHOSTUNREACH + errEINVAL = unix.EINVAL + errEINTR = unix.EINTR + errENETDOWN = unix.ENETDOWN + errENETUNREACH = unix.ENETUNREACH + errENOBUFS = unix.ENOBUFS + errENOTCONN = unix.ENOTCONN + errEPROTONOSUPPORT = unix.EPROTONOSUPPORT + errETIMEDOUT = unix.ETIMEDOUT +) diff --git a/pkg/common/errclass/windows.go b/pkg/common/errclass/windows.go new file mode 100644 index 0000000..db0eb00 --- /dev/null +++ b/pkg/common/errclass/windows.go @@ -0,0 +1,24 @@ +//go:build windows + +// SPDX-License-Identifier: GPL-3.0-or-later + +package errclass + +import "golang.org/x/sys/windows" + +const ( + errEADDRNOTAVAIL = windows.WSAEADDRNOTAVAIL + errEADDRINUSE = windows.WSAEADDRINUSE + errECONNABORTED = windows.WSAECONNABORTED + errECONNREFUSED = windows.WSAECONNREFUSED + errECONNRESET = windows.WSAECONNRESET + errEHOSTUNREACH = windows.WSAEHOSTUNREACH + errEINVAL = windows.WSAEINVAL + errEINTR = windows.WSAEINTR + errENETDOWN = windows.WSAENETDOWN + errENETUNREACH = windows.WSAENETUNREACH + errENOBUFS = windows.WSAENOBUFS + errENOTCONN = windows.WSAENOTCONN + errEPROTONOSUPPORT = windows.WSAEPROTONOSUPPORT + errETIMEDOUT = windows.WSAETIMEDOUT +) diff --git a/pkg/common/httpconntrace/httpconntrace.go b/pkg/common/httpconntrace/httpconntrace.go new file mode 100644 index 0000000..1de8ef8 --- /dev/null +++ b/pkg/common/httpconntrace/httpconntrace.go @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +/* +Package httpconntrace provides a way to trace the local and remote endpoints +used by an HTTP connection while performing an [*http.Client] request. + +Internally, we use [net/http/httptrace] to collect the connection [*Endpoints]. + +Operationally, you need to use [Do] where you would otherwise call +[*http.Client.Do] method. The [*Endpoints] are returned along with the response. + +Collecting the connection [*Endpoints] is important to map the HTTP response +with the connection that actually serviced the request. +*/ +package httpconntrace + +import ( + "context" + "net" + "net/http" + "net/http/httptrace" + "net/netip" + "sync" +) + +// Endpoints contains the connection endpoints extacted by [Do]. +type Endpoints struct { + // LocalAddr is the local address of the connection. + LocalAddr netip.AddrPort + + // RemoteAddr is the remote address of the connection. + RemoteAddr netip.AddrPort +} + +// Do performs an HTTP request using [*http.Client.Do] and uses [net/http/httptrace] to +// extract the local and remote [*Endpoints] used by the connection. +// +// Internally, this function creates a new context for tracing purposes, to avoid +// accidentally composing the [net/http/httptrace] trace with other possible context traces +// that may have already been present in the request context. Obviously, this means that +// using this function prevents one to observe connection events with a trace. +// +// Note that this function assumes we're using TCP and casts the connection addresses +// to [*net.TCPAddr] to extract the endpoints. If the we're not using TCP, the returned +// [*Endpoint] will contain zero initialized (i.e., invalid) addresses. +// +// We return *Endpoints rather than Endpoints because the structure is larger than 32 bytes +// and could possibly be further extended in the future to include additional fields. +func Do(client *http.Client, req *http.Request) (*http.Response, *Endpoints, error) { + // Prepare to collect info in a goroutine-safe way. + var ( + laddr netip.AddrPort + mu sync.Mutex + raddr netip.AddrPort + ) + + // Create clean context for tracing where "clean" means + // we don't compose with other possible context traces + traceCtx, cancel := context.WithCancel(context.Background()) + + // Configure the trace for extracting laddr, raddr + trace := &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + mu.Lock() + defer mu.Unlock() + if addr, ok := info.Conn.LocalAddr().(*net.TCPAddr); ok { + laddr = addr.AddrPort() + } + if addr, ok := info.Conn.RemoteAddr().(*net.TCPAddr); ok { + raddr = addr.AddrPort() + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(traceCtx, trace)) + + // Arrange for the inner context to be canceled + // when the outer context is done. + // + // This must be after req.WithContext to avoid + // a data race in the context itself. + go func() { + defer cancel() + select { + case <-req.Context().Done(): + case <-traceCtx.Done(): + } + }() + + // Perform the request + resp, err := client.Do(req) + + // Gather the local and remote endpoints while holding the mutex + // to avoid data-racing with the tracing goroutine. + mu.Lock() + epnts := &Endpoints{LocalAddr: laddr, RemoteAddr: raddr} + mu.Unlock() + + // Return the results to the caller. + return resp, epnts, err +} diff --git a/pkg/common/httpconntrace/httpconntrace_test.go b/pkg/common/httpconntrace/httpconntrace_test.go new file mode 100644 index 0000000..be912c1 --- /dev/null +++ b/pkg/common/httpconntrace/httpconntrace_test.go @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package httpconntrace_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + + "github.com/rbmk-project/rbmk/pkg/common/httpconntrace" +) + +func Example() { + // Create a test server that just echoes back + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello, World!")) + })) + defer ts.Close() + + // Create and send request + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + fmt.Printf("failed to create request: %s\n", err) + return + } + + // Use Do instead of client.Do to get connection endpoints + resp, endpoints, err := httpconntrace.Do(http.DefaultClient, req) + if err != nil { + fmt.Printf("request failed: %s\n", err) + return + } + defer resp.Body.Close() + + // Print the endpoints we collected + fmt.Printf("Local: %v\n", endpoints.LocalAddr.IsValid()) + fmt.Printf("Remote: %v\n", endpoints.RemoteAddr.IsValid()) + + // Output: + // Local: true + // Remote: true +} diff --git a/pkg/common/httpslog/httpslog.go b/pkg/common/httpslog/httpslog.go new file mode 100644 index 0000000..0ff1bb7 --- /dev/null +++ b/pkg/common/httpslog/httpslog.go @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +// Package httpslog implements structured logging for HTTP clients. +package httpslog + +import ( + "log/slog" + "net/http" + "net/netip" + "time" + + "github.com/rbmk-project/rbmk/pkg/common/errclass" +) + +// MaybeLogRoundTripStart logs the start of a round trip if the +// given logger is not nil, otherwise it does nothing. +func MaybeLogRoundTripStart( + logger *slog.Logger, + localAddr netip.AddrPort, + protocol string, + remoteAddr netip.AddrPort, + req *http.Request, + t0 time.Time, +) { + if logger != nil { + logger.InfoContext( + req.Context(), + "httpRoundTripStart", + slog.String("httpMethod", req.Method), + slog.String("httpUrl", req.URL.String()), + slog.Any("httpRequestHeaders", req.Header), + slog.String("localAddr", localAddr.String()), + slog.String("protocol", protocol), + slog.String("remoteAddr", remoteAddr.String()), + slog.Time("t", t0), + ) + } +} + +// MaybeLogRoundTripDone logs the end of a round trip if the given +// logger is not nil, otherwise it does nothing. +func MaybeLogRoundTripDone( + logger *slog.Logger, + localAddr netip.AddrPort, + protocol string, + remoteAddr netip.AddrPort, + req *http.Request, + resp *http.Response, + err error, + t0 time.Time, + t time.Time, +) { + if logger != nil { + if err != nil { + logger.InfoContext( + req.Context(), + "httpRoundTripDone", + slog.Any("err", err), + slog.Any("errClass", errclass.New(err)), + slog.String("httpMethod", req.Method), + slog.String("httpUrl", req.URL.String()), + slog.Any("httpRequestHeaders", req.Header), + slog.String("localAddr", localAddr.String()), + slog.String("protocol", protocol), + slog.String("remoteAddr", remoteAddr.String()), + slog.Time("t0", t0), + slog.Time("t", t), + ) + return + } + logger.InfoContext( + req.Context(), + "httpRoundTripDone", + slog.String("httpMethod", req.Method), + slog.String("httpUrl", req.URL.String()), + slog.Any("httpRequestHeaders", req.Header), + slog.Int("httpResponseStatusCode", resp.StatusCode), + slog.Any("httpResponseHeaders", resp.Header), + slog.String("localAddr", localAddr.String()), + slog.String("protocol", protocol), + slog.String("remoteAddr", remoteAddr.String()), + slog.Time("t", t0), + ) + } +} diff --git a/pkg/common/httpslog/httpslog_test.go b/pkg/common/httpslog/httpslog_test.go new file mode 100644 index 0000000..3434eab --- /dev/null +++ b/pkg/common/httpslog/httpslog_test.go @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package httpslog + +import ( + "bytes" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestMaybeLogRoundTripStart(t *testing.T) { + tests := []struct { + name string + newLogger func(w io.Writer) *slog.Logger + expectTime time.Time + expectLog string + }{ + { + name: "Logger set", + newLogger: func(w io.Writer) *slog.Logger { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.TimeKey { + return slog.Attr{} + } + return attr + }, + })) + }, + expectTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + expectLog: `{"level":"INFO","msg":"httpRoundTripStart","httpMethod":"GET",` + + `"httpUrl":"https://example.com","httpRequestHeaders":{},"localAddr":"127.0.0.1:0",` + + `"protocol":"tcp","remoteAddr":"93.184.216.34:443","t":"2020-01-01T00:00:00Z"}` + "\n", + }, + { + name: "Logger not set", + newLogger: func(w io.Writer) *slog.Logger { return nil }, + expectTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + expectLog: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var out bytes.Buffer + logger := tt.newLogger(&out) + + req, err := http.NewRequest("GET", "https://example.com", nil) + assert.NoError(t, err) + + localAddr := netip.MustParseAddrPort("127.0.0.1:0") + remoteAddr := netip.MustParseAddrPort("93.184.216.34:443") + + MaybeLogRoundTripStart( + logger, + localAddr, + "tcp", + remoteAddr, + req, + tt.expectTime, + ) + + actualLog := out.String() + assert.Equal(t, tt.expectLog, actualLog) + }) + } +} + +func TestMaybeLogRoundTripDone(t *testing.T) { + tests := []struct { + name string + newLogger func(w io.Writer) *slog.Logger + withError bool + expectTime time.Time + expectLog string + }{ + { + name: "Logger set with success", + newLogger: func(w io.Writer) *slog.Logger { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.TimeKey { + return slog.Attr{} + } + return attr + }, + })) + }, + withError: false, + expectTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + expectLog: `{"level":"INFO","msg":"httpRoundTripDone","httpMethod":"GET",` + + `"httpUrl":"https://example.com","httpRequestHeaders":{},` + + `"httpResponseStatusCode":200,"httpResponseHeaders":{},` + + `"localAddr":"127.0.0.1:0","protocol":"tcp","remoteAddr":"93.184.216.34:443",` + + `"t":"2020-01-01T00:00:00Z"}` + "\n", + }, + { + name: "Logger set with error", + newLogger: func(w io.Writer) *slog.Logger { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.TimeKey { + return slog.Attr{} + } + return attr + }, + })) + }, + withError: true, + expectTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + expectLog: `{"level":"INFO","msg":"httpRoundTripDone","err":"assert.AnError general error for testing",` + + `"errClass":"EGENERIC","httpMethod":"GET","httpUrl":"https://example.com",` + + `"httpRequestHeaders":{},"localAddr":"127.0.0.1:0","protocol":"tcp",` + + `"remoteAddr":"93.184.216.34:443","t0":"2020-01-01T00:00:00Z","t":"2020-01-01T00:00:00Z"}` + "\n", + }, + { + name: "Logger not set", + newLogger: func(w io.Writer) *slog.Logger { return nil }, + withError: false, + expectTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + expectLog: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var out bytes.Buffer + logger := tt.newLogger(&out) + + req, err := http.NewRequest("GET", "https://example.com", nil) + assert.NoError(t, err) + + var resp *http.Response + var roundTripErr error + + if !tt.withError { + resp = &http.Response{ + StatusCode: 200, + Header: make(http.Header), + } + } else { + roundTripErr = assert.AnError + } + + localAddr := netip.MustParseAddrPort("127.0.0.1:0") + remoteAddr := netip.MustParseAddrPort("93.184.216.34:443") + + MaybeLogRoundTripDone( + logger, + localAddr, + "tcp", + remoteAddr, + req, + resp, + roundTripErr, + tt.expectTime, + tt.expectTime, + ) + + actualLog := out.String() + assert.Equal(t, tt.expectLog, actualLog) + + // Verify JSON is valid when there's output + if actualLog != "" { + var jsonMap map[string]interface{} + err := json.Unmarshal([]byte(actualLog), &jsonMap) + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/common/internal/fsmodel/fsmodel.go b/pkg/common/internal/fsmodel/fsmodel.go new file mode 100644 index 0000000..bea8bbe --- /dev/null +++ b/pkg/common/internal/fsmodel/fsmodel.go @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +// Package fsmodel provides an abstract file system model. +package fsmodel + +import ( + "io" + "io/fs" + "net" + "os" + "time" +) + +// Forward file system constants. +const ( + O_CREATE = os.O_CREATE + O_RDONLY = os.O_RDONLY + O_RDWR = os.O_RDWR + O_TRUNC = os.O_TRUNC + O_WRONLY = os.O_WRONLY + O_APPEND = os.O_APPEND +) + +// File represents a file in the filesystem. +// +// We use a simplified view of the full interface implemented by +// [os.File] to allow for easier mocking and testing. +type File io.ReadWriteCloser + +// Ensure [*os.File] implements [File]. +var _ File = &os.File{} + +// FS is the filesystem interface. +// +// Any simulated or real filesystem should implement this interface. +type FS interface { + // Chmod changes the mode of the named file to mode. + Chmod(name string, mode fs.FileMode) error + + // Chown changes the uid and gid of the named file. + Chown(name string, uid, gid int) error + + // Chtimes changes the access and modification times of the named file. + Chtimes(name string, atime time.Time, mtime time.Time) error + + // Create creates a file in the filesystem, returning the file or an error. + Create(name string) (File, error) + + // DialUnix connects to a Unix-domain socket using the given file name. + DialUnix(name string) (net.Conn, error) + + // ListenUnix creates a listening Unix-domain socket using the given file name. + ListenUnix(name string) (net.Listener, error) + + // Lstat is like Stat but does not follow symbolic links. + Lstat(name string) (fs.FileInfo, error) + + // Mkdir creates a directory in the filesystem, possibly returning an error. + Mkdir(name string, perm fs.FileMode) error + + // MkdirAll creates a directory path and all parents that does not exist yet. + MkdirAll(path string, perm fs.FileMode) error + + // Open opens a file, returning it or an error, if any. + Open(name string) (File, error) + + // OpenFile opens a file using the given flags and the given mode. + OpenFile(name string, flag int, perm fs.FileMode) (File, error) + + // ReadDir reads and returns the content of a given directory. + ReadDir(dirname string) ([]fs.DirEntry, error) + + // Remove removes a file identified by name, returning an error, if any. + Remove(name string) error + + // RemoveAll removes a directory path and any children it contains. It + // does not fail if the path does not exist (returns nil). + RemoveAll(path string) error + + // Rename renames a file. + Rename(oldname, newname string) error + + // Stat returns a FileInfo describing the named file, or an error. + Stat(name string) (fs.FileInfo, error) +} diff --git a/pkg/common/mocks/conn.go b/pkg/common/mocks/conn.go new file mode 100644 index 0000000..8078288 --- /dev/null +++ b/pkg/common/mocks/conn.go @@ -0,0 +1,81 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// Adapted from: https://github.com/ooni/probe-cli/blob/v3.20.1/internal/mocks/dialer.go +// + +package mocks + +import ( + "net" + "time" +) + +// Conn is a mockable [net.Conn]. +type Conn struct { + // MockRead is the function to call when Read is called. + MockRead func(b []byte) (int, error) + + // MockWrite is the function to call when Write is called. + MockWrite func(b []byte) (int, error) + + // MockClose is the function to call when Close is called. + MockClose func() error + + // MockLocalAddr is the function to call when LocalAddr is called. + MockLocalAddr func() net.Addr + + // MockRemoteAddr is the function to call when RemoteAddr is called. + MockRemoteAddr func() net.Addr + + // MockSetDeadline is the function to call when SetDeadline is called. + MockSetDeadline func(t time.Time) error + + // MockSetReadDeadline is the function to call when SetReadDeadline is called. + MockSetReadDeadline func(t time.Time) error + + // MockSetWriteDeadline is the function to call when SetWriteDeadline is called. + MockSetWriteDeadline func(t time.Time) error +} + +var _ net.Conn = &Conn{} + +// Read calls MockRead. +func (c *Conn) Read(b []byte) (int, error) { + return c.MockRead(b) +} + +// Write calls MockWrite. +func (c *Conn) Write(b []byte) (int, error) { + return c.MockWrite(b) +} + +// Close calls MockClose. +func (c *Conn) Close() error { + return c.MockClose() +} + +// LocalAddr calls MockLocalAddr. +func (c *Conn) LocalAddr() net.Addr { + return c.MockLocalAddr() +} + +// RemoteAddr calls MockRemoteAddr. +func (c *Conn) RemoteAddr() net.Addr { + return c.MockRemoteAddr() +} + +// SetDeadline calls MockSetDeadline. +func (c *Conn) SetDeadline(t time.Time) error { + return c.MockSetDeadline(t) +} + +// SetReadDeadline calls MockSetReadDeadline. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.MockSetReadDeadline(t) +} + +// SetWriteDeadline calls MockSetWriteDeadline. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.MockSetWriteDeadline(t) +} diff --git a/pkg/common/mocks/conn_test.go b/pkg/common/mocks/conn_test.go new file mode 100644 index 0000000..4ef209a --- /dev/null +++ b/pkg/common/mocks/conn_test.go @@ -0,0 +1,134 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// Adapted from: https://github.com/ooni/probe-cli/blob/v3.20.1/internal/mocks/dialer_test.go +// + +package mocks + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestConn(t *testing.T) { + t.Run("Read", func(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + } + count, err := c.Read(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + }) + + t.Run("Write", func(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, expected + }, + } + count, err := c.Write(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + }) + + t.Run("Close", func(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockClose: func() error { + return expected + }, + } + err := c.Close() + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("LocalAddr", func(t *testing.T) { + expected := &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1234, + } + c := &Conn{ + MockLocalAddr: func() net.Addr { + return expected + }, + } + out := c.LocalAddr() + if diff := cmp.Diff(expected, out); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("RemoteAddr", func(t *testing.T) { + expected := &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1234, + } + c := &Conn{ + MockRemoteAddr: func() net.Addr { + return expected + }, + } + out := c.RemoteAddr() + if diff := cmp.Diff(expected, out); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("SetDeadline", func(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockSetDeadline: func(t time.Time) error { + return expected + }, + } + err := c.SetDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + }) + + t.Run("SetReadDeadline", func(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockSetReadDeadline: func(t time.Time) error { + return expected + }, + } + err := c.SetReadDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + }) + + t.Run("SetWriteDeadline", func(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockSetWriteDeadline: func(t time.Time) error { + return expected + }, + } + err := c.SetWriteDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + }) +} diff --git a/pkg/common/mocks/doc.go b/pkg/common/mocks/doc.go new file mode 100644 index 0000000..8f924c9 --- /dev/null +++ b/pkg/common/mocks/doc.go @@ -0,0 +1,4 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +// Package mocks contains mocks for standard library types. +package mocks diff --git a/pkg/common/mocks/fsmodel.go b/pkg/common/mocks/fsmodel.go new file mode 100644 index 0000000..26a5840 --- /dev/null +++ b/pkg/common/mocks/fsmodel.go @@ -0,0 +1,179 @@ +package mocks + +import ( + "io/fs" + "net" + "time" + + "github.com/rbmk-project/rbmk/pkg/common/internal/fsmodel" +) + +// FsmodelFS is an alias for [fsmodel.FS]. +type FsmodelFS = fsmodel.FS + +// FsmodelFile is an alias for [fsmodel.File]. +type FsmodelFile = fsmodel.File + +// FS implements [FsmodelFS] for testing +type FS struct { + // MockChmod implements Chmod + MockChmod func(name string, mode fs.FileMode) error + + // MockChown implements Chown + MockChown func(name string, uid, gid int) error + + // MockChtimes implements Chtimes + MockChtimes func(name string, atime time.Time, mtime time.Time) error + + // MockCreate implements Create + MockCreate func(name string) (FsmodelFile, error) + + // MockDialUnix implements DialUnix + MockDialUnix func(name string) (net.Conn, error) + + // MockListenUnix implements ListenUnix + MockListenUnix func(name string) (net.Listener, error) + + // MockLstat implements Lstat + MockLstat func(name string) (fs.FileInfo, error) + + // MockMkdir implements Mkdir + MockMkdir func(name string, perm fs.FileMode) error + + // MockMkdirAll implements MkdirAll + MockMkdirAll func(path string, perm fs.FileMode) error + + // MockOpen implements Open + MockOpen func(name string) (FsmodelFile, error) + + // MockOpenFile implements OpenFile + MockOpenFile func(name string, flag int, perm fs.FileMode) (FsmodelFile, error) + + // MockReadDir implements ReadDir + MockReadDir func(dirname string) ([]fs.DirEntry, error) + + // MockRemove implements Remove + MockRemove func(name string) error + + // MockRemoveAll implements RemoveAll + MockRemoveAll func(path string) error + + // MockRename implements Rename + MockRename func(oldname, newname string) error + + // MockStat implements Stat + MockStat func(name string) (fs.FileInfo, error) +} + +// Ensure [FS] implements [FsmodelFS] +var _ FsmodelFS = &FS{} + +// Chmod calls MockChmod +func (m *FS) Chmod(name string, mode fs.FileMode) error { + return m.MockChmod(name, mode) +} + +// Chown calls MockChown +func (m *FS) Chown(name string, uid, gid int) error { + return m.MockChown(name, uid, gid) +} + +// Chtimes calls MockChtimes +func (m *FS) Chtimes(name string, atime, mtime time.Time) error { + return m.MockChtimes(name, atime, mtime) +} + +// Create calls MockCreate +func (m *FS) Create(name string) (FsmodelFile, error) { + return m.MockCreate(name) +} + +// DialUnix calls MockDialUnix +func (m *FS) DialUnix(name string) (net.Conn, error) { + return m.MockDialUnix(name) +} + +// ListenUnix calls MockListenUnix +func (m *FS) ListenUnix(name string) (net.Listener, error) { + return m.MockListenUnix(name) +} + +// Lstat calls MockLstat +func (m *FS) Lstat(name string) (fs.FileInfo, error) { + return m.MockLstat(name) +} + +// Mkdir calls MockMkdir +func (m *FS) Mkdir(name string, perm fs.FileMode) error { + return m.MockMkdir(name, perm) +} + +// MkdirAll calls MockMkdirAll +func (m *FS) MkdirAll(path string, perm fs.FileMode) error { + return m.MockMkdirAll(path, perm) +} + +// Open calls MockOpen +func (m *FS) Open(name string) (FsmodelFile, error) { + return m.MockOpen(name) +} + +// OpenFile calls MockOpenFile +func (m *FS) OpenFile(name string, flag int, perm fs.FileMode) (FsmodelFile, error) { + return m.MockOpenFile(name, flag, perm) +} + +// ReadDir calls MockReadDir +func (m *FS) ReadDir(dirname string) ([]fs.DirEntry, error) { + return m.MockReadDir(dirname) +} + +// Remove calls MockRemove +func (m *FS) Remove(name string) error { + return m.MockRemove(name) +} + +// RemoveAll calls MockRemoveAll +func (m *FS) RemoveAll(path string) error { + return m.MockRemoveAll(path) +} + +// Rename calls MockRename +func (m *FS) Rename(oldname, newname string) error { + return m.MockRename(oldname, newname) +} + +// Stat calls MockStat +func (m *FS) Stat(name string) (fs.FileInfo, error) { + return m.MockStat(name) +} + +// File implements [FsmodelFile] for testing +type File struct { + // MockRead implements Read + MockRead func(b []byte) (int, error) + + // MockWrite implements Write + MockWrite func(b []byte) (int, error) + + // MockClose implements Close + MockClose func() error +} + +// Ensure [File] implements [FsmodelFile]. +var _ FsmodelFile = &File{} + +// Read calls MockRead +func (m *File) Read(b []byte) (int, error) { + return m.MockRead(b) +} + +// Write calls MockWrite +func (m *File) Write(b []byte) (int, error) { + return m.MockWrite(b) +} + +// Close calls MockClose +func (m *File) Close() error { + return m.MockClose() +} diff --git a/pkg/common/mocks/fsmodel_test.go b/pkg/common/mocks/fsmodel_test.go new file mode 100644 index 0000000..14e5486 --- /dev/null +++ b/pkg/common/mocks/fsmodel_test.go @@ -0,0 +1,269 @@ +package mocks_test + +import ( + "errors" + "io/fs" + "net" + "testing" + "time" + + "github.com/rbmk-project/rbmk/pkg/common/internal/fsmodel" + "github.com/rbmk-project/rbmk/pkg/common/mocks" +) + +func TestFS(t *testing.T) { + t.Run("Chmod", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockChmod: func(name string, mode fs.FileMode) error { + return expected + }, + } + err := fs.Chmod("test.txt", 0644) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Chown", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockChown: func(name string, uid, gid int) error { + return expected + }, + } + err := fs.Chown("test.txt", 1000, 1000) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Chtimes", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockChtimes: func(name string, atime, mtime time.Time) error { + return expected + }, + } + err := fs.Chtimes("test.txt", time.Now(), time.Now()) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Create", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockCreate: func(name string) (fsmodel.File, error) { + return nil, expected + }, + } + _, err := fs.Create("test.txt") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("DialUnix", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockDialUnix: func(name string) (net.Conn, error) { + return nil, expected + }, + } + _, err := fs.DialUnix("test.sock") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("ListenUnix", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockListenUnix: func(name string) (net.Listener, error) { + return nil, expected + }, + } + _, err := fs.ListenUnix("test.sock") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Lstat", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockLstat: func(name string) (fs.FileInfo, error) { + return nil, expected + }, + } + _, err := fs.Lstat("test.txt") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Mkdir", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockMkdir: func(name string, perm fs.FileMode) error { + return expected + }, + } + err := fs.Mkdir("testdir", 0755) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("MkdirAll", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockMkdirAll: func(path string, perm fs.FileMode) error { + return expected + }, + } + err := fs.MkdirAll("testdir/subdir", 0755) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Open", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockOpen: func(name string) (fsmodel.File, error) { + return nil, expected + }, + } + _, err := fs.Open("test.txt") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("OpenFile", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockOpenFile: func(name string, flag int, perm fs.FileMode) (fsmodel.File, error) { + return nil, expected + }, + } + _, err := fs.OpenFile("test.txt", fsmodel.O_CREATE|fsmodel.O_WRONLY, 0644) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("ReadDir", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockReadDir: func(dirname string) ([]fs.DirEntry, error) { + return nil, expected + }, + } + _, err := fs.ReadDir("testdir") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Remove", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockRemove: func(name string) error { + return expected + }, + } + err := fs.Remove("test.txt") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("RemoveAll", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockRemoveAll: func(path string) error { + return expected + }, + } + err := fs.RemoveAll("testdir") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Rename", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockRename: func(oldname, newname string) error { + return expected + }, + } + err := fs.Rename("old.txt", "new.txt") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Stat", func(t *testing.T) { + expected := errors.New("mocked error") + fs := &mocks.FS{ + MockStat: func(name string) (fs.FileInfo, error) { + return nil, expected + }, + } + _, err := fs.Stat("test.txt") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) +} + +func TestFile(t *testing.T) { + t.Run("Read", func(t *testing.T) { + expected := errors.New("mocked error") + file := &mocks.File{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + } + count, err := file.Read(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + }) + + t.Run("Write", func(t *testing.T) { + expected := errors.New("mocked error") + file := &mocks.File{ + MockWrite: func(b []byte) (int, error) { + return 0, expected + }, + } + count, err := file.Write(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + }) + + t.Run("Close", func(t *testing.T) { + expected := errors.New("mocked error") + file := &mocks.File{ + MockClose: func() error { + return expected + }, + } + err := file.Close() + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) +} diff --git a/pkg/common/mocks/http.go b/pkg/common/mocks/http.go new file mode 100644 index 0000000..22d2331 --- /dev/null +++ b/pkg/common/mocks/http.go @@ -0,0 +1,20 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// Adapted from: https://github.com/ooni/probe-cli/blob/v3.20.1/internal/mocks/http.go +// + +package mocks + +import "net/http" + +// HTTPTransport mocks [http.RoundTripper]. +type HTTPTransport struct { + // MockRoundTrip is the function to call when RoundTrip is called. + MockRoundTrip func(req *http.Request) (*http.Response, error) +} + +// RoundTrip calls MockRoundTrip. +func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return txp.MockRoundTrip(req) +} diff --git a/pkg/common/mocks/http_test.go b/pkg/common/mocks/http_test.go new file mode 100644 index 0000000..f46dd66 --- /dev/null +++ b/pkg/common/mocks/http_test.go @@ -0,0 +1,31 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// Adapted from: https://github.com/ooni/probe-cli/blob/v3.20.1/internal/mocks/http_test.go +// + +package mocks + +import ( + "errors" + "net/http" + "testing" +) + +func TestHTTPTransport(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + expected := errors.New("mocked error") + txp := &HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, expected + }, + } + resp, err := txp.RoundTrip(&http.Request{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response here") + } + }) +} diff --git a/pkg/common/mocks/packetconn.go b/pkg/common/mocks/packetconn.go new file mode 100644 index 0000000..cccc5d7 --- /dev/null +++ b/pkg/common/mocks/packetconn.go @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package mocks + +import ( + "net" + "time" +) + +// PacketConn is a mockable [net.PacketConn]. +type PacketConn struct { + // MockReadFrom is the function to call when ReadFrom is called. + MockReadFrom func(p []byte) (int, net.Addr, error) + + // MockWriteTo is the function to call when WriteTo is called. + MockWriteTo func(p []byte, addr net.Addr) (int, error) + + // MockClose is the function to call when Close is called. + MockClose func() error + + // MockLocalAddr is the function to call when LocalAddr is called. + MockLocalAddr func() net.Addr + + // MockSetDeadline is the function to call when SetDeadline is called. + MockSetDeadline func(t time.Time) error + + // MockSetReadDeadline is the function to call when SetReadDeadline is called. + MockSetReadDeadline func(t time.Time) error + + // MockSetWriteDeadline is the function to call when SetWriteDeadline is called. + MockSetWriteDeadline func(t time.Time) error +} + +var _ net.PacketConn = &PacketConn{} + +// ReadFrom calls MockReadFrom. +func (pc *PacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + return pc.MockReadFrom(p) +} + +// WriteTo calls MockWriteTo. +func (pc *PacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + return pc.MockWriteTo(p, addr) +} + +// Close calls MockClose. +func (pc *PacketConn) Close() error { + return pc.MockClose() +} + +// LocalAddr calls MockLocalAddr. +func (pc *PacketConn) LocalAddr() net.Addr { + return pc.MockLocalAddr() +} + +// SetDeadline calls MockSetDeadline. +func (pc *PacketConn) SetDeadline(t time.Time) error { + return pc.MockSetDeadline(t) +} + +// SetReadDeadline calls MockSetReadDeadline. +func (pc *PacketConn) SetReadDeadline(t time.Time) error { + return pc.MockSetReadDeadline(t) +} + +// SetWriteDeadline calls MockSetWriteDeadline. +func (pc *PacketConn) SetWriteDeadline(t time.Time) error { + return pc.MockSetWriteDeadline(t) +} diff --git a/pkg/common/mocks/packetconn_test.go b/pkg/common/mocks/packetconn_test.go new file mode 100644 index 0000000..a0f4115 --- /dev/null +++ b/pkg/common/mocks/packetconn_test.go @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package mocks + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestPacketConn(t *testing.T) { + t.Run("ReadFrom", func(t *testing.T) { + expected := errors.New("mocked error") + expectedAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 8080, + } + pc := &PacketConn{ + MockReadFrom: func(p []byte) (int, net.Addr, error) { + return 0, expectedAddr, expected + }, + } + count, addr, err := pc.ReadFrom(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + if diff := cmp.Diff(expectedAddr, addr); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("WriteTo", func(t *testing.T) { + expected := errors.New("mocked error") + addr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 8080, + } + pc := &PacketConn{ + MockWriteTo: func(p []byte, addr net.Addr) (int, error) { + return 0, expected + }, + } + count, err := pc.WriteTo(make([]byte, 128), addr) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + }) + + t.Run("Close", func(t *testing.T) { + expected := errors.New("mocked error") + pc := &PacketConn{ + MockClose: func() error { + return expected + }, + } + err := pc.Close() + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("LocalAddr", func(t *testing.T) { + expected := &net.UDPAddr{ + IP: net.IPv6loopback, + Port: 1234, + } + pc := &PacketConn{ + MockLocalAddr: func() net.Addr { + return expected + }, + } + out := pc.LocalAddr() + if diff := cmp.Diff(expected, out); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("SetDeadline", func(t *testing.T) { + expected := errors.New("mocked error") + pc := &PacketConn{ + MockSetDeadline: func(t time.Time) error { + return expected + }, + } + err := pc.SetDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("SetReadDeadline", func(t *testing.T) { + expected := errors.New("mocked error") + pc := &PacketConn{ + MockSetReadDeadline: func(t time.Time) error { + return expected + }, + } + err := pc.SetReadDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("SetWriteDeadline", func(t *testing.T) { + expected := errors.New("mocked error") + pc := &PacketConn{ + MockSetWriteDeadline: func(t time.Time) error { + return expected + }, + } + err := pc.SetWriteDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) +} diff --git a/pkg/common/mocks/tlsconn.go b/pkg/common/mocks/tlsconn.go new file mode 100644 index 0000000..d192fbf --- /dev/null +++ b/pkg/common/mocks/tlsconn.go @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package mocks + +import ( + "context" + "crypto/tls" +) + +// TLSConn is a mockable TLS connection. +type TLSConn struct { + // We embed *Conn to handle the net.Conn interface. + *Conn + + // MockConnectionState is the function to call when ConnectionState is called. + MockConnectionState func() tls.ConnectionState + + // MockHandshakeContext is the function to call when HandshakeContext is called. + MockHandshakeContext func(ctx context.Context) error +} + +// ConnectionState calls MockConnectionState. +func (c *TLSConn) ConnectionState() tls.ConnectionState { + return c.MockConnectionState() +} + +// HandshakeContext calls MockHandshakeContext. +func (c *TLSConn) HandshakeContext(ctx context.Context) error { + return c.MockHandshakeContext(ctx) +} diff --git a/pkg/common/mocks/tlsconn_test.go b/pkg/common/mocks/tlsconn_test.go new file mode 100644 index 0000000..3ae1381 --- /dev/null +++ b/pkg/common/mocks/tlsconn_test.go @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package mocks + +import ( + "context" + "crypto/tls" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestTLSConn(t *testing.T) { + t.Run("ConnectionState", func(t *testing.T) { + expectedState := tls.ConnectionState{ + Version: tls.VersionTLS13, + HandshakeComplete: true, + DidResume: false, + CipherSuite: tls.TLS_AES_128_GCM_SHA256, + NegotiatedProtocol: "h2", + ServerName: "example.com", + PeerCertificates: nil, + VerifiedChains: nil, + SignedCertificateTimestamps: nil, + OCSPResponse: nil, + } + + conn := &TLSConn{ + MockConnectionState: func() tls.ConnectionState { + return expectedState + }, + } + + state := conn.ConnectionState() + if diff := cmp.Diff(expectedState, state, + cmpopts.IgnoreUnexported(tls.ConnectionState{})); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("HandshakeContext", func(t *testing.T) { + expected := errors.New("mocked handshake error") + conn := &TLSConn{ + MockHandshakeContext: func(ctx context.Context) error { + return expected + }, + } + + err := conn.HandshakeContext(context.Background()) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Embedded Conn methods", func(t *testing.T) { + expected := errors.New("mocked read error") + conn := &TLSConn{ + Conn: &Conn{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + }, + } + + count, err := conn.Read(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + }) +} diff --git a/pkg/common/netipx/netipx.go b/pkg/common/netipx/netipx.go new file mode 100644 index 0000000..78e596f --- /dev/null +++ b/pkg/common/netipx/netipx.go @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +// Package netipx contains [net/netip] extensions. +package netipx + +import ( + "net" + "net/netip" +) + +// AddrToAddrPort converts a [net.Addr] to a [netip.AddrPort]. +// +// If the input is nil or neither a [*net.TCPAddr] nor [*net.UDPAddr], +// returns an unspecified IPv6 address with port 0. +// +// For [*net.TCPAddr] and [*net.UDPAddr] addresses, returns their +// corresponding [netip.AddrPort] representation. +func AddrToAddrPort(addr net.Addr) netip.AddrPort { + if addr == nil { + return netip.AddrPortFrom(netip.IPv6Unspecified(), 0) + } + if tcp, ok := addr.(*net.TCPAddr); ok { + return tcp.AddrPort() + } + if udp, ok := addr.(*net.UDPAddr); ok { + return udp.AddrPort() + } + return netip.AddrPortFrom(netip.IPv6Unspecified(), 0) +} diff --git a/pkg/common/netipx/netipx_test.go b/pkg/common/netipx/netipx_test.go new file mode 100644 index 0000000..afc5058 --- /dev/null +++ b/pkg/common/netipx/netipx_test.go @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package netipx_test + +import ( + "net" + "net/netip" + "testing" + + "github.com/rbmk-project/rbmk/pkg/common/netipx" + "github.com/stretchr/testify/assert" +) + +func TestAddrToAddrPort(t *testing.T) { + tests := []struct { + name string + addr net.Addr + want netip.AddrPort + }{ + { + name: "nil address", + addr: nil, + want: netip.AddrPortFrom(netip.IPv6Unspecified(), 0), + }, + + { + name: "TCP address", + addr: &net.TCPAddr{ + IP: net.ParseIP("2001:db8::1"), + Port: 1234, + }, + want: netip.MustParseAddrPort("[2001:db8::1]:1234"), + }, + + { + name: "UDP address", + addr: &net.UDPAddr{ + IP: net.ParseIP("2001:db8::2"), + Port: 5678, + }, + want: netip.MustParseAddrPort("[2001:db8::2]:5678"), + }, + + { + name: "other address type", + addr: &net.UnixAddr{}, + want: netip.AddrPortFrom(netip.IPv6Unspecified(), 0), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := netipx.AddrToAddrPort(tt.addr) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/common/runtimex/runtimex.go b/pkg/common/runtimex/runtimex.go new file mode 100644 index 0000000..14c09de --- /dev/null +++ b/pkg/common/runtimex/runtimex.go @@ -0,0 +1,54 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// Adapted from: https://github.com/ooni/probe-cli/blob/v3.20.1/internal/runtimex/runtimex.go +// + +// Package runtimex contains runtime extensions. +// +// This package is inspired to https://pkg.go.dev/github.com/m-lab/go/rtx, except that it's simpler. +package runtimex + +import ( + "errors" + "fmt" +) + +// PanicOnError calls panic() if err is not nil. The type passed +// to panic is an error type wrapping the original error. +func PanicOnError(err error, message string) { + if err != nil { + panic(fmt.Errorf("%s: %w", message, err)) + } +} + +// Assert calls panic if assertion is false. The type passed to +// panic is an error constructed using errors.New(message). +func Assert(assertion bool, message string) { + if !assertion { + panic(errors.New(message)) + } +} + +// Try0 calls [runtimex.PanicOnError] if err is not nil. +func Try0(err error) { + PanicOnError(err, "Try0") +} + +// Try1 is like [Try0] but supports functions returning one values and an error. +func Try1[T1 any](v1 T1, err error) T1 { + PanicOnError(err, "Try1") + return v1 +} + +// Try2 is like [Try1] but supports functions returning two values and an error. +func Try2[T1, T2 any](v1 T1, v2 T2, err error) (T1, T2) { + PanicOnError(err, "Try2") + return v1, v2 +} + +// Try3 is like [Try2] but supports functions returning three values and an error. +func Try3[T1, T2, T3 any](v1 T1, v2 T2, v3 T3, err error) (T1, T2, T3) { + PanicOnError(err, "Try3") + return v1, v2, v3 +} diff --git a/pkg/common/runtimex/runtimex_test.go b/pkg/common/runtimex/runtimex_test.go new file mode 100644 index 0000000..fc632ba --- /dev/null +++ b/pkg/common/runtimex/runtimex_test.go @@ -0,0 +1,156 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// Adapted from: https://github.com/ooni/probe-cli/blob/v3.20.1/internal/runtimex/runtimex_test.go +// + +package runtimex_test + +import ( + "errors" + "testing" + + "github.com/rbmk-project/rbmk/pkg/common/runtimex" +) + +func TestPanicOnError(t *testing.T) { + badfunc := func(in error) (out error) { + defer func() { + out = recover().(error) + }() + runtimex.PanicOnError(in, "we expect this assertion to fail") + return + } + + t.Run("error is nil", func(t *testing.T) { + runtimex.PanicOnError(nil, "this assertion should not fail") + }) + + t.Run("error is not nil", func(t *testing.T) { + expected := errors.New("mocked error") + if !errors.Is(badfunc(expected), expected) { + t.Fatal("not the error we expected") + } + }) +} + +func TestAssert(t *testing.T) { + badfunc := func(in bool, message string) (out error) { + defer func() { + out = recover().(error) + }() + runtimex.Assert(in, message) + return + } + + t.Run("assertion is true", func(t *testing.T) { + runtimex.Assert(true, "this assertion should not fail") + }) + + t.Run("assertion is false", func(t *testing.T) { + message := "mocked error" + err := badfunc(false, message) + if err == nil || err.Error() != message { + t.Fatal("not the error we expected", err) + } + }) +} + +func TestTry(t *testing.T) { + t.Run("Try0", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + runtimex.Try0(nil) + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + var got error + func() { + defer func() { + if r := recover(); r != nil { + got = r.(error) + } + }() + runtimex.Try0(expected) + }() + if !errors.Is(got, expected) { + t.Fatal("unexpected error") + } + }) + }) + + t.Run("Try1", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + v1 := runtimex.Try1(17, nil) + if v1 != 17 { + t.Fatal("unexpected value") + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + var got error + func() { + defer func() { + if r := recover(); r != nil { + got = r.(error) + } + }() + runtimex.Try1(17, expected) + }() + if !errors.Is(got, expected) { + t.Fatal("unexpected error") + } + }) + }) + + t.Run("Try2", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + v1, v2 := runtimex.Try2(17, true, nil) + if v1 != 17 || !v2 { + t.Fatal("unexpected value") + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + var got error + func() { + defer func() { + if r := recover(); r != nil { + got = r.(error) + } + }() + runtimex.Try2(17, true, expected) + }() + if !errors.Is(got, expected) { + t.Fatal("unexpected error") + } + }) + }) + + t.Run("Try3", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + v1, v2, v3 := runtimex.Try3(17, true, 44.0, nil) + if v1 != 17 || !v2 || v3 != 44.0 { + t.Fatal("unexpected value") + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + var got error + func() { + defer func() { + if r := recover(); r != nil { + got = r.(error) + } + }() + runtimex.Try3(17, true, 44.0, expected) + }() + if !errors.Is(got, expected) { + t.Fatal("unexpected error") + } + }) + }) +} diff --git a/pkg/common/selfsignedcert/selfsignedcert.go b/pkg/common/selfsignedcert/selfsignedcert.go new file mode 100644 index 0000000..58f0b82 --- /dev/null +++ b/pkg/common/selfsignedcert/selfsignedcert.go @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +// Package selfsignedcert helps to create self-signed certificates. +package selfsignedcert + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "path/filepath" + "time" + + "github.com/rbmk-project/common/runtimex" +) + +// Config contains configuration for [New]. +type Config struct { + // CommonName is the certificate common name. + CommonName string + + // DNSNames contains the alternative DNS names to include in the certificate. + DNSNames []string + + // IPAddrs contains the IP addrs for which the certificate is valid. + IPAddrs []net.IP +} + +// NewConfigExampleCom creates a [*Config] for example.com +// using the www.example.com, 127.0.0.1, and ::1 sans. +func NewConfigExampleCom() *Config { + config := &Config{ + CommonName: "example.com", + DNSNames: []string{"www.example.com"}, + IPAddrs: []net.IP{ + net.ParseIP("127.0.0.1"), + net.ParseIP("::1"), + }, + } + return config +} + +// Cert is the self-signed certificate. +type Cert struct { + // CertPEM is the certificate encoded using PEM. + CertPEM []byte + + // KeyPEM is the secret key encoded using PEM. + KeyPEM []byte +} + +// WriteFiles writes CertPEM to `cert.pem` and KeyPEM to `key.pem`. +// +// This method panics on failure. +func (c *Cert) WriteFiles(baseDir string) { + runtimex.Try0(os.WriteFile(filepath.Join(baseDir, "cert.pem"), c.CertPEM, 0600)) + runtimex.Try0(os.WriteFile(filepath.Join(baseDir, "key.pem"), c.KeyPEM, 0600)) +} + +// New generates a self-signed certificate and key with SANs. +// +// This function panics on failure. +func New(config *Config) *Cert { + // Generate the private key + priv := runtimex.Try1(ecdsa.GenerateKey(elliptic.P256(), rand.Reader)) + + // Build the certificate template + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + serialNumber := runtimex.Try1(rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"RBMK Project"}, + CommonName: config.CommonName, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // Add SANs to the certificate + template.DNSNames = config.DNSNames + template.IPAddresses = config.IPAddrs + + // Generate the certificate proper and encoded to PEM + certDER := runtimex.Try1(x509.CreateCertificate( + rand.Reader, &template, &template, &priv.PublicKey, priv)) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + // Generate the private key in PEM format + keyPEM := runtimex.Try1(x509.MarshalECPrivateKey(priv)) + keyPEMBytes := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyPEM}) + + // Return the results + return &Cert{CertPEM: certPEM, KeyPEM: keyPEMBytes} +} diff --git a/pkg/common/selfsignedcert/selfsignedcert_test.go b/pkg/common/selfsignedcert/selfsignedcert_test.go new file mode 100644 index 0000000..bfc02fe --- /dev/null +++ b/pkg/common/selfsignedcert/selfsignedcert_test.go @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package selfsignedcert_test + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "io" + "net/http" + "net/url" + "testing" + + "github.com/rbmk-project/rbmk/pkg/common/runtimex" + "github.com/rbmk-project/rbmk/pkg/common/selfsignedcert" +) + +func TestSelfSignedCert(t *testing.T) { + // 1. generate the certificate and private key + cert := selfsignedcert.New(selfsignedcert.NewConfigExampleCom()) + cert.WriteFiles("testdata") + + // 2. create a suitable TLS listener + serverConfig := &tls.Config{Certificates: []tls.Certificate{ + runtimex.Try1(tls.X509KeyPair(cert.CertPEM, cert.KeyPEM)), + }} + listener := runtimex.Try1(tls.Listen("tcp", "127.0.0.1:0", serverConfig)) + defer listener.Close() + + // 3. create a listening HTTP server using the testdata files + expectByes := []byte("Bonsoir, Elliot!\n") + srv := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(expectByes) + }), + } + go srv.Serve(listener) + + // 4. create a suitable HTTP client + pool := x509.NewCertPool() + runtimex.Assert(pool.AppendCertsFromPEM(cert.CertPEM), "cannot append PEM cert") + clientConfig := &tls.Config{RootCAs: pool} + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: clientConfig, + }, + } + + // 5. perform an HTTP round trip + URL := &url.URL{Scheme: "https", Host: listener.Addr().String(), Path: "/"} + resp, err := client.Get(URL.String()) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // 6. make sure the response is correct + if resp.StatusCode != http.StatusOK { + t.Fatal("expected 200, got", resp.StatusCode) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expectByes, body) { + t.Fatal("expected", expectByes, ", got", body) + } +} diff --git a/pkg/common/selfsignedcert/testdata/.gitignore b/pkg/common/selfsignedcert/testdata/.gitignore new file mode 100644 index 0000000..fc4d927 --- /dev/null +++ b/pkg/common/selfsignedcert/testdata/.gitignore @@ -0,0 +1,2 @@ +/cert.pem +/key.pem diff --git a/pkg/dns/dnscore/answer.go b/pkg/dns/dnscore/answer.go new file mode 100644 index 0000000..0b3c0ce --- /dev/null +++ b/pkg/dns/dnscore/answer.go @@ -0,0 +1,49 @@ +// +// SPDX-License-Identifier: BSD-3-Clause +// +// Adapted from: https://github.com/ooni/probe-engine/blob/v0.23.0/netx/resolver/decoder.go +// +// Answer RRs decoder +// + +package dnscore + +import "github.com/miekg/dns" + +// DecodeLookupA decodes RRs from a lookup A response. +func DecodeLookupA(rrs []dns.RR) (addrs []string, cname string, err error) { + for _, answer := range rrs { + switch answer := answer.(type) { + case *dns.A: + addrs = append(addrs, answer.A.String()) + + case *dns.CNAME: + cname = answer.Target + } + } + + if len(addrs) <= 0 { + return nil, "", ErrNoData + } + + return +} + +// DecodeLookupAAAA decodes RRs from a lookup AAAA response. +func DecodeLookupAAAA(rrs []dns.RR) (addrs []string, cname string, err error) { + for _, answer := range rrs { + switch answer := answer.(type) { + case *dns.AAAA: + addrs = append(addrs, answer.AAAA.String()) + + case *dns.CNAME: + cname = answer.Target + } + } + + if len(addrs) <= 0 { + return nil, "", ErrNoData + } + + return +} diff --git a/pkg/dns/dnscore/answer_test.go b/pkg/dns/dnscore/answer_test.go new file mode 100644 index 0000000..f79d95e --- /dev/null +++ b/pkg/dns/dnscore/answer_test.go @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "net" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func TestDecodeLookupA(t *testing.T) { + tests := []struct { + name string + rrs []dns.RR + expected []string + cname string + err error + }{ + { + name: "Single A record", + rrs: []dns.RR{ + &dns.A{A: net.ParseIP("192.0.2.1")}, + }, + expected: []string{"192.0.2.1"}, + cname: "", + err: nil, + }, + + { + name: "Single CNAME record", + rrs: []dns.RR{ + &dns.CNAME{Target: "example.com."}, + }, + expected: nil, + cname: "", + err: ErrNoData, + }, + + { + name: "Multiple A records without CNAME", + rrs: []dns.RR{ + &dns.A{A: net.ParseIP("192.0.2.1")}, + &dns.A{A: net.ParseIP("192.0.2.2")}, + }, + expected: []string{"192.0.2.1", "192.0.2.2"}, + cname: "", + err: nil, + }, + + { + name: "Multiple A records with CNAME", + rrs: []dns.RR{ + &dns.A{A: net.ParseIP("192.0.2.1")}, + &dns.A{A: net.ParseIP("192.0.2.2")}, + &dns.CNAME{Target: "example.com."}, + }, + expected: []string{"192.0.2.1", "192.0.2.2"}, + cname: "example.com.", + err: nil, + }, + + { + name: "No A records", + rrs: []dns.RR{}, + expected: nil, + cname: "", + err: ErrNoData, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addrs, cname, err := DecodeLookupA(tt.rrs) + assert.Equal(t, tt.expected, addrs) + assert.Equal(t, tt.cname, cname) + assert.Equal(t, tt.err, err) + }) + } +} + +func TestDecodeLookupAAAA(t *testing.T) { + tests := []struct { + name string + rrs []dns.RR + expected []string + cname string + err error + }{ + { + name: "Single AAAA record", + rrs: []dns.RR{ + &dns.AAAA{AAAA: net.ParseIP("2001:db8::1")}, + }, + expected: []string{"2001:db8::1"}, + cname: "", + err: nil, + }, + + { + name: "Single CNAME record", + rrs: []dns.RR{ + &dns.CNAME{Target: "example.com."}, + }, + expected: nil, + cname: "", + err: ErrNoData, + }, + + { + name: "Multiple AAAA records without CNAME", + rrs: []dns.RR{ + &dns.AAAA{AAAA: net.ParseIP("2001:db8::1")}, + &dns.AAAA{AAAA: net.ParseIP("2001:db8::2")}, + }, + expected: []string{"2001:db8::1", "2001:db8::2"}, + cname: "", + err: nil, + }, + + { + name: "Multiple AAAA records with CNAME", + rrs: []dns.RR{ + &dns.AAAA{AAAA: net.ParseIP("2001:db8::1")}, + &dns.AAAA{AAAA: net.ParseIP("2001:db8::2")}, + &dns.CNAME{Target: "example.com."}, + }, + expected: []string{"2001:db8::1", "2001:db8::2"}, + cname: "example.com.", + err: nil, + }, + + { + name: "No AAAA records", + rrs: []dns.RR{}, + expected: nil, + cname: "", + err: ErrNoData, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addrs, cname, err := DecodeLookupAAAA(tt.rrs) + assert.Equal(t, tt.expected, addrs) + assert.Equal(t, tt.cname, cname) + assert.Equal(t, tt.err, err) + }) + } +} diff --git a/pkg/dns/dnscore/doc.go b/pkg/dns/dnscore/doc.go new file mode 100644 index 0000000..d07d5bd --- /dev/null +++ b/pkg/dns/dnscore/doc.go @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +/* +Package dnscore provides a DNS resolver, a DNS transport, a query builder, +and a DNS response parser. + +This package is designed to facilitate DNS measurements and queries +by providing both high-level and low-level APIs. It aims to be flexible, +extensible, and easy to integrate with existing Go code. + +The high-level [*Resolver] API provides a DNS resolver that is compatible with the +[*net.Resolver] struct from the [net] package. The low-level [*Transport] API +allows users to send and receive DNS messages using different protocols +and dialers. The package also includes utilities for creating and validating +DNS messages. + +# Features + +- High-level [*Resolver] API compatible with [*net.Resolver] for easy integration. + +- Low-level [*Transport] API allowing granular control over DNS requests and responses. + +- Support for multiple DNS protocols, including UDP, TCP, DoT, DoH, and DoQ. + +- Utilities for creating and validating DNS messages. + +- Optional logging for structured diagnostic events through [log/slog]. + +- Handling of duplicate responses for DNS over UDP to measure censorship. + +The package is structured to allow users to compose their own workflows +by providing building blocks for DNS queries and responses. It uses +the widely-used [github.com/miekg/dns] library for DNS message parsing +and serialization. + +# Design Documents + +The [dd-000-dnscore.md] document describes the design of this package. + +The [df-000-dns.md] document describes the data format generated by this +package when using [log/slog] to emit structured diagnostic events. + +[dd-000-dnscore.md]: https://github.com/rbmk-project/rbmk-project.github.io/blob/main/docs/design/dd-000-dnscore.md +[df-000-dns.md]: https://github.com/rbmk-project/rbmk-project.github.io/blob/main/docs/spec/data-format/df-000-dns.md +*/ +package dnscore diff --git a/pkg/dns/dnscore/dohttps.go b/pkg/dns/dnscore/dohttps.go new file mode 100644 index 0000000..3cbb1e5 --- /dev/null +++ b/pkg/dns/dnscore/dohttps.go @@ -0,0 +1,148 @@ +// +// SPDX-License-Identifier: BSD-3-Clause +// +// Adapted from: https://github.com/ooni/probe-engine/blob/v0.23.0/netx/resolver/dnsoverhttps.go +// +// DNS-over-HTTPS implementation +// + +package dnscore + +import ( + "bytes" + "context" + "io" + "net/http" + "net/netip" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/httpconntrace" + "github.com/rbmk-project/rbmk/pkg/common/httpslog" +) + +// newHTTPRequestWithContext is a helper function that creates a new HTTP request +// using the namesake transport function or the stdlib if the such a function is nil. +func (t *Transport) newHTTPRequestWithContext( + ctx context.Context, method, URL string, body io.Reader) (*http.Request, error) { + if t.NewHTTPRequestWithContext != nil { + return t.NewHTTPRequestWithContext(ctx, method, URL, body) + } + return http.NewRequestWithContext(ctx, method, URL, body) +} + +// httpClient is a helper function that returns the HTTP client using the +// specific transport field or the stdlib if the given field is nil. +func (t *Transport) httpClient() *http.Client { + if t.HTTPClient != nil { + return t.HTTPClient + } + return http.DefaultClient +} + +// httpClientDo performs an HTTP request using one of two methods: +// +// 1. if HTTPClientDo is not nil, use it directly; +// +// 2. otherwise use [*Transport.httpClient] to obtain a suitable +// [*http.Client] and perform the request with it. +func (t *Transport) httpClientDo(req *http.Request) (*http.Response, netip.AddrPort, netip.AddrPort, error) { + // If HTTPClientDo isn't nil, use it directly. + if t.HTTPClientDo != nil { + return t.HTTPClientDo(req) + } + + // Otherwise use httpconntrace.Do to perform the request + resp, endpoints, err := httpconntrace.Do(t.httpClient(), req) + return resp, endpoints.LocalAddr, endpoints.RemoteAddr, err +} + +// readAllContext is a helper function that reads all from the reader using the +// namesake transport function or the stdlib if the given function is nil. +func (t *Transport) readAllContext(ctx context.Context, r io.Reader, c io.Closer) ([]byte, error) { + if t.ReadAllContext != nil { + return t.ReadAllContext(ctx, r, c) + } + return io.ReadAll(r) +} + +// queryHTTPS implements [*Transport.Query] for DNS over HTTPS. +func (t *Transport) queryHTTPS(ctx context.Context, + addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + // 0. immediately fail if the context is already done, which + // is useful to write unit tests + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // 1. Serialize the query and possibly log that we're sending it. + rawQuery, err := query.Pack() + if err != nil { + return nil, err + } + t0 := t.maybeLogQuery(ctx, addr, rawQuery) + + // 2. The query is sent as the body of a POST request. The content-type + // header must be set. Otherwise servers may respond with 400. + req, err := t.newHTTPRequestWithContext(ctx, "POST", addr.Address, bytes.NewReader(rawQuery)) + if err != nil { + return nil, err + } + req.Header.Set("content-type", "application/dns-message") + + // 3. Log the HTTP request we're sending. + httpslog.MaybeLogRoundTripStart( + t.Logger, + netip.MustParseAddrPort("[::]:0"), // not yet known + "tcp", + netip.MustParseAddrPort("[::]:0"), // not yet known + req, + t0, + ) + + // 4. Receive the response headers making sure we close + // the body, the response code is 200, and the content type + // is the expected one. Since servers always include the + // content type, we don't need to be flexible here. + httpResp, laddr, raddr, err := t.httpClientDo(req) + + // 5. Log the result of the HTTP transfer. + httpslog.MaybeLogRoundTripDone( + t.Logger, + laddr, + "tcp", + raddr, + req, + httpResp, + err, + t0, + t.timeNow(), + ) + + // 6. Make sure we close the body, the response code is 200, + // and the content type is the expected one. Since servers + // always include the content type, we don't need to be flexible here. + if err != nil { + return nil, err + } + defer httpResp.Body.Close() + if httpResp.StatusCode != 200 { + return nil, ErrServerMisbehaving + } + if httpResp.Header.Get("content-type") != "application/dns-message" { + return nil, ErrServerMisbehaving + } + + // 7. Now that headers are OK, we read the whole raw response + // body, decode it, and possibly log it. + reader := io.LimitReader(httpResp.Body, int64(edns0MaxResponseSize(query))) + rawResp, err := t.readAllContext(ctx, reader, httpResp.Body) + if err != nil { + return nil, err + } + resp := new(dns.Msg) + if err := resp.Unpack(rawResp); err != nil { + return nil, err + } + t.maybeLogResponseAddrPort(ctx, addr, t0, rawQuery, rawResp, laddr, raddr) + return resp, nil +} diff --git a/pkg/dns/dnscore/dohttps_test.go b/pkg/dns/dnscore/dohttps_test.go new file mode 100644 index 0000000..d785635 --- /dev/null +++ b/pkg/dns/dnscore/dohttps_test.go @@ -0,0 +1,559 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "bytes" + "context" + "errors" + "io" + "net" + "net/http" + "net/http/httptrace" + "net/netip" + "testing" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/mocks" + "github.com/rbmk-project/rbmk/pkg/common/runtimex" + "github.com/stretchr/testify/assert" +) + +func TestTransport_newHTTPRequestWithContext(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + method string + url string + body io.Reader + expectedError error + }{ + { + name: "Successful request with custom function", + setupTransport: func() *Transport { + return &Transport{ + NewHTTPRequestWithContext: func(ctx context.Context, method, URL string, body io.Reader) (*http.Request, error) { + return http.NewRequestWithContext(ctx, method, URL, body) + }, + } + }, + method: "GET", + url: "https://example.com", + body: nil, + expectedError: nil, + }, + + { + name: "Successful request with default function", + setupTransport: func() *Transport { + return &Transport{} + }, + method: "GET", + url: "https://example.com", + body: nil, + expectedError: nil, + }, + + { + name: "Invalid URL", + setupTransport: func() *Transport { + return &Transport{} + }, + method: "GET", + url: "https://example.com\t", + body: nil, + expectedError: errors.New("parse \"https://example.com\\t\": net/url: invalid control character in URL"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + _, err := transport.newHTTPRequestWithContext(context.Background(), tt.method, tt.url, tt.body) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTransport_httpClient(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + expectedClient *http.Client + }{ + { + name: "Custom HTTP client", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{}, + } + }, + expectedClient: &http.Client{}, + }, + + { + name: "Default HTTP client", + setupTransport: func() *Transport { + return &Transport{} + }, + expectedClient: http.DefaultClient, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + client := transport.httpClient() + assert.Equal(t, tt.expectedClient, client) + }) + } +} + +func TestTransport_httpClientDo(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + expectedError error + expectedLocalAddr netip.AddrPort + expectedRemoteAddr netip.AddrPort + }{ + { + name: "HTTPClientDo takes precedence", + setupTransport: func() *Transport { + return &Transport{ + HTTPClientDo: func(req *http.Request) (*http.Response, netip.AddrPort, netip.AddrPort, error) { + return &http.Response{StatusCode: 200}, netip.AddrPort{}, netip.AddrPort{}, nil + }, + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, errors.New("should not be called") + }, + }, + }, + } + }, + expectedError: nil, + expectedLocalAddr: netip.AddrPort{}, + expectedRemoteAddr: netip.AddrPort{}, + }, + + { + name: "HTTPClientDo returns error", + setupTransport: func() *Transport { + return &Transport{ + HTTPClientDo: func(req *http.Request) (*http.Response, netip.AddrPort, netip.AddrPort, error) { + return nil, netip.AddrPort{}, netip.AddrPort{}, errors.New("custom error") + }, + } + }, + expectedError: errors.New("custom error"), + expectedLocalAddr: netip.AddrPort{}, + expectedRemoteAddr: netip.AddrPort{}, + }, + + { + name: "Fallback to HTTPClient success", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 200}, nil + }, + }, + }, + } + }, + expectedError: nil, + expectedLocalAddr: netip.AddrPort{}, + expectedRemoteAddr: netip.AddrPort{}, + }, + + { + name: "Fallback to HTTPClient failure", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, errors.New("http error") + }, + }, + }, + } + }, + expectedError: errors.New("Get \"https://example.com\": http error"), + expectedLocalAddr: netip.AddrPort{}, + expectedRemoteAddr: netip.AddrPort{}, + }, + + { + name: "Fallback to HTTPClient collects addresses", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{ + Conn: &mocks.Conn{ + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.ParseIP("::1"), + Port: 12345, + } + }, + MockRemoteAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.ParseIP("::2"), + Port: 443, + } + }, + }, + }) + } + return &http.Response{StatusCode: 200}, nil + }, + }, + }, + } + }, + expectedError: nil, + expectedLocalAddr: netip.MustParseAddrPort("[::1]:12345"), + expectedRemoteAddr: netip.MustParseAddrPort("[::2]:443"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + req := runtimex.Try1(http.NewRequest("GET", "https://example.com", nil)) + resp, la, ra, err := transport.httpClientDo(req) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + assert.Nil(t, resp) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + } + assert.Equal(t, tt.expectedLocalAddr, la) + assert.Equal(t, tt.expectedRemoteAddr, ra) + }) + } +} + +func TestTransport_readAllContext(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + reader io.Reader + closer io.Closer + expectedData []byte + expectedError error + }{ + { + name: "Successful read with custom function", + setupTransport: func() *Transport { + return &Transport{ + ReadAllContext: func(ctx context.Context, r io.Reader, c io.Closer) ([]byte, error) { + return io.ReadAll(r) + }, + } + }, + reader: bytes.NewReader([]byte("test data")), + closer: io.NopCloser(nil), + expectedData: []byte("test data"), + expectedError: nil, + }, + + { + name: "Successful read with default function", + setupTransport: func() *Transport { + return &Transport{} + }, + reader: bytes.NewReader([]byte("test data")), + closer: io.NopCloser(nil), + expectedData: []byte("test data"), + expectedError: nil, + }, + + { + name: "Read failure", + setupTransport: func() *Transport { + return &Transport{} + }, + reader: &mocks.Conn{MockRead: func(b []byte) (int, error) { return 0, errors.New("read failed") }}, + closer: io.NopCloser(nil), + expectedData: nil, + expectedError: errors.New("read failed"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + data, err := transport.readAllContext(context.Background(), tt.reader, tt.closer) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedData, data) + } + }) + } +} + +func TestTransport_queryHTTPS(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + questionName string + url string + expectedError error + }{ + { + name: "Successful query", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + dnsResp := &dns.Msg{} + rawDnsResp, err := dnsResp.Pack() + if err != nil { + panic(err) + } + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(rawDnsResp)), + } + resp.Header.Set("content-type", "application/dns-message") + return resp, nil + }, + }, + }, + } + }, + questionName: "example.com.", + url: "https://dns.google/dns-query", + expectedError: nil, + }, + + { + name: "HTTP request failure", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, errors.New("http request failed") + }, + }, + }, + } + }, + questionName: "example.com.", + url: "https://dns.google/dns-query", + expectedError: errors.New("Post \"https://dns.google/dns-query\": http request failed"), + }, + + { + name: "Non-200 HTTP status code", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 500, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte{})), + } + return resp, nil + }, + }, + }, + } + }, + questionName: "example.com.", + url: "https://dns.google/dns-query", + expectedError: ErrServerMisbehaving, + }, + + { + name: "Invalid content type", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte{})), + } + resp.Header.Set("content-type", "text/plain") + return resp, nil + }, + }, + }, + } + }, + questionName: "example.com.", + url: "https://dns.google/dns-query", + expectedError: ErrServerMisbehaving, + }, + + { + name: "Invalid DNS response", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte{0xFF})), + } + resp.Header.Set("content-type", "application/dns-message") + return resp, nil + }, + }, + }, + } + }, + questionName: "example.com.", + url: "https://dns.google/dns-query", + expectedError: errors.New("dns: overflow unpacking uint16"), + }, + + { + name: "Non-FQDN domain name", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, nil + }, + }, + }, + } + }, + questionName: "example", + url: "https://dns.google/dns-query", + expectedError: errors.New("dns: domain must be fully qualified"), + }, + + { + name: "Invalid URL", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, nil + }, + }, + }, + } + }, + questionName: "example.com.", + url: "https://dns.google/dns-query\t", + expectedError: errors.New("parse \"https://dns.google/dns-query\\t\": net/url: invalid control character in URL"), + }, + + { + name: "Fail reading response body", + setupTransport: func() *Transport { + return &Transport{ + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, errors.New("read failed") + }, + MockClose: func() error { + return nil + }, + }, + } + resp.Header.Set("content-type", "application/dns-message") + return resp, nil + }, + }, + }, + } + }, + questionName: "example.com.", + url: "https://dns.google/dns-query", + expectedError: errors.New("read failed"), + }, + + { + name: "HTTPClientDo takes precedence over HTTPClient", + setupTransport: func() *Transport { + return &Transport{ + // Should be used + HTTPClientDo: func(req *http.Request) (*http.Response, netip.AddrPort, netip.AddrPort, error) { + dnsResp := &dns.Msg{} + rawDnsResp, err := dnsResp.Pack() + if err != nil { + panic(err) + } + resp := &http.Response{ + StatusCode: 200, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(rawDnsResp)), + } + resp.Header.Set("content-type", "application/dns-message") + return resp, netip.AddrPort{}, netip.AddrPort{}, nil + }, + // Should not be used + HTTPClient: &http.Client{ + Transport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, errors.New("HTTPClient should not be used") + }, + }, + }, + } + }, + questionName: "example.com.", + url: "https://dns.google/dns-query", + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + addr := &ServerAddr{Address: tt.url, Protocol: ProtocolDoH} + query := new(dns.Msg) + query.SetQuestion(tt.questionName, dns.TypeA) + + _, err := transport.queryHTTPS(context.Background(), addr, query) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/dns/dnscore/doquic.go b/pkg/dns/dnscore/doquic.go new file mode 100644 index 0000000..ba72137 --- /dev/null +++ b/pkg/dns/dnscore/doquic.go @@ -0,0 +1,151 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// DNS-over-QUIC implementation +// +// Written by @roopeshsn and @bassosimone +// +// See https://github.com/rbmk-project/dnscore/pull/18 +// +// See https://datatracker.ietf.org/doc/rfc9250/ +// + +package dnscore + +import ( + "context" + "crypto/tls" + "net" + "time" + + "github.com/miekg/dns" + "github.com/quic-go/quic-go" + "github.com/rbmk-project/rbmk/pkg/common/closepool" +) + +func (t *Transport) queryQUIC(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + // 0. immediately fail if the context is already done, which + // is useful to write unit tests + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // 1. Fill the TLS configuration + hostname, _, err := net.SplitHostPort(addr.Address) + if err != nil { + return nil, err + } + tlsConfig := &tls.Config{ + NextProtos: []string{"doq"}, + ServerName: hostname, + RootCAs: t.RootCAs, + } + + // 2. Create a connection pool to close all opened connections + // and ensure we don't leak resources by using defer. + connPool := &closepool.Pool{} + defer connPool.Close() + + // TODO(bassosimone,roopeshsn): for TCP connections, we abstract + // this process of combining the DNS lookup and dialing a connection, + // which, in turn, allows for better unit testing and also allows + // rbmk-project/rbmk to use rbmk-project/x/netcore for dialing. + // + // We should probably see to create a similar dialing interface in + // rbmk-project/x/netcore for QUIC connections. We started discussing + // this in https://github.com/rbmk-project/dnscore/pull/18. + + // 3. Open the UDP connection for supporting QUIC + listenConfig := &net.ListenConfig{} + udpConn, err := listenConfig.ListenPacket(ctx, "udp", ":0") + if err != nil { + return nil, err + } + connPool.Add(udpConn) + + // 4. Map the UDP address, which may possibly contain a domain + // name, to an actual UDP address structure to dial with + udpAddr, err := net.ResolveUDPAddr("udp", addr.Address) + if err != nil { + return nil, err + } + + // 5. Establish a QUIC connection. Note that the default + // configuration implies a 5s timeout for handshaking and + // a 30s idle connection timeout. + tr := &quic.Transport{ + Conn: udpConn, + } + connPool.Add(tr) + quicConfig := &quic.Config{} + quicConn, err := tr.Dial(ctx, udpAddr, tlsConfig, quicConfig) + if err != nil { + return nil, err + } + connPool.Add(closepool.CloserFunc(func() error { + // Closing w/o specific error -- RFC 9250 Sect. 4.3 + const doq_no_error = 0x00 + return quicConn.CloseWithError(doq_no_error, "") + })) + + // 6. Open a stream for sending the DoQ query and wrap it into + // an adapter that makes it usable by DNS-over-stream code + quicStream, err := quicConn.OpenStream() + if err != nil { + return nil, err + } + stream := &quicStreamAdapter{ + Stream: quicStream, + localAddr: quicConn.LocalAddr(), + remoteAddr: quicConn.RemoteAddr(), + } + connPool.Add(stream) + + // 7. Ensure that we tear down everything which we have set up + // in the case in which the context is canceled + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + defer connPool.Close() + <-ctx.Done() + }() + + // 8. defer to queryStream. Note that this method TAKES OWNERSHIP of + // the stream and closes it after we've sent the query, honouring the + // expectations for DoQ queries -- see RFC 9250 Sect. 4.2. + return t.queryStream(ctx, addr, query, stream) +} + +// quicStreamAdapter ensures a QUIC stream implements [dnsStream]. +type quicStreamAdapter struct { + Stream *quic.Stream + localAddr net.Addr + remoteAddr net.Addr +} + +// Make sure we actually implement [dnsStream]. +var _ dnsStream = &quicStreamAdapter{} + +func (qsw *quicStreamAdapter) Read(p []byte) (int, error) { + return qsw.Stream.Read(p) +} + +func (qsw *quicStreamAdapter) Write(p []byte) (int, error) { + return qsw.Stream.Write(p) +} + +func (qsw *quicStreamAdapter) Close() error { + return qsw.Stream.Close() +} + +func (qsw *quicStreamAdapter) SetDeadline(t time.Time) error { + return qsw.Stream.SetDeadline(t) +} + +func (qsw *quicStreamAdapter) LocalAddr() net.Addr { + return qsw.localAddr +} + +func (qsw *quicStreamAdapter) RemoteAddr() net.Addr { + return qsw.remoteAddr +} diff --git a/pkg/dns/dnscore/doquic_test.go b/pkg/dns/dnscore/doquic_test.go new file mode 100644 index 0000000..2c195fd --- /dev/null +++ b/pkg/dns/dnscore/doquic_test.go @@ -0,0 +1,51 @@ +package dnscore + +import ( + "context" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func TestTransport_queryQUIC(t *testing.T) { + // TODO(bassosimone,roopeshsn): currently this is an integration test + // using the network w/ real servers but we should instead have: + // + // 1. an integration test using the network but using a QUIC server running + // locally (a test which should live inside integration_test.go) + // + // 2. unit tests using mocking like we do for, e.g.m dohttps_test.go + + tests := []struct { + name string + setupTransport func() *Transport + expectedError error + }{ + { + name: "Successful query", + setupTransport: func() *Transport { + return &Transport{} + }, + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + addr := NewServerAddr(ProtocolDoQ, "dns.adguard.com:853") + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + _, err := transport.queryQUIC(context.Background(), addr, query) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/dns/dnscore/dotcp.go b/pkg/dns/dnscore/dotcp.go new file mode 100644 index 0000000..73aef4c --- /dev/null +++ b/pkg/dns/dnscore/dotcp.go @@ -0,0 +1,156 @@ +// +// SPDX-License-Identifier: BSD-3-Clause +// +// Adapted from: https://github.com/ooni/probe-engine/blob/v0.23.0/netx/resolver/dnsovertcp.go +// +// DNS-over-TCP implementation. Includes generic code to +// send queries over streams used by DoT and DoQ. +// + +package dnscore + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "math" + "net" + "time" + + "github.com/miekg/dns" +) + +// dnsStream is the interface expected by [*Transport.queryStream], +type dnsStream interface { + io.ReadWriteCloser + SetDeadline(t time.Time) error + LocalAddr() net.Addr + RemoteAddr() net.Addr +} + +// queryTCP implements [*Transport.Query] for DNS over TCP. +func (t *Transport) queryTCP(ctx context.Context, + addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + // 0. immediately fail if the context is already done, which + // is useful to write unit tests + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // 1. Dial the connection + conn, err := t.dialContext(ctx, "tcp", addr.Address) + + // 2. Handle dialing failure + if err != nil { + return nil, err + } + + // 3. Transfer conn ownership and perform the round trip + return t.queryStream(ctx, addr, query, conn) +} + +// ErrQueryTooLargeForTransport indicates that a query is too large for the transport. +var ErrQueryTooLargeForTransport = errors.New("query too large for transport") + +// queryMsg is an interface modeling [*dns.Msg] to allow for +// testing [*Transport.queryStream] more easily. +type queryMsg interface { + Pack() ([]byte, error) +} + +// queryStream performs the round trip over the given TCP/TLS stream. +// +// This method TAKES OWNERSHIP of the provided connection and is +// responsible for closing it when done. +func (t *Transport) queryStream(ctx context.Context, + addr *ServerAddr, query queryMsg, conn dnsStream) (*dns.Msg, error) { + + // 1. Use a single connection for request, which is what the standard library + // does as well for TCP and is more robust in terms of residual censorship. + // + // In the future, we may want to reuse a TLS connection for multiple queries + // + // Make sure we react to context being canceled early. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + defer conn.Close() + <-ctx.Done() + }() + + // 2. Use the context deadline to limit the query lifetime + // as documented in the [*Transport.Query] function. + if deadline, ok := ctx.Deadline(); ok { + _ = conn.SetDeadline(deadline) + } + + // 3. Serialize the query and possibly log that we're sending it. + rawQuery, err := query.Pack() + if err != nil { + return nil, err + } + t0 := t.maybeLogQuery(ctx, addr, rawQuery) + + // 4. Wrap the query into a frame + rawQueryFrame, err := newRawMsgFrame(addr, rawQuery) + if err != nil { + return nil, err + } + + // 5. Send the query. Do not bother with logging the write call + // since that should be done by a custom dialer that wraps the + // returned connection and implements the desired logging. + if _, err := conn.Write(rawQueryFrame); err != nil { + return nil, err + } + + // 5b. Ensure we close the stream when using DoQ to signal the + // upstream server that it is okay to send a response. + // + // RFC 9250 is very clear in this respect: + // + // 4.2. Stream Mapping and Usage + // client MUST send the DNS query over the selected stream and MUST + // indicate through the STREAM FIN mechanism that no further data will + // be sent on that stream. + // + // Empirical testing during https://github.com/rbmk-project/dnscore/pull/18 + // showed that, in fact, some servers misbehave if we don't do this. + if _, ok := conn.(*quicStreamAdapter); ok { + _ = conn.Close() + } + + // 6. Wrap the conn to avoid issuing too many reads + // then read the response header and query + br := bufio.NewReader(conn) + header := make([]byte, 2) + if _, err := io.ReadFull(br, header); err != nil { + return nil, err + } + length := int(header[0])<<8 | int(header[1]) + rawResp := make([]byte, length) + if _, err := io.ReadFull(br, rawResp); err != nil { + return nil, err + } + + // 7. Parse the response and possibly log that we received it. + resp := new(dns.Msg) + if err := resp.Unpack(rawResp); err != nil { + return nil, err + } + t.maybeLogResponseConn(ctx, addr, t0, rawQuery, rawResp, conn) + return resp, nil +} + +// newRawMsgFrame creates a new raw frame for sending a message over TCP or TLS. +func newRawMsgFrame(addr *ServerAddr, rawMsg []byte) ([]byte, error) { + if len(rawMsg) > math.MaxUint16 { + return nil, fmt.Errorf("%w: %s", ErrQueryTooLargeForTransport, addr.Protocol) + } + rawMsgFrame := []byte{byte(len(rawMsg) >> 8)} + rawMsgFrame = append(rawMsgFrame, byte(len(rawMsg))) + rawMsgFrame = append(rawMsgFrame, rawMsg...) + return rawMsgFrame, nil +} diff --git a/pkg/dns/dnscore/dotcp_test.go b/pkg/dns/dnscore/dotcp_test.go new file mode 100644 index 0000000..7159c63 --- /dev/null +++ b/pkg/dns/dnscore/dotcp_test.go @@ -0,0 +1,381 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "math" + "net" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/mocks" + "github.com/stretchr/testify/assert" +) + +// MockDNSMsg is a mock implementation of the dns.Msg interface. +type MockDNSMsg struct { + MockPack func() ([]byte, error) +} + +func (m *MockDNSMsg) Pack() ([]byte, error) { + return m.MockPack() +} + +func newValidRawRespFrame() []byte { + resp := &dns.Msg{} + rawResp, err := resp.Pack() + if err != nil { + panic(err) + } + rawRespFrame, err := newRawMsgFrame(&ServerAddr{}, rawResp) + if err != nil { + panic(err) + } + return rawRespFrame +} + +func newGarbageRawRespFrame() []byte { + rawRespFrame, err := newRawMsgFrame(&ServerAddr{}, []byte{0xFF}) + if err != nil { + panic(err) + } + return rawRespFrame +} + +func TestTransport_queryTCP(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + expectedError error + }{ + { + name: "Successful query", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: (bytes.NewReader(newValidRawRespFrame())).Read, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: nil, + }, + + { + name: "Dial failure", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("dial failed") + }, + } + }, + expectedError: errors.New("dial failed"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + addr := NewServerAddr(ProtocolTCP, "8.8.8.8:53") + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + _, err := transport.queryTCP(context.Background(), addr, query) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTransport_queryStream(t *testing.T) { + tests := []struct { + name string + query queryMsg + setupTransport func() *Transport + setupConn func(deadlineset *bool) net.Conn + expectedError error + expectDeadline bool + }{ + { + name: "Successful query", + query: &dns.Msg{ + Question: []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + }, + setupTransport: func() *Transport { + return &Transport{} + }, + setupConn: func(_ *bool) net.Conn { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: bytes.NewReader(newValidRawRespFrame()).Read, + MockClose: func() error { + return nil + }, + } + }, + expectedError: nil, + }, + + { + name: "Write failure", + query: &dns.Msg{ + Question: []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + }, + setupTransport: func() *Transport { + return &Transport{} + }, + setupConn: func(_ *bool) net.Conn { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, errors.New("write failed") + }, + MockClose: func() error { + return nil + }, + } + }, + expectedError: errors.New("write failed"), + }, + + { + name: "Read header failure", + query: &dns.Msg{ + Question: []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + }, + setupTransport: func() *Transport { + return &Transport{} + }, + setupConn: func(_ *bool) net.Conn { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: func(b []byte) (int, error) { + return 0, errors.New("read header failed") + }, + MockClose: func() error { + return nil + }, + } + }, + expectedError: errors.New("read header failed"), + }, + + { + name: "Read body failure", + query: &dns.Msg{ + Question: []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + }, + setupTransport: func() *Transport { + return &Transport{} + }, + setupConn: func(_ *bool) net.Conn { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: bytes.NewReader([]byte{0, 4}).Read, + MockClose: func() error { + return nil + }, + } + }, + expectedError: io.EOF, + }, + + { + name: "Unpack failure", + query: &dns.Msg{ + Question: []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + }, + setupTransport: func() *Transport { + return &Transport{} + }, + setupConn: func(_ *bool) net.Conn { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: bytes.NewReader(newGarbageRawRespFrame()).Read, + MockClose: func() error { + return nil + }, + } + }, + expectedError: errors.New("dns: overflow unpacking uint16"), + }, + + { + name: "Context deadline set", + query: &dns.Msg{ + Question: []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + }, + setupTransport: func() *Transport { + return &Transport{} + }, + setupConn: func(deadlineset *bool) net.Conn { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: bytes.NewReader(newValidRawRespFrame()).Read, + MockClose: func() error { + return nil + }, + MockSetDeadline: func(t time.Time) error { + *deadlineset = true + return nil + }, + } + }, + expectedError: nil, + expectDeadline: true, + }, + + { + name: "Non-FQDN query", + query: &dns.Msg{ + Question: []dns.Question{ + {Name: "invalid-domain", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + }, + }, + setupTransport: func() *Transport { + return &Transport{} + }, + setupConn: func(_ *bool) net.Conn { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockClose: func() error { + return nil + }, + } + }, + expectedError: errors.New("dns: domain must be fully qualified"), + }, + + { + name: "Query too large for transport", + query: &MockDNSMsg{ + MockPack: func() ([]byte, error) { + return make([]byte, math.MaxUint16+1), nil + }, + }, + setupTransport: func() *Transport { + return &Transport{} + }, + setupConn: func(_ *bool) net.Conn { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: bytes.NewReader(newValidRawRespFrame()).Read, + MockClose: func() error { + return nil + }, + } + }, + expectedError: fmt.Errorf("%w: %s", ErrQueryTooLargeForTransport, ProtocolTCP), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + addr := NewServerAddr(ProtocolTCP, "8.8.8.8:53") + + var deadlineset bool + conn := tt.setupConn(&deadlineset) + + ctx := context.Background() + if tt.expectDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(1*time.Hour)) + defer cancel() + } + + _, err := transport.queryStream(ctx, addr, tt.query, conn) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + if tt.expectDeadline { + assert.True(t, deadlineset) + } + }) + } +} + +func Test_newRawMsgFrame(t *testing.T) { + tests := []struct { + name string + rawMsg []byte + expectedFrame []byte + expectedError error + }{ + { + name: "Valid message frame", + rawMsg: []byte{0, 1, 2, 3}, + expectedFrame: []byte{0, 4, 0, 1, 2, 3}, + expectedError: nil, + }, + { + name: "Message too large", + rawMsg: make([]byte, math.MaxUint16+1), + expectedFrame: nil, + expectedError: ErrQueryTooLargeForTransport, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr := &ServerAddr{Protocol: ProtocolTCP} + frame, err := newRawMsgFrame(addr, tt.rawMsg) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError, errors.Unwrap(err)) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedFrame, frame) + } + }) + } +} diff --git a/pkg/dns/dnscore/dotls.go b/pkg/dns/dnscore/dotls.go new file mode 100644 index 0000000..17c1d54 --- /dev/null +++ b/pkg/dns/dnscore/dotls.go @@ -0,0 +1,60 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// DNS-over-TLS implementation +// + +package dnscore + +import ( + "context" + "crypto/tls" + "net" + + "github.com/miekg/dns" +) + +// dialTLSContext is a helper function that dials a network address using the +// given dialer or the default dialer if the given dialer is nil. +func (t *Transport) dialTLSContext(ctx context.Context, network, address string) (net.Conn, error) { + if t.DialTLSContext != nil { + return t.DialTLSContext(ctx, network, address) + } + + // Fill in a default TLS config + hostname, _, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + config := &tls.Config{ + InsecureSkipVerify: false, + NextProtos: []string{"dot"}, + RootCAs: t.RootCAs, + ServerName: hostname, + } + + // Defer to the stdlib TLS dialer + dialer := &tls.Dialer{Config: config} + return dialer.DialContext(ctx, network, address) +} + +// queryTLS implements [*Transport.Query] for DNS over TLS. +func (t *Transport) queryTLS(ctx context.Context, + addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + // 0. immediately fail if the context is already done, which + // is useful to write unit tests + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // 1. Dial the TLS connection + conn, err := t.dialTLSContext(ctx, "tcp", addr.Address) + + // 2. Handle dialing failure + if err != nil { + return nil, err + } + + // 3. Transfer conn ownership and perform the round trip + return t.queryStream(ctx, addr, query, conn) +} diff --git a/pkg/dns/dnscore/dotls_test.go b/pkg/dns/dnscore/dotls_test.go new file mode 100644 index 0000000..ebbfdef --- /dev/null +++ b/pkg/dns/dnscore/dotls_test.go @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "bytes" + "context" + "errors" + "net" + "testing" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/mocks" + "github.com/stretchr/testify/assert" +) + +func TestTransport_dialTLSContext(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + address string + expectedError error + }{ + { + name: "Invalid address", + setupTransport: func() *Transport { + return &Transport{} + }, + address: "invalid-address", + expectedError: errors.New("address invalid-address: missing port in address"), + }, + + { + name: "Override DialTLSContext", + setupTransport: func() *Transport { + return &Transport{ + DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{}, nil + }, + } + }, + address: "example.com:853", + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + ctx := context.Background() + + _, err := transport.dialTLSContext(ctx, "tcp", tt.address) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTransport_queryTLS(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + expectedError error + }{ + { + name: "Successful query", + setupTransport: func() *Transport { + return &Transport{ + DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: (bytes.NewReader(newValidRawRespFrame())).Read, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: nil, + }, + + { + name: "Dial failure", + setupTransport: func() *Transport { + return &Transport{ + DialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("dial failed") + }, + } + }, + expectedError: errors.New("dial failed"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + addr := NewServerAddr(ProtocolDoT, "8.8.8.8:853") + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + _, err := transport.queryTLS(context.Background(), addr, query) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/dns/dnscore/doudp.go b/pkg/dns/dnscore/doudp.go new file mode 100644 index 0000000..66759b7 --- /dev/null +++ b/pkg/dns/dnscore/doudp.go @@ -0,0 +1,211 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// DNS over UDP protocol. +// +// Adapted from: https://github.com/ooni/probe-engine/blob/v0.23.0/netx/resolver/dnsoverudp.go +// + +package dnscore + +import ( + "context" + "net" + "time" + + "github.com/miekg/dns" +) + +// dialContext is a helper function that dials a network address using the +// given dialer or the default dialer if the given dialer is nil. +func (t *Transport) dialContext(ctx context.Context, network, address string) (net.Conn, error) { + if t.DialContext != nil { + return t.DialContext(ctx, network, address) + } + dialer := &net.Dialer{} + return dialer.DialContext(ctx, network, address) +} + +// timeNow is a helper function that returns the current time using the +// given function or the stdlib if the given function is nil. +func (t *Transport) timeNow() time.Time { + if t.TimeNow != nil { + return t.TimeNow() + } + return time.Now() +} + +// sendQueryUDP dials a connection, sends and logs the query and +// returns the following values: +// +// - conn: the connection to the server. +// +// - t0: the time when the query was sent. +// +// - rawQuery: the raw query bytes sent to the server. +// +// - err: any error that occurred during the process. +// +// On success, the caller TAKES OWNERSHIP of the returned connection +// and is responsible for closing it when done. +func (t *Transport) sendQueryUDP(ctx context.Context, addr *ServerAddr, + query *dns.Msg) (conn net.Conn, t0 time.Time, rawQuery []byte, err error) { + // 1. Dial the connection and handle failure. We do not handle retries at this + // level and instead rely on the caller to retry the query if needed. This allows + // the [*Resolver] to cycle through multiple servers in case of failure. + conn, err = t.dialContext(ctx, "udp", addr.Address) + if err != nil { + return + } + + // 2. Use the context deadline to limit the query lifetime + // as documented in the [*Transport.Query] function. + if deadline, ok := ctx.Deadline(); ok { + _ = conn.SetDeadline(deadline) + } + + // 3. Serialize the query and possibly log that we're sending it. + rawQuery, err = query.Pack() + if err != nil { + return + } + t0 = t.maybeLogQuery(ctx, addr, rawQuery) + + // 4. Send the query. Do not bother with logging the write call + // since that should be done by a custom dialer that wraps the + // returned connection and implements the desired logging. + _, err = conn.Write(rawQuery) + return +} + +// edns0MaxResponseSize returns the maximum response size that the client +// did configure using EDNS(0) or the default size of 512 bytes. +func edns0MaxResponseSize(query *dns.Msg) (maxSize uint16) { + for _, rr := range query.Extra { + if opt, ok := rr.(*dns.OPT); ok { + maxSize = opt.UDPSize() + break + } + } + if maxSize <= 0 { + maxSize = 512 + } + return +} + +// recvResponseUDP reads and parses the response from the server and +// possibly logs the response. It returns the parsed response or an error. +func (t *Transport) recvResponseUDP(ctx context.Context, addr *ServerAddr, conn net.Conn, + t0 time.Time, query *dns.Msg, rawQuery []byte) (*dns.Msg, error) { + // 1. Read the corresponding raw response + buffer := make([]byte, edns0MaxResponseSize(query)) + count, err := conn.Read(buffer) + if err != nil { + return nil, err + } + rawResp := buffer[:count] + + // 2. Parse the raw response and possibly log that we received it. + resp := &dns.Msg{} + if err := resp.Unpack(rawResp); err != nil { + return nil, err + } + t.maybeLogResponseConn(ctx, addr, t0, rawQuery, rawResp, conn) + return resp, nil +} + +// queryUDP implements [*Transport.Query] for DNS over UDP. +func (t *Transport) queryUDP(ctx context.Context, + addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + // 0. immediately fail if the context is already done, which + // is useful to write unit tests + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // Send the query and log the query if needed. + conn, t0, rawQuery, err := t.sendQueryUDP(ctx, addr, query) + if err != nil { + return nil, err + } + + // Use a single connection for request, which is what the standard library + // does as well and is more robust in terms of residual censorship. + // + // Make sure we react to context being canceled early. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + defer conn.Close() + <-ctx.Done() + }() + + // Read and parse the response and log it if needed. + return t.recvResponseUDP(ctx, addr, conn, t0, query, rawQuery) +} + +// emitMessageOrError sends a message or error to the output channel +// or drops the message if the context is done. +func (t *Transport) emitMessageOrError(ctx context.Context, + msg *dns.Msg, err error, out chan *MessageOrError) { + var messageOrError *MessageOrError + if err != nil { + messageOrError = &MessageOrError{Err: err} + } else { + messageOrError = &MessageOrError{Msg: msg} + } + + select { + case out <- messageOrError: + case <-ctx.Done(): + } +} + +// queryUDPWithDuplicates implements [*Transport.Query] for DNS over UDP with +func (t *Transport) queryUDPWithDuplicates(ctx context.Context, + addr *ServerAddr, query *dns.Msg) <-chan *MessageOrError { + out := make(chan *MessageOrError, 4) + + // Immediately fail if the context is already done, which + // is useful to write unit tests + if ctx.Err() != nil { + out <- &MessageOrError{Err: ctx.Err()} + close(out) + return out + } + + go func() { + // Ensure the channel is closed when we're done + defer close(out) + + // Send the query and log the query if needed. + conn, t0, rawQuery, err := t.sendQueryUDP(ctx, addr, query) + if err != nil { + t.emitMessageOrError(ctx, nil, err, out) + return + } + + // Use a single connection for request, which is what the standard library + // does as well and is more robust in terms of residual censorship. + // + // Make sure we react to context being canceled early. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + defer conn.Close() + <-ctx.Done() + }() + + // Loop collecting responses and emitting them until the context is done. + for { + resp, err := t.recvResponseUDP(ctx, addr, conn, t0, query, rawQuery) + if err != nil { + t.emitMessageOrError(ctx, nil, err, out) + return + } + + t.emitMessageOrError(ctx, resp, nil, out) + } + }() + return out +} diff --git a/pkg/dns/dnscore/doudp_test.go b/pkg/dns/dnscore/doudp_test.go new file mode 100644 index 0000000..0ef28c6 --- /dev/null +++ b/pkg/dns/dnscore/doudp_test.go @@ -0,0 +1,660 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "context" + "errors" + "net" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/mocks" + "github.com/stretchr/testify/assert" +) + +func TestTransport_dialContext(t *testing.T) { + tests := []struct { + name string + dialContext func(ctx context.Context, network, address string) (net.Conn, error) + expectedError error + }{ + { + name: "Custom dialer success", + dialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{}, nil + }, + expectedError: nil, + }, + + { + name: "Custom dialer failure", + dialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("dial failed") + }, + expectedError: errors.New("dial failed"), + }, + + { + // note: this is still a unit test because dialing a UDP + // connection doesn't involve any network activity + name: "Default dialer success", + dialContext: nil, + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &Transport{ + DialContext: tt.dialContext, + } + _, err := transport.dialContext(context.Background(), "udp", "8.8.8.8:53") + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTransport_timeNow(t *testing.T) { + tests := []struct { + name string + timeNow func() time.Time + expected time.Time + }{ + { + name: "Custom time function", + timeNow: func() time.Time { + return time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) + }, + expected: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + }, + + { + name: "Default time function", + timeNow: nil, + expected: time.Now(), // This will be close to the current time + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &Transport{ + TimeNow: tt.timeNow, + } + actual := transport.timeNow() + if tt.timeNow != nil { + assert.Equal(t, tt.expected, actual) + } else { + assert.WithinDuration(t, tt.expected, actual, 5*time.Second) + } + }) + } +} + +func TestTransport_sendQueryUDP(t *testing.T) { + tests := []struct { + name string + questionName string + setupTransport func(setDeadlineCalled *bool) *Transport + expectedError error + expectDeadline bool + }{ + { + name: "Successful send", + questionName: "example.com.", + setupTransport: func(setDeadlineCalled *bool) *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockClose: func() error { + return nil + }, + MockSetDeadline: func(t time.Time) error { + *setDeadlineCalled = true + return nil + }, + }, nil + }, + } + }, + expectedError: nil, + expectDeadline: true, + }, + + { + name: "Dial failure", + questionName: "example.com.", + setupTransport: func(_ *bool) *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("dial failed") + }, + } + }, + expectedError: errors.New("dial failed"), + expectDeadline: false, + }, + + { + name: "Write failure", + questionName: "example.com.", + setupTransport: func(setDeadlineCalled *bool) *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, errors.New("write failed") + }, + MockClose: func() error { + return nil + }, + MockSetDeadline: func(t time.Time) error { + *setDeadlineCalled = true + return nil + }, + }, nil + }, + } + }, + expectedError: errors.New("write failed"), + expectDeadline: true, + }, + + { + name: "Cannot pack query", + questionName: "nameThatIsNotCanonicalFQDN", + setupTransport: func(setDeadlineCalled *bool) *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockClose: func() error { + return nil + }, + MockSetDeadline: func(t time.Time) error { + *setDeadlineCalled = true + return nil + }, + }, nil + }, + } + }, + expectedError: errors.New("dns: domain must be fully qualified"), + expectDeadline: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var setDeadlineCalled bool + + transport := tt.setupTransport(&setDeadlineCalled) + addr := &ServerAddr{Address: "8.8.8.8:53", Protocol: ProtocolUDP} + query := new(dns.Msg) + query.SetQuestion(tt.questionName, dns.TypeA) + + ctx := context.Background() + if tt.expectDeadline { + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(1*time.Hour)) + defer cancel() + } + + conn, _, _, err := transport.sendQueryUDP(ctx, addr, query) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + if tt.expectDeadline { + assert.NotNil(t, conn) + assert.True(t, setDeadlineCalled) + } + } + }) + } +} + +func Test_edns0MaxResponseSize(t *testing.T) { + tests := []struct { + name string + query *dns.Msg + expected uint16 + }{ + { + name: "EDNS0 option set", + query: func() *dns.Msg { + msg := new(dns.Msg) + opt := new(dns.OPT) + opt.SetUDPSize(4096) + msg.Extra = append(msg.Extra, opt) + return msg + }(), + expected: 4096, + }, + + { + name: "No EDNS0 option set", + query: func() *dns.Msg { + return new(dns.Msg) + }(), + expected: 512, + }, + + { + name: "EDNS0 option with zero size", + query: func() *dns.Msg { + msg := new(dns.Msg) + opt := new(dns.OPT) + opt.SetUDPSize(0) + msg.Extra = append(msg.Extra, opt) + return msg + }(), + expected: 512, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := edns0MaxResponseSize(tt.query) + assert.Equal(t, tt.expected, actual) + }) + } +} + +func TestTransport_recvResponseUDP(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + expectedError error + }{ + { + name: "Successful receive", + setupTransport: func() *Transport { + return &Transport{} + }, + expectedError: nil, + }, + + { + name: "Read failure", + setupTransport: func() *Transport { + return &Transport{} + }, + expectedError: errors.New("read failed"), + }, + + { + name: "Unpack failure", + setupTransport: func() *Transport { + return &Transport{} + }, + expectedError: errors.New("dns: overflow unpacking uint16"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + addr := &ServerAddr{Address: "8.8.8.8:53", Protocol: ProtocolUDP} + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + conn := &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + if tt.expectedError != nil { + return 0, tt.expectedError + } + copy(b, []byte{0, 0, 0, 0}) + return len(b), nil + }, + MockClose: func() error { + return nil + }, + } + + ctx := context.Background() + _, err := transport.recvResponseUDP( + ctx, addr, conn, time.Now(), query, []byte{0, 0, 0, 0}) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTransport_queryUDP(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + expectedError error + }{ + { + name: "Successful query", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: func(b []byte) (int, error) { + copy(b, []byte{0, 0, 0, 0}) + return len(b), nil + }, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: nil, + }, + + { + name: "Dial failure", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("dial failed") + }, + } + }, + expectedError: errors.New("dial failed"), + }, + + { + name: "Write failure", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, errors.New("write failed") + }, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: errors.New("write failed"), + }, + + { + name: "Read failure", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: func(b []byte) (int, error) { + return 0, errors.New("read failed") + }, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: errors.New("read failed"), + }, + + { + name: "Send query failure", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, errors.New("send query failed") + }, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: errors.New("send query failed"), + }, + + { + name: "Garbage response", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: func(b []byte) (int, error) { + copy(b, []byte{0xFF}) + return 1, nil + }, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: errors.New("dns: overflow unpacking uint16"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + addr := NewServerAddr(ProtocolUDP, "8.8.8.8:53") + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + _, err := transport.queryUDP(context.Background(), addr, query) + if tt.expectedError != nil { + assert.Error(t, err) + assert.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTransport_emitMessageOrError(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + err error + expectedError error + }{ + { + name: "Send message", + msg: new(dns.Msg), + err: nil, + expectedError: nil, + }, + + { + name: "Send error", + msg: nil, + err: errors.New("test error"), + expectedError: errors.New("test error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &Transport{} + out := make(chan *MessageOrError, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + transport.emitMessageOrError(ctx, tt.msg, tt.err, out) + messageOrError := <-out + + if tt.expectedError != nil { + assert.Error(t, messageOrError.Err) + assert.Equal(t, tt.expectedError.Error(), messageOrError.Err.Error()) + } else { + assert.NoError(t, messageOrError.Err) + assert.Equal(t, tt.msg, messageOrError.Msg) + } + }) + } +} + +func TestTransport_queryUDPWithDuplicates(t *testing.T) { + tests := []struct { + name string + setupTransport func() *Transport + expectedError error + }{ + { + name: "Successful query with duplicates", + setupTransport: func() *Transport { + count := &atomic.Int64{} + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: func(b []byte) (int, error) { + if count.Add(1) > 3 { + return 0, os.ErrDeadlineExceeded + } + copy(b, []byte{0, 0, 0, 0}) + return len(b), nil + }, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: os.ErrDeadlineExceeded, + }, + + { + name: "Dial failure", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("dial failed") + }, + } + }, + expectedError: errors.New("dial failed"), + }, + + { + name: "Write failure", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, errors.New("write failed") + }, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: errors.New("write failed"), + }, + + { + name: "Read failure", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: func(b []byte) (int, error) { + return 0, errors.New("read failed") + }, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: errors.New("read failed"), + }, + + { + name: "Garbage response", + setupTransport: func() *Transport { + return &Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: func(b []byte) (int, error) { + copy(b, []byte{0xFF}) + return 1, nil + }, + MockClose: func() error { + return nil + }, + }, nil + }, + } + }, + expectedError: errors.New("dns: overflow unpacking uint16"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := tt.setupTransport() + addr := NewServerAddr(ProtocolUDP, "8.8.8.8:53") + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + ch := transport.queryUDPWithDuplicates(context.Background(), addr, query) + messages := []*MessageOrError{} + for msgOrErr := range ch { + messages = append(messages, msgOrErr) + } + if len(messages) <= 0 { + t.Fatal("No messages received") + } + last := messages[len(messages)-1] + if tt.expectedError != nil { + assert.Error(t, last.Err) + assert.Equal(t, tt.expectedError.Error(), last.Err.Error()) + } else { + assert.NoError(t, last.Err) + } + }) + } +} diff --git a/pkg/dns/dnscore/example_https_test.go b/pkg/dns/dnscore/example_https_test.go new file mode 100644 index 0000000..6bad3e2 --- /dev/null +++ b/pkg/dns/dnscore/example_https_test.go @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore_test + +import ( + "context" + "fmt" + "log" + "slices" + "strings" + "time" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/runtimex" + "github.com/rbmk-project/rbmk/pkg/dns/dnscore" +) + +func ExampleTransport_dnsOverHTTPS() { + // create transport, server addr, and query + txp := &dnscore.Transport{} + serverAddr := &dnscore.ServerAddr{ + Protocol: dnscore.ProtocolDoH, + Address: "https://8.8.8.8/dns-query", + } + options := []dnscore.QueryOption{ + dnscore.QueryOptionEDNS0( + dnscore.EDNS0SuggestedMaxResponseSizeOtherwise, + dnscore.EDNS0FlagDO|dnscore.EDNS0FlagBlockLengthPadding, + ), + } + query, err := dnscore.NewQueryWithServerAddr(serverAddr, "dns.google", dns.TypeA, options...) + if err != nil { + log.Fatal(err) + } + + // issue the query and get the response + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp, err := txp.Query(ctx, serverAddr, query) + if err != nil { + log.Fatal(err) + } + + // validate the response + if err := dnscore.ValidateResponse(query, resp); err != nil { + log.Fatal(err) + } + runtimex.Assert(len(query.Question) > 0, "expected at least one question") + rrs, err := dnscore.ValidAnswers(query.Question[0], resp) + if err != nil { + log.Fatal(err) + } + + // print the results + var addrs []string + for _, rr := range rrs { + switch rr := rr.(type) { + case *dns.A: + addrs = append(addrs, rr.A.String()) + } + } + slices.Sort(addrs) + fmt.Printf("%s\n", strings.Join(addrs, "\n")) + + // Output: + // 8.8.4.4 + // 8.8.8.8 +} diff --git a/pkg/dns/dnscore/example_quic_test.go b/pkg/dns/dnscore/example_quic_test.go new file mode 100644 index 0000000..9e36991 --- /dev/null +++ b/pkg/dns/dnscore/example_quic_test.go @@ -0,0 +1,66 @@ +package dnscore_test + +import ( + "context" + "fmt" + "log" + "slices" + "strings" + "time" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/runtimex" + "github.com/rbmk-project/rbmk/pkg/dns/dnscore" +) + +func ExampleTransport_dnsOverQUIC() { + // create transport, server addr, and query + txp := &dnscore.Transport{} + serverAddr := &dnscore.ServerAddr{ + Protocol: dnscore.ProtocolDoQ, + Address: "dns0.eu:853", + } + options := []dnscore.QueryOption{ + dnscore.QueryOptionEDNS0( + dnscore.EDNS0SuggestedMaxResponseSizeOtherwise, + dnscore.EDNS0FlagDO|dnscore.EDNS0FlagBlockLengthPadding, + ), + } + query, err := dnscore.NewQueryWithServerAddr(serverAddr, "dns.google", dns.TypeA, options...) + if err != nil { + log.Fatal(err) + } + + // issue the query and get the response + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp, err := txp.Query(ctx, serverAddr, query) + if err != nil { + log.Fatal(err) + } + + // validate the response + if err := dnscore.ValidateResponse(query, resp); err != nil { + log.Fatal(err) + } + runtimex.Assert(len(query.Question) > 0, "expected at least one question") + rrs, err := dnscore.ValidAnswers(query.Question[0], resp) + if err != nil { + log.Fatal(err) + } + + // print the results + var addrs []string + for _, rr := range rrs { + switch rr := rr.(type) { + case *dns.A: + addrs = append(addrs, rr.A.String()) + } + } + slices.Sort(addrs) + fmt.Printf("%s\n", strings.Join(addrs, "\n")) + + // Output: + // 8.8.4.4 + // 8.8.8.8 +} diff --git a/pkg/dns/dnscore/example_resolver_test.go b/pkg/dns/dnscore/example_resolver_test.go new file mode 100644 index 0000000..ad3fce6 --- /dev/null +++ b/pkg/dns/dnscore/example_resolver_test.go @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore_test + +import ( + "context" + "fmt" + "log" + "slices" + "strings" + + "github.com/rbmk-project/rbmk/pkg/dns/dnscore" +) + +func ExampleResolver() { + // create resolver + reso := &dnscore.Resolver{} + + // issue the queries and merge the responses + addrs, err := reso.LookupHost(context.Background(), "dns.google") + if err != nil { + log.Fatal(err) + } + + // print the results + slices.Sort(addrs) + fmt.Printf("%s\n", strings.Join(addrs, "\n")) + + // Output: + // 2001:4860:4860::8844 + // 2001:4860:4860::8888 + // 8.8.4.4 + // 8.8.8.8 +} diff --git a/pkg/dns/dnscore/example_tcp_test.go b/pkg/dns/dnscore/example_tcp_test.go new file mode 100644 index 0000000..c22144b --- /dev/null +++ b/pkg/dns/dnscore/example_tcp_test.go @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore_test + +import ( + "context" + "fmt" + "log" + "slices" + "strings" + "time" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/runtimex" + "github.com/rbmk-project/rbmk/pkg/dns/dnscore" +) + +func ExampleTransport_dnsOverTCP() { + // create transport, server addr, and query + txp := &dnscore.Transport{} + serverAddr := &dnscore.ServerAddr{ + Protocol: dnscore.ProtocolTCP, + Address: "8.8.8.8:53", + } + options := []dnscore.QueryOption{ + dnscore.QueryOptionEDNS0( + dnscore.EDNS0SuggestedMaxResponseSizeOtherwise, + 0, + ), + } + query, err := dnscore.NewQueryWithServerAddr(serverAddr, "dns.google", dns.TypeA, options...) + if err != nil { + log.Fatal(err) + } + + // issue the query and get the response + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp, err := txp.Query(ctx, serverAddr, query) + if err != nil { + log.Fatal(err) + } + + // validate the response + if err := dnscore.ValidateResponse(query, resp); err != nil { + log.Fatal(err) + } + runtimex.Assert(len(query.Question) > 0, "expected at least one question") + rrs, err := dnscore.ValidAnswers(query.Question[0], resp) + if err != nil { + log.Fatal(err) + } + + // print the results + var addrs []string + for _, rr := range rrs { + switch rr := rr.(type) { + case *dns.A: + addrs = append(addrs, rr.A.String()) + } + } + slices.Sort(addrs) + fmt.Printf("%s\n", strings.Join(addrs, "\n")) + + // Output: + // 8.8.4.4 + // 8.8.8.8 +} diff --git a/pkg/dns/dnscore/example_tls_test.go b/pkg/dns/dnscore/example_tls_test.go new file mode 100644 index 0000000..8efc36f --- /dev/null +++ b/pkg/dns/dnscore/example_tls_test.go @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore_test + +import ( + "context" + "fmt" + "log" + "slices" + "strings" + "time" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/runtimex" + "github.com/rbmk-project/rbmk/pkg/dns/dnscore" +) + +func ExampleTransport_dnsOverTLS() { + // create transport, server addr, and query + txp := &dnscore.Transport{} + serverAddr := &dnscore.ServerAddr{ + Protocol: dnscore.ProtocolDoT, + Address: "8.8.8.8:853", + } + options := []dnscore.QueryOption{ + dnscore.QueryOptionEDNS0( + dnscore.EDNS0SuggestedMaxResponseSizeOtherwise, + dnscore.EDNS0FlagDO|dnscore.EDNS0FlagBlockLengthPadding, + ), + } + query, err := dnscore.NewQueryWithServerAddr(serverAddr, "dns.google", dns.TypeA, options...) + if err != nil { + log.Fatal(err) + } + + // issue the query and get the response + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp, err := txp.Query(ctx, serverAddr, query) + if err != nil { + log.Fatal(err) + } + + // validate the response + if err := dnscore.ValidateResponse(query, resp); err != nil { + log.Fatal(err) + } + runtimex.Assert(len(query.Question) > 0, "expected at least one question") + rrs, err := dnscore.ValidAnswers(query.Question[0], resp) + if err != nil { + log.Fatal(err) + } + + // print the results + var addrs []string + for _, rr := range rrs { + switch rr := rr.(type) { + case *dns.A: + addrs = append(addrs, rr.A.String()) + } + } + slices.Sort(addrs) + fmt.Printf("%s\n", strings.Join(addrs, "\n")) + + // Output: + // 8.8.4.4 + // 8.8.8.8 +} diff --git a/pkg/dns/dnscore/example_udp_test.go b/pkg/dns/dnscore/example_udp_test.go new file mode 100644 index 0000000..bca3cd6 --- /dev/null +++ b/pkg/dns/dnscore/example_udp_test.go @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore_test + +import ( + "context" + "fmt" + "log" + "slices" + "strings" + "time" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/runtimex" + "github.com/rbmk-project/rbmk/pkg/dns/dnscore" +) + +func ExampleTransport_dnsOverUDP() { + // create transport, server addr, and query + txp := &dnscore.Transport{} + serverAddr := &dnscore.ServerAddr{ + Protocol: dnscore.ProtocolUDP, + Address: "8.8.8.8:53", + } + options := []dnscore.QueryOption{ + dnscore.QueryOptionEDNS0( + dnscore.EDNS0SuggestedMaxResponseSizeUDP, + 0, + ), + } + query, err := dnscore.NewQueryWithServerAddr(serverAddr, "dns.google", dns.TypeA, options...) + if err != nil { + log.Fatal(err) + } + + // issue the query and get the response + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp, err := txp.Query(ctx, serverAddr, query) + if err != nil { + log.Fatal(err) + } + + // validate the response + if err := dnscore.ValidateResponse(query, resp); err != nil { + log.Fatal(err) + } + runtimex.Assert(len(query.Question) > 0, "expected at least one question") + rrs, err := dnscore.ValidAnswers(query.Question[0], resp) + if err != nil { + log.Fatal(err) + } + + // print the results + var addrs []string + for _, rr := range rrs { + switch rr := rr.(type) { + case *dns.A: + addrs = append(addrs, rr.A.String()) + } + } + slices.Sort(addrs) + fmt.Printf("%s\n", strings.Join(addrs, "\n")) + + // Output: + // 8.8.4.4 + // 8.8.8.8 +} diff --git a/pkg/dns/dnscore/integration_test.go b/pkg/dns/dnscore/integration_test.go new file mode 100644 index 0000000..fe5133f --- /dev/null +++ b/pkg/dns/dnscore/integration_test.go @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore_test + +import ( + "context" + "crypto/tls" + "net/http" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/dns/dnscore" + "github.com/rbmk-project/rbmk/pkg/dns/dnscoretest" + "github.com/stretchr/testify/assert" +) + +func checkResult(t *testing.T, resp *dns.Msg, err error) { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 1, len(resp.Answer)) + assert.Equal(t, "example.com.", resp.Answer[0].Header().Name) + assert.Equal(t, dns.TypeA, resp.Answer[0].Header().Rrtype) + assert.Equal( + t, dnscoretest.ExampleComAddrA.String(), + resp.Answer[0].(*dns.A).A.String(), + ) +} + +func TestTransport_RoundTrip_UDP(t *testing.T) { + // create and start a testing server + server := &dnscoretest.Server{} + handler := dnscoretest.NewExampleComHandler() + <-server.StartUDP(handler) + defer server.Close() + + // create transport, server addr, and query + txp := &dnscore.Transport{} + serverAddr := &dnscore.ServerAddr{ + Protocol: dnscore.ProtocolUDP, + Address: server.Addr, + } + options := []dnscore.QueryOption{ + dnscore.QueryOptionEDNS0( + dnscore.EDNS0SuggestedMaxResponseSizeUDP, + 0, + ), + } + query, err := dnscore.NewQueryWithServerAddr(serverAddr, "example.com", dns.TypeA, options...) + if err != nil { + t.Fatal(err) + } + + // issue the query and get the response + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp, err := txp.Query(ctx, serverAddr, query) + + // verify the results + checkResult(t, resp, err) +} + +func TestTransport_RoundTrip_TCP(t *testing.T) { + // create and start a testing server + server := &dnscoretest.Server{} + handler := dnscoretest.NewExampleComHandler() + <-server.StartTCP(handler) + defer server.Close() + + // create transport, server addr, and query + txp := &dnscore.Transport{} + serverAddr := &dnscore.ServerAddr{ + Protocol: dnscore.ProtocolTCP, + Address: server.Addr, + } + options := []dnscore.QueryOption{ + dnscore.QueryOptionEDNS0( + dnscore.EDNS0SuggestedMaxResponseSizeOtherwise, + 0, + ), + } + query, err := dnscore.NewQueryWithServerAddr(serverAddr, "example.com", dns.TypeA, options...) + if err != nil { + t.Fatal(err) + } + + // issue the query and get the response + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp, err := txp.Query(ctx, serverAddr, query) + + // verify the results + checkResult(t, resp, err) +} + +func TestTransport_RoundTrip_TLS(t *testing.T) { + // create and start a testing server + server := &dnscoretest.Server{} + handler := dnscoretest.NewExampleComHandler() + <-server.StartTLS(handler) + defer server.Close() + + // create transport, server addr, and query + txp := &dnscore.Transport{RootCAs: server.RootCAs} + serverAddr := &dnscore.ServerAddr{ + Protocol: dnscore.ProtocolDoT, + Address: server.Addr, + } + options := []dnscore.QueryOption{ + dnscore.QueryOptionEDNS0( + dnscore.EDNS0SuggestedMaxResponseSizeOtherwise, + dnscore.EDNS0FlagDO|dnscore.EDNS0FlagBlockLengthPadding, + ), + } + query, err := dnscore.NewQueryWithServerAddr(serverAddr, "example.com", dns.TypeA, options...) + if err != nil { + t.Fatal(err) + } + + // issue the query and get the response + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp, err := txp.Query(ctx, serverAddr, query) + + // verify the results + checkResult(t, resp, err) +} + +func TestTransport_RoundTrip_HTTPS(t *testing.T) { + // create and start a testing server + server := &dnscoretest.Server{} + handler := dnscoretest.NewExampleComHandler() + <-server.StartHTTPS(handler) + defer server.Close() + + // create transport, server addr, and query + txp := &dnscore.Transport{ + HTTPClient: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: server.RootCAs, + }, + }, + }, + } + serverAddr := &dnscore.ServerAddr{ + Protocol: dnscore.ProtocolDoH, + Address: server.URL, + } + options := []dnscore.QueryOption{ + dnscore.QueryOptionEDNS0( + dnscore.EDNS0SuggestedMaxResponseSizeOtherwise, + dnscore.EDNS0FlagDO|dnscore.EDNS0FlagBlockLengthPadding, + ), + } + query, err := dnscore.NewQueryWithServerAddr(serverAddr, "example.com", dns.TypeA, options...) + if err != nil { + t.Fatal(err) + } + + // issue the query and get the response + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + resp, err := txp.Query(ctx, serverAddr, query) + + // verify the results + checkResult(t, resp, err) +} + +// TODO(bassosimone,roopeshsn): add integration tests for DoQ diff --git a/pkg/dns/dnscore/lookup.go b/pkg/dns/dnscore/lookup.go new file mode 100644 index 0000000..6e8962e --- /dev/null +++ b/pkg/dns/dnscore/lookup.go @@ -0,0 +1,103 @@ +// +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// SPDX-License-Identifier: BSD-3-Clause +// +// Adapted from: https://github.com/golang/go/blob/go1.21.10/src/net/dnsclient_unix.go +// +// Resolver code to send queries and receive responses +// along with BSD-licensed code from the stdlib. +// + +package dnscore + +import ( + "context" + "errors" + + "github.com/miekg/dns" +) + +// transport returns the tranport to use for resolving queries, which is +// either the transport specified in the resolver or the default. +func (r *Resolver) transport() ResolverTransport { + if r.Transport != nil { + return r.Transport + } + return DefaultTransport +} + +// exchange implements [*Resolver.lookup] with a specific server. +func (r *Resolver) exchange(ctx context.Context, + name string, qtype uint16, server resolverConfigServer) ([]dns.RR, error) { + // Handle the case of domains that should not be resolved + labels := dns.SplitDomainName(dns.CanonicalName(name)) + if len(labels) > 0 && labels[len(labels)-1] == "onion" { + return nil, ErrNoData + } + + // Enforce an operation timeout + if server.timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, server.timeout) + defer cancel() + } + + // Encode the query + query, err := NewQueryWithServerAddr(server.address, name, qtype, server.queryOptions...) + if err != nil { + return nil, err + } + q0 := query.Question[0] // we know it's present because we just created it + + // Obtain the transport and perform the query + resp, err := r.transport().Query(ctx, server.address, query) + if err != nil { + return nil, err + } + + // Validate the response, check for errors and extract RRs + if err := ValidateResponse(query, resp); err != nil { + return nil, err + } + if err := RCodeToError(resp); err != nil { + return nil, err + } + return ValidAnswers(q0, resp) +} + +// lookup is the internal implementation of the Lookup* functions. +func (r *Resolver) lookup(ctx context.Context, + name string, qtype uint16) ([]dns.RR, error) { + // by default, on failure, we return the EAI_NODATA equivalent + lastErr := ErrNoData + + // obtain the list of servers and prepare to walk it + var ( + config = r.config() + attempts = config.Attempts() + servers = config.servers() + ) + for idx := 0; len(servers) > 0 && idx < attempts; idx++ { + // select a server and exchange the query + server := servers[uint32(idx)%uint32(len(servers))] + rrs, err := r.exchange(ctx, name, qtype, server) + + // immediately handle success and stop on NXDOMAIN + // + // note: it's not so common to use NXDOMAIN for censorship + // so this is a trade off to privilege fast convergence + if err == nil { + return rrs, nil + } + if errors.Is(err, ErrNoName) { + return nil, err + } + + lastErr = err + } + + return nil, lastErr +} diff --git a/pkg/dns/dnscore/lookup_test.go b/pkg/dns/dnscore/lookup_test.go new file mode 100644 index 0000000..02c6557 --- /dev/null +++ b/pkg/dns/dnscore/lookup_test.go @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/miekg/dns" +) + +func TestResolver_transport(t *testing.T) { + t.Run("default transport", func(t *testing.T) { + resolver := &Resolver{} + if resolver.transport() != DefaultTransport { + t.Fatal("unexpected transport: got non-default transport, want DefaultTransport") + } + }) + + t.Run("custom transport", func(t *testing.T) { + expectedTransport := &MockResolverTransport{} + resolver := &Resolver{Transport: expectedTransport} + if resolver.transport() != expectedTransport { + t.Fatal("unexpected transport: got different transport, want expectedTransport") + } + }) +} + +func TestResolver_exchange(t *testing.T) { + t.Run("successful query", func(t *testing.T) { + expectedRR := &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("192.0.2.1"), + } + mockTransport := &MockResolverTransport{ + MockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + resp := &dns.Msg{} + resp.SetReply(query) + resp.Answer = append(resp.Answer, expectedRR) + return resp, nil + }, + } + resolver := &Resolver{Transport: mockTransport} + server := resolverConfigServer{ + address: &ServerAddr{Address: "8.8.8.8:53"}, + } + rrs, err := resolver.exchange(context.Background(), "example.com", dns.TypeA, server) + if err != nil { + t.Fatal("unexpected error:", err) + } + if len(rrs) != 1 || rrs[0].String() != expectedRR.String() { + t.Fatalf("unexpected result: got %v, want %v", rrs, expectedRR) + } + }) + + t.Run("query timeout", func(t *testing.T) { + mockTransport := &MockResolverTransport{ + MockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + time.Sleep(100 * time.Millisecond) + return nil, context.DeadlineExceeded + }, + } + resolver := &Resolver{Transport: mockTransport} + server := resolverConfigServer{ + address: &ServerAddr{Address: "8.8.8.8:53"}, + timeout: 10 * time.Millisecond, + } + _, err := resolver.exchange(context.Background(), "example.com", dns.TypeA, server) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("unexpected error: got %v, want %v", err, context.DeadlineExceeded) + } + }) + + t.Run("onion domain", func(t *testing.T) { + resolver := &Resolver{} + server := resolverConfigServer{ + address: &ServerAddr{Address: "8.8.8.8:53"}, + } + _, err := resolver.exchange(context.Background(), "example.onion", dns.TypeA, server) + if !errors.Is(err, ErrNoData) { + t.Fatalf("unexpected error: got %v, want %v", err, ErrNoData) + } + }) + + t.Run("cannot encode query", func(t *testing.T) { + resolver := &Resolver{} + server := resolverConfigServer{ + address: &ServerAddr{Address: "8.8.8.8:53"}, + } + _, err := resolver.exchange(context.Background(), "\t\t\t", dns.TypeA, server) + if err == nil || err.Error() != "idna: disallowed rune U+0009" { + t.Fatalf("unexpected error: %s", err) + } + }) + + t.Run("invalid response", func(t *testing.T) { + mockTransport := &MockResolverTransport{ + MockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + return &dns.Msg{}, nil + }, + } + resolver := &Resolver{Transport: mockTransport} + server := resolverConfigServer{ + address: &ServerAddr{Address: "8.8.8.8:53"}, + } + _, err := resolver.exchange(context.Background(), "example.com", dns.TypeA, server) + if !errors.Is(err, ErrInvalidResponse) { + t.Fatalf("unexpected error: got %v, want %v", err, ErrInvalidResponse) + } + }) +} + +func TestResolver_lookup(t *testing.T) { + t.Run("successful lookup", func(t *testing.T) { + expectedRR := &dns.A{ + Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("192.0.2.1"), + } + mockTransport := &MockResolverTransport{ + MockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + resp := &dns.Msg{} + resp.SetReply(query) + resp.Answer = append(resp.Answer, expectedRR) + return resp, nil + }, + } + resolver := &Resolver{Transport: mockTransport} + config := &ResolverConfig{ + attempts: DefaultAttempts, + list: []resolverConfigServer{ + {address: &ServerAddr{Address: "8.8.8.8:53"}}, + }, + } + resolver.Config = config + rrs, err := resolver.lookup(context.Background(), "example.com", dns.TypeA) + if err != nil { + t.Fatal("unexpected error:", err) + } + if len(rrs) != 1 || rrs[0].String() != expectedRR.String() { + t.Fatalf("unexpected result: got %v, want %v", rrs, expectedRR) + } + }) + + t.Run("lookup with no data", func(t *testing.T) { + mockTransport := &MockResolverTransport{ + MockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + resp := &dns.Msg{} + resp.SetReply(query) + return resp, nil + }, + } + resolver := &Resolver{Transport: mockTransport} + config := &ResolverConfig{ + attempts: DefaultAttempts, + list: []resolverConfigServer{ + {address: &ServerAddr{Address: "8.8.8.8:53"}}, + }, + } + resolver.Config = config + _, err := resolver.lookup(context.Background(), "example.com", dns.TypeA) + if !errors.Is(err, ErrNoData) { + t.Fatalf("unexpected error: got %v, want %v", err, ErrNoData) + } + }) + + t.Run("lookup with NXDOMAIN", func(t *testing.T) { + mockTransport := &MockResolverTransport{ + MockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + resp := &dns.Msg{} + resp.SetReply(query) + resp.Rcode = dns.RcodeNameError + return resp, nil + }, + } + resolver := &Resolver{Transport: mockTransport} + config := &ResolverConfig{ + attempts: DefaultAttempts, + list: []resolverConfigServer{ + {address: &ServerAddr{Address: "8.8.8.8:53"}}, + }, + } + resolver.Config = config + _, err := resolver.lookup(context.Background(), "example.com", dns.TypeA) + if !errors.Is(err, ErrNoName) { + t.Fatalf("unexpected error: got %v, want %v", err, ErrNoName) + } + }) +} diff --git a/pkg/dns/dnscore/query.go b/pkg/dns/dnscore/query.go new file mode 100644 index 0000000..195117a --- /dev/null +++ b/pkg/dns/dnscore/query.go @@ -0,0 +1,153 @@ +// +// SPDX-License-Identifier: BSD-3-Clause +// +// Adapted from: https://github.com/ooni/probe-engine/blob/v0.23.0/netx/resolver/encoder.go +// +// Query implementation +// + +package dnscore + +import ( + "github.com/miekg/dns" + "golang.org/x/net/idna" +) + +// QueryOption is a function that modifies a DNS query. +type QueryOption func(*dns.Msg) error + +const ( + // EDNS0FlagDO enables DNSSEC by setting the DNSSSEC OK (DO) bit. + EDNS0FlagDO = 1 << iota + + // EDNS0FlagBlockLengthPadding enables block-length padding as defined + // by https://datatracker.ietf.org/doc/html/rfc8467#section-4.1. + // + // This helps protect against size-based traffic analysis by padding + // DNS queries to a standard block size (128 bytes). + // + // This flag implies [QueryFlagEDNS0]. + EDNS0FlagBlockLengthPadding +) + +// EDNS0SuggestedMaxResponseSizeUDP is the suggested max-response size +// to use for the DNS over UDP transport. This value is same as the one +// used by the [net] package in the standard library. +const EDNS0SuggestedMaxResponseSizeUDP = 1232 + +// END0SSuggestedMaxResponseSizeOtherwise is the suggested max-response size +// when not using the DNS over UDP transport. +const EDNS0SuggestedMaxResponseSizeOtherwise = 4096 + +// QueryOptionEDNS0 configures the EDNS(0) options. +// +// You can configure: +// +// 1. The maximum acceptable response size. +// +// 2. DNSSEC using [EDNS0FlagDO]. +// +// 3. Block-length padding using [EDNS0FlagBlockLengthPadding]. +func QueryOptionEDNS0(maxResponseSize uint16, flags int) QueryOption { + return func(q *dns.Msg) error { + // 1. DNSSEC OK (DO) + q.SetEdns0(maxResponseSize, flags&EDNS0FlagDO != 0) + + // 2. padding + // + // Clients SHOULD pad queries to the closest multiple of + // 128 octets RFC8467#section-4.1. We inflate the query + // length by the size of the option (i.e. 4 octets). The + // cast to uint is necessary to make the modulus operation + // work as intended when the desiredBlockSize is smaller + // than (query.Len()+4) ¯\_(ツ)_/¯. + if flags&EDNS0FlagBlockLengthPadding != 0 { + const desiredBlockSize = 128 + remainder := (desiredBlockSize - uint16(q.Len()+4)) % desiredBlockSize + opt := new(dns.EDNS0_PADDING) + opt.Padding = make([]byte, remainder) + q.IsEdns0().Option = append(q.IsEdns0().Option, opt) + } + return nil + } +} + +// QueryOptionID allows setting an arbitrary query ID. +// +// Otherwise, the default is using [dns.Id] for all protocols +// except DNS-over-HTTPS and DNS-over-QUIC, where we use +// zero, thus following RFC 9250 Sect 4.2.1. +func QueryOptionID(id uint16) QueryOption { + return func(q *dns.Msg) error { + q.Id = id + return nil + } +} + +// NewQueryWithServerAddr constructs a [*dns.Message] containing a +// query for the given domain, query type and [*ServerAddr]. We use +// the [*ServerAddr] to enforce protocol-specific query settings, +// such as, that DoH SHOULD use a zero query ID. +// +// This function takes care of IDNA encoding the domain name and +// fails if the domain name is invalid. +// +// Additionally, [NewQuery] ensures the given name is fully qualified. +// +// Use constants such as [dns.TypeAAAA] to specify the query type. +// +// The [QueryOption] functions can be used to set additional options. +func NewQueryWithServerAddr(serverAddr *ServerAddr, name string, qtype uint16, + options ...QueryOption) (*dns.Msg, error) { + // IDNA encode the domain name. + punyName, err := idna.Lookup.ToASCII(name) + if err != nil { + return nil, err + } + + // Ensure the domain name is fully qualified. + if !dns.IsFqdn(punyName) { + punyName = dns.Fqdn(punyName) + } + + // Create the query message. + question := dns.Question{ + Name: punyName, + Qtype: qtype, + Qclass: dns.ClassINET, + } + query := new(dns.Msg) + query.RecursionDesired = true + query.Question = make([]dns.Question, 1) + query.Question[0] = question + + // Only set the queryID for protocols that actually + // require a nonzero queryID to be set. + switch serverAddr.Protocol { + case ProtocolDoH, ProtocolDoQ: + // for DoH/DoQ, by default we leave the query ID to + // zero, which is what the RFCs suggest/require. + default: + query.Id = dns.Id() + } + + // Apply the query options. + for _, option := range options { + if err := option(query); err != nil { + return nil, err + } + } + return query, nil +} + +// NewQuery is equivalent to calling [NewQueryWithServerAddr] with +// a zero-initialized [*ServerAddr]. We retain this function for backward +// compatibility with the previous API. Existing code that is using this +// function SHOULD use [NewQueryWithServerAddr] with DoH (and MUST with +// DoQ) such that we correctly set the query ID to zero. Other protocols +// are not impacted by this issue and may continue using [NewQuery]. +// +// Deprecated: use [NewQueryWithServerAddr] instead. +func NewQuery(name string, qtype uint16, options ...QueryOption) (*dns.Msg, error) { + return NewQueryWithServerAddr(&ServerAddr{}, name, qtype, options...) +} diff --git a/pkg/dns/dnscore/query_test.go b/pkg/dns/dnscore/query_test.go new file mode 100644 index 0000000..774afd2 --- /dev/null +++ b/pkg/dns/dnscore/query_test.go @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "errors" + "testing" + + "github.com/miekg/dns" +) + +func TestNewQueryWithServerAddr(t *testing.T) { + // Override the `dns.Id` factory for testing purposes + const expectedNonZeroQueryID = 4 + savedId := dns.Id + dns.Id = func() uint16 { return expectedNonZeroQueryID } + defer func() { dns.Id = savedId }() + + tests := []struct { + name string + serverAddr *ServerAddr + qname string + qtype uint16 + options []QueryOption + wantName string + wantId uint16 + wantErr bool + }{ + { + name: "standard UDP query", + serverAddr: NewServerAddr(ProtocolUDP, "8.8.8.8:53"), + qname: "www.example.com", + qtype: dns.TypeA, + wantName: "www.example.com.", + wantId: expectedNonZeroQueryID, + }, + { + name: "standard TCP query", + serverAddr: NewServerAddr(ProtocolTCP, "8.8.8.8:53"), + qname: "www.example.com", + qtype: dns.TypeA, + wantName: "www.example.com.", + wantId: expectedNonZeroQueryID, + }, + { + name: "standard TLS query", + serverAddr: NewServerAddr(ProtocolTLS, "8.8.8.8:53"), + qname: "www.example.com", + qtype: dns.TypeA, + wantName: "www.example.com.", + wantId: expectedNonZeroQueryID, + }, + { + name: "DoH query should have zero ID", + serverAddr: NewServerAddr(ProtocolHTTPS, "https://dns.google/dns-query"), + qname: "example.com", + qtype: dns.TypeAAAA, + wantName: "example.com.", + wantId: 0, + }, + { + name: "invalid domain", + serverAddr: NewServerAddr(ProtocolUDP, "8.8.8.8:53"), + qname: "invalid domain", + qtype: dns.TypeA, + wantErr: true, + }, + { + name: "with failing option", + serverAddr: NewServerAddr(ProtocolUDP, "8.8.8.8:53"), + qname: "www.example.com", + qtype: dns.TypeA, + options: []QueryOption{mockedFailingOption}, + wantErr: true, + }, + { + name: "DoQ query should have zero ID", + serverAddr: NewServerAddr(ProtocolQUIC, "dns.adguard-dns.com:853"), + qname: "example.com", + qtype: dns.TypeAAAA, + wantName: "example.com.", + wantId: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewQueryWithServerAddr(tt.serverAddr, tt.qname, tt.qtype, tt.options...) + if (err != nil) != tt.wantErr { + t.Errorf("NewQueryWithServerAddr() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + return + } + if got.Question[0].Name != tt.wantName { + t.Errorf("NewQueryWithServerAddr() name = %v, want %v", got.Question[0].Name, tt.wantName) + } + if tt.wantId == 0 && got.Id != 0 { + t.Errorf("NewQueryWithServerAddr() id = %v, want 0", got.Id) + } + if tt.wantId != 0 && got.Id == 0 { + t.Errorf("NewQueryWithServerAddr() id = 0, want non-zero") + } + }) + } +} + +func TestNewQuery(t *testing.T) { + // Note: NewQuery has been deprecated on 2025-02-20 + tests := []struct { + name string + qtype uint16 + options []QueryOption + wantName string + wantErr bool + }{ + {"www.example.com", dns.TypeA, nil, "www.example.com.", false}, + {"example.com", dns.TypeAAAA, nil, "example.com.", false}, + {"invalid domain", dns.TypeA, nil, "", true}, + {"www.mocked-failure.com", dns.TypeA, []QueryOption{mockedFailingOption}, "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewQuery(tt.name, tt.qtype, tt.options...) + if (err != nil) != tt.wantErr { + t.Errorf("NewQuery() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil && got.Question[0].Name != tt.wantName { + t.Errorf("NewQuery() = %v, want %v", got.Question[0].Name, tt.wantName) + } + }) + } +} + +func mockedFailingOption(q *dns.Msg) error { + return errors.New("mocked option failure") +} + +func TestQueryOptionEDNS0(t *testing.T) { + query := new(dns.Msg) + option := QueryOptionEDNS0(4096, EDNS0FlagDO|EDNS0FlagBlockLengthPadding) + if err := option(query); err != nil { + t.Errorf("QueryOptionEDNS0() error = %v", err) + } + if query.IsEdns0() == nil { + t.Errorf("QueryOptionEDNS0() did not set EDNS0 options") + } + if len(query.IsEdns0().Option) == 0 { + t.Errorf("QueryOptionEDNS0() did not set padding option") + } +} + +func TestQueryOptionID(t *testing.T) { + query := new(dns.Msg) + option := QueryOptionID(42) + if err := option(query); err != nil { + t.Errorf("QueryOptionID() error = %v", err) + } + if query.Id != 42 { + t.Errorf("QueryOptionID() did not set ID") + } +} diff --git a/pkg/dns/dnscore/resolver.go b/pkg/dns/dnscore/resolver.go new file mode 100644 index 0000000..3544612 --- /dev/null +++ b/pkg/dns/dnscore/resolver.go @@ -0,0 +1,158 @@ +// +// SPDX-License-Identifier: GPL-3.0-or-later +// +// Adapted from: https://github.com/ooni/probe-cli/blob/v3.20.1/internal/netxlite/resolverparallel.go +// + +package dnscore + +import ( + "context" + "errors" + "net" + "strings" + + "github.com/miekg/dns" +) + +// ResolverTransport is the interface defining the [*Transport] +// methods used by the [*Resolver] struct. +// +// The [*Transport] type implements this interface. +type ResolverTransport interface { + Query(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) +} + +// Resolver is a DNS resolver. This struct is API compatible with +// the [*net.Resolver] struct from the [net] package. +// +// The zero value is ready to use. +type Resolver struct { + // Config is the optional resolver configuration. + // + // If nil, we use an empty [*ResolverConfig]. + Config *ResolverConfig + + // Transport is the optional DNS transport to use for resolving queries. + // + // If nil, we use [DefaultTransport]. + Transport ResolverTransport +} + +// config returns the resolver configuration or a default one. +func (r *Resolver) config() *ResolverConfig { + if r.Config == nil { + return NewConfig() + } + return r.Config +} + +// resolverLookupResult is the result of a lookup operation. +type resolverLookupResult struct { + addrs []string + err error +} + +// LookupHost looks up the given host named using the DNS resolver. +func (r *Resolver) LookupHost(ctx context.Context, host string) ([]string, error) { + // start A and AAAA lookups in the background to speed up the process + // then wait for both of them to terminate + // + // note: when the context is canceled, the lookup terminates immediately + ach := make(chan *resolverLookupResult) + go func() { + var result resolverLookupResult + result.addrs, result.err = r.LookupA(ctx, host) + ach <- &result + }() + aaaach := make(chan *resolverLookupResult) + go func() { + var result resolverLookupResult + result.addrs, result.err = r.LookupAAAA(ctx, host) + aaaach <- &result + }() + ares, aaaares := <-ach, <-aaaach + + // merge addresses to return a single list to the caller + addrs := append(append([]string{}, ares.addrs...), aaaares.addrs...) + + // handle the case of no addresses + // + // if there's an error, give priority to the A error because not all + // domains have AAAA records; as a fallback, when there's no error just + // say that the queries returned no data + if len(addrs) < 1 { + if ares.err != nil && !errors.Is(ares.err, ErrNoData) { + return nil, ares.err + } + if aaaares.err != nil && !errors.Is(aaaares.err, ErrNoData) { + return nil, aaaares.err + } + return nil, ErrNoData + } + + // deduplicate addresses and sort IPv4 before IPv6 + addrs = resolverDedupAndSort(addrs) + return addrs, nil +} + +// resolverDedupAndSort deduplicates a list of addresses and sorts IPv4 +// addresses before IPv6 addresses. In principle, DNS resolvers should not +// return duplicates, but, with censorship, it is possible that the AAAA +// query answer is actually a censored A answer. Additionally, since we +// don't implement RFC6724, we sort IPv4 addresses before IPv6 addresses, +// given that everyone supports IPv4 and not everyone supports IPv6. +func resolverDedupAndSort(addrs []string) []string { + uniq := make(map[string]struct{}) + var dedupA, dedupAAAA []string + for _, addr := range addrs { + if _, ok := uniq[addr]; !ok { + uniq[addr] = struct{}{} + if strings.Contains(addr, ":") { + dedupAAAA = append(dedupAAAA, addr) + continue + } + dedupA = append(dedupA, addr) + } + } + result := make([]string, 0, len(dedupA)+len(dedupAAAA)) + result = append(result, dedupA...) + result = append(result, dedupAAAA...) + return result +} + +// LookupA resolves the IPv4 addresses of a given domain. +func (r *Resolver) LookupA(ctx context.Context, host string) ([]string, error) { + // Behave like getaddrinfo when the host is an IP address. + if net.ParseIP(host) != nil { + return []string{host}, nil + } + + // Obtain the RRs + rrs, err := r.lookup(ctx, host, dns.TypeA) + if err != nil { + return nil, err + } + + // Decode as IPv4 addresses and CNAME + addrs, _, err := DecodeLookupA(rrs) + return addrs, err +} + +// LookupAAAA resolves the IPv6 addresses of a given domain. +func (r *Resolver) LookupAAAA(ctx context.Context, host string) ([]string, error) { + // Behave like getaddrinfo when the host is an IP address. + if net.ParseIP(host) != nil { + return []string{host}, nil + } + + // Obtain the RRs + rrs, err := r.lookup(ctx, host, dns.TypeAAAA) + if err != nil { + return nil, err + } + + // Decode as IPv6 addresses and CNAME + addrs, _, err := DecodeLookupAAAA(rrs) + return addrs, err +} diff --git a/pkg/dns/dnscore/resolver_test.go b/pkg/dns/dnscore/resolver_test.go new file mode 100644 index 0000000..6b7632f --- /dev/null +++ b/pkg/dns/dnscore/resolver_test.go @@ -0,0 +1,447 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "context" + "errors" + "io" + "net" + "testing" + + "github.com/miekg/dns" +) + +// MockResolverTransport allows mocking a [ResolverTransport]. +type MockResolverTransport struct { + MockQuery func(ctx context.Context, + addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) +} + +func (rtm *MockResolverTransport) Query(ctx context.Context, + addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + return rtm.MockQuery(ctx, addr, query) +} + +func TestResolverTransportMock(t *testing.T) { + t.Run("Query", func(t *testing.T) { + expected := errors.New("mocked error") + rtm := &MockResolverTransport{ + MockQuery: func(ctx context.Context, + addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + return nil, expected + }, + } + resp, err := rtm.Query(context.Background(), nil, nil) + if !errors.Is(err, expected) { + t.Fatal("unexpected error") + } + if resp != nil { + t.Fatal("unexpected response") + } + }) +} + +func TestResolver_config(t *testing.T) { + tests := []struct { + name string + config *ResolverConfig + attempts int + }{ + { + name: "Nil config returns default", + config: nil, + attempts: DefaultAttempts, + }, + { + name: "Non-nil config returns the same config", + config: &ResolverConfig{attempts: 128}, + attempts: 128, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resolver := &Resolver{Config: tt.config} + result := resolver.config() + if result.attempts != tt.attempts { + t.Fatalf("expected attempts %d, got %d", tt.attempts, result.attempts) + } + }) + } +} + +func TestResolverDedupAndSort(t *testing.T) { + tests := []struct { + name string + input []string + expected []string + }{ + { + name: "No duplicates, mixed IPv4 and IPv6", + input: []string{"192.0.2.1", "2001:db8::1"}, + expected: []string{"192.0.2.1", "2001:db8::1"}, + }, + + { + name: "With duplicates, mixed IPv4 and IPv6", + input: []string{"192.0.2.1", "2001:db8::1", "192.0.2.1", "2001:db8::1"}, + expected: []string{"192.0.2.1", "2001:db8::1"}, + }, + + { + name: "Only IPv4 addresses", + input: []string{"192.0.2.1", "192.0.2.2"}, + expected: []string{"192.0.2.1", "192.0.2.2"}, + }, + + { + name: "Only IPv6 addresses", + input: []string{"2001:db8::1", "2001:db8::2"}, + expected: []string{"2001:db8::1", "2001:db8::2"}, + }, + + { + name: "Mixed IPv4 and IPv6 with duplicates", + input: []string{"192.0.2.1", "2001:db8::1", "192.0.2.1", "2001:db8::1", "192.0.2.2"}, + expected: []string{"192.0.2.1", "192.0.2.2", "2001:db8::1"}, + }, + + { + name: "Empty input", + input: []string{}, + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := resolverDedupAndSort(tt.input) + if len(result) != len(tt.expected) { + t.Fatalf("expected %d addresses, got %d", len(tt.expected), len(result)) + } + for i, addr := range result { + if addr != tt.expected[i] { + t.Fatalf("expected address %s, got %s", tt.expected[i], addr) + } + } + }) + } +} + +func TestResolver_LookupA(t *testing.T) { + tests := []struct { + name string + host string + mockQuery func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) + expected []string + expectedErr error + }{ + { + name: "Valid IPv4 address", + host: "192.0.2.1", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + return nil, nil + }, + expected: []string{"192.0.2.1"}, + expectedErr: nil, + }, + + { + name: "Valid A record", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + msg := &dns.Msg{} + msg.SetReply(query) + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: query.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("192.0.2.1"), + }) + return msg, nil + }, + expected: []string{"192.0.2.1"}, + expectedErr: nil, + }, + + { + name: "No A record", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + msg := &dns.Msg{} + msg.SetReply(query) + return msg, nil + }, + expected: nil, + expectedErr: ErrNoData, + }, + + { + name: "DNS query error", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + return nil, io.EOF + }, + expected: nil, + expectedErr: io.EOF, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rtm := &MockResolverTransport{ + MockQuery: tt.mockQuery, + } + resolver := &Resolver{} + resolver.Transport = rtm + + addrs, err := resolver.LookupA(context.Background(), tt.host) + + if !errors.Is(err, tt.expectedErr) { + t.Fatalf("expected error %v, got %v", tt.expectedErr, err) + } + if len(addrs) != len(tt.expected) { + t.Fatalf("expected %d addresses, got %d", len(tt.expected), len(addrs)) + } + for i, addr := range addrs { + if addr != tt.expected[i] { + t.Fatalf("expected address %s, got %s", tt.expected[i], addr) + } + } + }) + } +} + +func TestResolver_LookupAAAA(t *testing.T) { + tests := []struct { + name string + host string + mockQuery func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) + expected []string + expectedErr error + }{ + { + name: "Valid IPv6 address", + host: "2001:db8::1", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + return nil, nil + }, + expected: []string{"2001:db8::1"}, + expectedErr: nil, + }, + + { + name: "Valid AAAA record", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + msg := &dns.Msg{} + msg.SetReply(query) + msg.Answer = append(msg.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{Name: query.Question[0].Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 300}, + AAAA: net.ParseIP("2001:db8::1"), + }) + return msg, nil + }, + expected: []string{"2001:db8::1"}, + expectedErr: nil, + }, + + { + name: "No AAAA record", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + msg := &dns.Msg{} + msg.SetReply(query) + return msg, nil + }, + expected: nil, + expectedErr: ErrNoData, + }, + + { + name: "DNS query error", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + return nil, io.EOF + }, + expected: nil, + expectedErr: io.EOF, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rtm := &MockResolverTransport{ + MockQuery: tt.mockQuery, + } + resolver := &Resolver{} + resolver.Transport = rtm + + addrs, err := resolver.LookupAAAA(context.Background(), tt.host) + + if !errors.Is(err, tt.expectedErr) { + t.Fatalf("expected error %v, got %v", tt.expectedErr, err) + } + if len(addrs) != len(tt.expected) { + t.Fatalf("expected %d addresses, got %d", len(tt.expected), len(addrs)) + } + for i, addr := range addrs { + if addr != tt.expected[i] { + t.Fatalf("expected address %s, got %s", tt.expected[i], addr) + } + } + }) + } +} + +func TestResolver_LookupHost(t *testing.T) { + skipAll := false // Set this to true to skip all tests except those with dontSkip set to true + + tests := []struct { + name string + host string + mockQuery func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) + expected []string + expectedErr error + dontSkip bool + }{ + { + name: "Valid A and AAAA records", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + msg := &dns.Msg{} + msg.SetReply(query) + switch query.Question[0].Qtype { + case dns.TypeA: + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: query.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("192.0.2.1"), + }) + case dns.TypeAAAA: + msg.Answer = append(msg.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{Name: query.Question[0].Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 300}, + AAAA: net.ParseIP("2001:db8::1"), + }) + } + return msg, nil + }, + expected: []string{"192.0.2.1", "2001:db8::1"}, + expectedErr: nil, + dontSkip: false, + }, + + { + name: "No A and AAAA records", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + msg := &dns.Msg{} + msg.SetReply(query) + return msg, nil + }, + expected: nil, + expectedErr: ErrNoData, + dontSkip: false, + }, + + { + name: "DNS query error for A record", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + if query.Question[0].Qtype == dns.TypeA { + return nil, io.EOF + } + msg := &dns.Msg{} + msg.SetReply(query) + msg.Answer = append(msg.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{Name: query.Question[0].Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 300}, + AAAA: net.ParseIP("2001:db8::1"), + }) + return msg, nil + }, + expected: []string{"2001:db8::1"}, + expectedErr: nil, + dontSkip: false, + }, + + { + name: "DNS query error for AAAA record", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + if query.Question[0].Qtype == dns.TypeAAAA { + return nil, io.EOF + } + msg := &dns.Msg{} + msg.SetReply(query) + msg.Answer = append(msg.Answer, &dns.A{ + Hdr: dns.RR_Header{Name: query.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 300}, + A: net.ParseIP("192.0.2.1"), + }) + return msg, nil + }, + expected: []string{"192.0.2.1"}, + expectedErr: nil, + dontSkip: false, + }, + + { + name: "A returns error, AAAA returns no error and no addresses", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + msg := &dns.Msg{} + if query.Question[0].Qtype == dns.TypeA { + msg.SetRcode(query, dns.RcodeNameError) + } else { + msg.SetReply(query) + } + return msg, nil + }, + expected: nil, + expectedErr: ErrNoName, + dontSkip: false, + }, + + { + name: "A returns no error and no addresses, AAAA returns error", + host: "example.com", + mockQuery: func(ctx context.Context, addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + msg := &dns.Msg{} + if query.Question[0].Qtype == dns.TypeAAAA { + msg.SetRcode(query, dns.RcodeNameError) + } else { + msg.SetReply(query) + } + return msg, nil + }, + expected: nil, + expectedErr: ErrNoName, + dontSkip: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if skipAll && !tt.dontSkip { + t.Skip("Skipping test as skipAll is true and dontSkip is false") + } + + rtm := &MockResolverTransport{ + MockQuery: tt.mockQuery, + } + resolver := &Resolver{} + resolver.Transport = rtm + + addrs, err := resolver.LookupHost(context.Background(), tt.host) + + if !errors.Is(err, tt.expectedErr) { + t.Fatalf("expected error %v, got %v", tt.expectedErr, err) + } + if len(addrs) != len(tt.expected) { + t.Fatalf("expected %d addresses, got %d", len(tt.expected), len(addrs)) + } + for i, addr := range addrs { + if addr != tt.expected[i] { + t.Fatalf("expected address %s, got %s", tt.expected[i], addr) + } + } + }) + } +} diff --git a/pkg/dns/dnscore/resolverconfig.go b/pkg/dns/dnscore/resolverconfig.go new file mode 100644 index 0000000..7d44dd0 --- /dev/null +++ b/pkg/dns/dnscore/resolverconfig.go @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "net" + "sync" + "time" +) + +// ResolverConfig contains configuration for the resolver. +// +// Construct using [NewConfig]. +// +// This struct is safe for concurrent use by multiple goroutines. +// +// If the configuration is empty, it uses the "8.8.8.8:53/udp" +// and "8.8.4.4:53/udp" servers as the default servers. +type ResolverConfig struct { + // attempts is the number of attempts to make for each query. + attempts int + + // list contains the list of configured servers. + list []resolverConfigServer + + // mu is the mutex for the config. + mu sync.RWMutex +} + +// DefaultAttempts is the default number of attempts to make for each query. +const DefaultAttempts = 2 + +// NewConfig creates a new resolver configuration. +func NewConfig() *ResolverConfig { + return &ResolverConfig{ + attempts: DefaultAttempts, + list: []resolverConfigServer{}, + mu: sync.RWMutex{}, + } +} + +// SetAttempts sets the number of attempts to make for each query. +func (c *ResolverConfig) SetAttempts(attempts int) { + c.mu.Lock() + c.attempts = attempts + c.mu.Unlock() +} + +// Attempts returns the number of attempts to make for each query. +func (c *ResolverConfig) Attempts() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.attempts +} + +// resolverConfigServer contains configuration for a single resolver server. +// +// Construct a new instance using [newResolverConfigServer]. +type resolverConfigServer struct { + // address is the address of the server. + address *ServerAddr + + // queryOptions is the list of query options to use + // for constructing queries to this server. + queryOptions []QueryOption + + // timeout is the timeout for each query. + timeout time.Duration +} + +// AddServerOption is an option for adding a server to the resolver configuration. +type AddServerOption func(*resolverConfigServer) + +// ServerOptionQueryOptions sets the query options to use for constructing queries +// to this specific server.If this option is not used, we use the default query options +// suitable for the protocol used by the server. Specifically, we enable DNSSEC +// validation and block-length padding for DoT, DoH, and DoQ. +func ServerOptionQueryOptions(queryOptions ...QueryOption) AddServerOption { + return func(s *resolverConfigServer) { + s.queryOptions = queryOptions + } +} + +// DefaultQueryTimeout is the default timeout for each query. +const DefaultQueryTimeout = 5 * time.Second + +// ServerOptionQueryTimeout sets the timeout for each query. +// +// If this option is not used, we use the [DefaultQueryTimeout] default. +func ServerOptionQueryTimeout(timeout time.Duration) AddServerOption { + return func(s *resolverConfigServer) { + s.timeout = timeout + } +} + +// newResolverConfigServer creates a new resolver server configuration. +func newResolverConfigServer(address *ServerAddr, options ...AddServerOption) resolverConfigServer { + server := resolverConfigServer{ + address: address, + queryOptions: []QueryOption{}, + timeout: DefaultQueryTimeout, + } + + // apply the default query options suitable for the protocol used by the server + switch address.Protocol { + case ProtocolDoH, ProtocolDoT, ProtocolDoQ: + server.queryOptions = append(server.queryOptions, QueryOptionEDNS0( + EDNS0SuggestedMaxResponseSizeOtherwise, + EDNS0FlagDO|EDNS0FlagBlockLengthPadding)) + + case ProtocolTCP: + server.queryOptions = append(server.queryOptions, QueryOptionEDNS0( + EDNS0SuggestedMaxResponseSizeOtherwise, 0)) + + case ProtocolUDP: + server.queryOptions = append(server.queryOptions, QueryOptionEDNS0( + EDNS0SuggestedMaxResponseSizeUDP, 0)) + } + + // apply the user-provided resolver config options + for _, option := range options { + option(&server) + } + + return server +} + +// AddServer adds a new server to the resolver configuration. +func (c *ResolverConfig) AddServer(address *ServerAddr, options ...AddServerOption) { + c.mu.Lock() + c.list = append(c.list, newResolverConfigServer(address, options...)) + c.mu.Unlock() +} + +// servers returns the list of configured servers. +func (c *ResolverConfig) servers() []resolverConfigServer { + // copy the list of servers + c.mu.RLock() + list := append([]resolverConfigServer(nil), c.list...) + c.mu.RUnlock() + + // if empty, create the default servers + if len(list) == 0 { + defaultAddrs := []string{"8.8.8.8", "8.8.4.4"} + for _, addr := range defaultAddrs { + // TODO(bassosimone): double check whether this is causing + // us to always use the max UDP response size also for + // encrypted transports. I think this may be the case just + // by reading the current code. + list = append(list, newResolverConfigServer( + NewServerAddr(ProtocolUDP, net.JoinHostPort(addr, "53")), + ServerOptionQueryOptions(QueryOptionEDNS0( + EDNS0SuggestedMaxResponseSizeUDP, 0)))) + } + } + return list +} diff --git a/pkg/dns/dnscore/resolverconfig_test.go b/pkg/dns/dnscore/resolverconfig_test.go new file mode 100644 index 0000000..2d40bb4 --- /dev/null +++ b/pkg/dns/dnscore/resolverconfig_test.go @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "testing" + "time" + + "github.com/miekg/dns" +) + +func TestNewConfig(t *testing.T) { + config := NewConfig() + if config == nil { + t.Fatal("Expected non-nil config") + } + if config.Attempts() != DefaultAttempts { + t.Fatalf("Expected %d attempts, got %d", DefaultAttempts, config.Attempts()) + } + if len(config.servers()) != 2 { + t.Fatalf("Expected 2 default servers, got %d", len(config.servers())) + } +} + +func TestSetAttempts(t *testing.T) { + config := NewConfig() + config.SetAttempts(3) + if config.Attempts() != 3 { + t.Fatalf("Expected 3 attempts, got %d", config.Attempts()) + } +} + +func TestAddServer(t *testing.T) { + config := NewConfig() + addr := NewServerAddr(ProtocolUDP, "1.1.1.1:53") + config.AddServer(addr) + servers := config.servers() + if len(servers) != 1 { + t.Fatalf("Expected 1 server, got %d", len(servers)) + } + if servers[0].address.Address != "1.1.1.1:53" { + t.Fatalf("Expected server address 1.1.1.1:53, got %s", servers[0].address.Address) + } +} + +func TestServerOptionQueryTimeout(t *testing.T) { + addr := NewServerAddr(ProtocolUDP, "1.1.1.1:53") + server := newResolverConfigServer(addr, ServerOptionQueryTimeout(5*time.Second)) + if server.timeout != 5*time.Second { + t.Fatalf("Expected timeout 5s, got %s", server.timeout) + } +} + +func TestServerOptionQueryOptions(t *testing.T) { + tests := []struct { + protocol Protocol + expectedLength int + expectedFlags int + }{ + {ProtocolUDP, 1, 0}, + {ProtocolTCP, 1, 0}, + {ProtocolDoT, 1, EDNS0FlagDO | EDNS0FlagBlockLengthPadding}, + {ProtocolDoH, 1, EDNS0FlagDO | EDNS0FlagBlockLengthPadding}, + } + + for _, test := range tests { + addr := NewServerAddr(test.protocol, "1.1.1.1:53") + server := newResolverConfigServer(addr) + if len(server.queryOptions) != test.expectedLength { + t.Fatalf("Expected %d query option(s) for protocol %s, got %d", test.expectedLength, test.protocol, len(server.queryOptions)) + } + if test.expectedLength > 0 { + option := server.queryOptions[0] + msg := &dns.Msg{} + if err := option(msg); err != nil { + t.Fatalf("Failed to apply query option for protocol %s: %s", test.protocol, err) + } + if msg.IsEdns0() == nil { + t.Fatalf("Expected EDNS0 option for protocol %s", test.protocol) + } + if msg.IsEdns0().Do() != (test.expectedFlags&EDNS0FlagDO != 0) { + t.Fatalf("Expected DO flag %v for protocol %s, got %v", test.expectedFlags&EDNS0FlagDO != 0, test.protocol, msg.IsEdns0().Do()) + } + if len(msg.IsEdns0().Option) > 0 { + if _, ok := msg.IsEdns0().Option[0].(*dns.EDNS0_PADDING); !ok && (test.expectedFlags&EDNS0FlagBlockLengthPadding != 0) { + t.Fatalf("Expected block length padding for protocol %s", test.protocol) + } + } + } + } +} diff --git a/pkg/dns/dnscore/response.go b/pkg/dns/dnscore/response.go new file mode 100644 index 0000000..acf615c --- /dev/null +++ b/pkg/dns/dnscore/response.go @@ -0,0 +1,193 @@ +// +// SPDX-License-Identifier: BSD-3-Clause +// +// Adapted from: +// +// - https://github.com/ooni/probe-engine/blob/v0.23.0/netx/resolver/decoder.go +// +// - https://github.com/golang/go/blob/go1.21.10/src/net/dnsclient_unix.go +// +// Response implementation +// + +package dnscore + +import ( + "errors" + + "github.com/miekg/dns" +) + +// Additional errors emitted by [ValidateResponse]. +var ( + // ErrInvalidQuery means that the query does not contain a single question. + ErrInvalidQuery = errors.New("invalid query") +) + +// ValidateResponse validates a given DNS response +// message for a given query message. +func ValidateResponse(query, resp *dns.Msg) error { + // 1. make sure the message is actually a response + if !resp.Response { + return ErrInvalidResponse + } + + // 2. make sure the response ID matches the query ID + if resp.Id != query.Id { + return ErrInvalidResponse + } + + // 3. make sure the query and the response contains a question + if len(resp.Question) != 1 { + return ErrInvalidResponse + } + resp0 := resp.Question[0] + if len(query.Question) != 1 { + return ErrInvalidQuery + } + query0 := query.Question[0] + + // 4. make sure the question name is correct + if !equalASCIIName(resp0.Name, query0.Name) { + return ErrInvalidResponse + } + if resp0.Qclass != query0.Qclass { + return ErrInvalidResponse + } + if resp0.Qtype != query0.Qtype { + return ErrInvalidResponse + } + return nil +} + +func equalASCIIName(x, y string) bool { + if len(x) != len(y) { + return false + } + for i := 0; i < len(x); i++ { + a := x[i] + b := y[i] + if 'A' <= a && a <= 'Z' { + a += 0x20 + } + if 'A' <= b && b <= 'Z' { + b += 0x20 + } + if a != b { + return false + } + } + return true +} + +// These error messages use the same suffixes used by the Go standard library. +var ( + // ErrCannotUnmarshalMessage indicates that we cannot unmarshal a DNS message. + ErrCannotUnmarshalMessage = errors.New("cannot unmarshal DNS message") + + // ErrInvalidResponse means that the response is not a response message + // or does not contain a single question matching the query. + ErrInvalidResponse = errors.New("invalid DNS response") + + // ErrNoName indicates that the server response code is NXDOMAIN. + ErrNoName = errors.New("no such host") + + // ErrServerMisbehaving indicates that the server response code is + // neither 0, nor NXDOMAIN, nor SERVFAIL. + ErrServerMisbehaving = errors.New("server misbehaving") + + // ErrServerTemporarilyMisbehaving indicates that the server answer is SERVFAIL. + // + // The error message is same as [ErrServerMisbehaving] for compatibility with the + // Go standard library, which assigns the same error string to both errors. + ErrServerTemporarilyMisbehaving = errors.New("server misbehaving") + + // ErrNoData indicates that there is no pertinent answer in the response. + ErrNoData = errors.New("no answer from DNS server") +) + +// RCodeToError maps an RCODE inside a valid DNS response +// to an error string using a suffix compatible with the +// error strings returned by [*net.Resolver]. +// +// For example, if a domain does not exist, the error +// will use the "no such host" suffix. +// +// If the RCODE is zero, this function returns nil. +// +// Before invoking this function, make sure the response is valid +// for the request by calling [ValidateResponse]. +func RCodeToError(resp *dns.Msg) error { + // 1. handle NXDOMAIN case by mapping it to EAI_NONAME + if resp.Rcode == dns.RcodeNameError { + return ErrNoName + } + + // 2. handle the case of lame referral by mapping it to EAI_NODATA + if resp.Rcode == dns.RcodeSuccess && + !resp.Authoritative && + !resp.RecursionAvailable && + len(resp.Answer) == 0 { + return ErrNoData + } + + // 3. handle any other error by mapping to EAI_FAIL + if resp.Rcode != dns.RcodeSuccess { + if resp.Rcode == dns.RcodeServerFailure { + return ErrServerTemporarilyMisbehaving + } + return ErrServerMisbehaving + } + return nil +} + +// ValidAnswers extracts valid RRs from the response considering +// the DNS question that was asked. Before invoking this function, +// make sure the response is valid using [ValidateResponse]. +// +// The list of valid RRs is returned in the same order as they appear +// in the response message. If the response does not contain any valid +// RRs, this function returns an empty list. +func ValidAnswers(q0 dns.Question, resp *dns.Msg) ([]dns.RR, error) { + // 1. figure out whether the resolver has encountered any + // alias along the way, or we should filter for the original + // question name. As mentioned in the Go standard library: + // + // [...] RFC 1034 section 4.3.1 says that "the recursive + // response to a query will be... The answer to the query, + // possibly preface by one or more CNAME RRs that specify + // aliases encountered on the way to an answer." + // + // Therefore, if we want to validate, we need to do so + // against the last CNAME record in the answers. + expectName := q0.Name + for _, answer := range resp.Answer { + switch answer := answer.(type) { + case *dns.CNAME: + expectName = answer.Target + } + } + + // 3. build list of valid answers + valid := []dns.RR{} + for _, answer := range resp.Answer { + header := answer.Header() + if !equalASCIIName(expectName, header.Name) { + continue + } + if q0.Qclass != header.Class { + continue + } + // Note: there may be several RR types for a given query so we + // should not check for the type here + valid = append(valid, answer) + } + + // 3. handle the case of no valid answers + if len(valid) < 1 { + return nil, ErrNoData + } + + // 4. return the possibly empty list. + return valid, nil +} diff --git a/pkg/dns/dnscore/response_test.go b/pkg/dns/dnscore/response_test.go new file mode 100644 index 0000000..1507ad8 --- /dev/null +++ b/pkg/dns/dnscore/response_test.go @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestValidateResponse(t *testing.T) { + tests := []struct { + name string + modify func(*dns.Msg, *dns.Msg) + expected error + }{ + { + name: "ValidResponse", + modify: func(query, resp *dns.Msg) { + // No modification needed, valid response + }, + expected: nil, + }, + + { + name: "InvalidResponseID", + modify: func(query, resp *dns.Msg) { + resp.Id = query.Id + 1 + }, + expected: ErrInvalidResponse, + }, + + { + name: "InvalidResponseNotAResponse", + modify: func(query, resp *dns.Msg) { + resp.Response = false + }, + expected: ErrInvalidResponse, + }, + + { + name: "InvalidQueryNoQuestion", + modify: func(query, resp *dns.Msg) { + query.Question = nil + }, + expected: ErrInvalidQuery, + }, + + { + name: "InvalidResponseNoQuestion", + modify: func(query, resp *dns.Msg) { + resp.Question = nil + }, + expected: ErrInvalidResponse, + }, + + { + name: "InvalidResponseQuestionName", + modify: func(query, resp *dns.Msg) { + resp.Question[0].Name = "invalid.com." + }, + expected: ErrInvalidResponse, + }, + + { + name: "InvalidResponseQuestionClass", + modify: func(query, resp *dns.Msg) { + resp.Question[0].Qclass = dns.ClassCHAOS + }, + expected: ErrInvalidResponse, + }, + + { + name: "InvalidResponseQuestionType", + modify: func(query, resp *dns.Msg) { + resp.Question[0].Qtype = dns.TypeAAAA + }, + expected: ErrInvalidResponse, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + resp := new(dns.Msg) + resp.SetReply(query) + + tt.modify(query, resp) + + if err := ValidateResponse(query, resp); err != tt.expected { + t.Fatalf("expected %v, got %v", tt.expected, err) + } + }) + } +} + +func Test_equalASCIIName(t *testing.T) { + tests := []struct { + name string + x string + y string + expected bool + }{ + {"EqualNames", "example.com.", "example.com.", true}, + {"EqualNamesDifferentCase", "Example.COM.", "exaMple.com.", true}, + {"DifferentNames", "example.com.", "example.org.", false}, + {"DifferentLengths", "example.com.", "example.co.uk.", false}, + {"OnlyPrefixMatch", "example.co.", "example.co.uk.", false}, + {"EmptyStrings", "", "", true}, + {"OneEmptyString", "example.com.", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if result := equalASCIIName(tt.x, tt.y); result != tt.expected { + t.Fatalf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestRCodeToError(t *testing.T) { + tests := []struct { + name string + rcode int + expected error + }{ + {"NameError", dns.RcodeNameError, ErrNoName}, + {"ServerFailure", dns.RcodeServerFailure, ErrServerTemporarilyMisbehaving}, + {"LameReferral", dns.RcodeSuccess, ErrNoData}, + {"Success", dns.RcodeSuccess, nil}, + {"Refused", dns.RcodeRefused, ErrServerMisbehaving}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := new(dns.Msg) + resp.Rcode = tt.rcode + + switch tt.name { + case "LameReferral": + resp.Authoritative = false + resp.RecursionAvailable = false + resp.Answer = nil + case "Success": + resp.Authoritative = true + resp.RecursionAvailable = true + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "example.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.IPv4(127, 0, 0, 1), + }} + } + + if err := RCodeToError(resp); err != tt.expected { + t.Fatalf("expected %v, got %v", tt.expected, err) + } + }) + } +} + +func TestValidAnswers(t *testing.T) { + tests := []struct { + name string + query *dns.Msg + resp *dns.Msg + expected int + err error + }{ + { + name: "ValidAnswerWithoutCNAME", + query: func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + return m + }(), + resp: func() *dns.Msg { + m := new(dns.Msg) + m.SetReply(new(dns.Msg)) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: "example.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.IPv4(127, 0, 0, 1), + }) + return m + }(), + expected: 1, + err: nil, + }, + + { + name: "ValidAnswerWithCNAME", + query: func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.co.uk.", dns.TypeA) + return m + }(), + resp: func() *dns.Msg { + m := new(dns.Msg) + m.SetReply(new(dns.Msg)) + m.Answer = append(m.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: "example.co.uk.", + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + }, + Target: "example.com.", + }) + m.Answer = append(m.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: "example.com.", + Rrtype: dns.TypeCNAME, + Class: dns.ClassINET, + }, + Target: "example.org.", + }) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: "example.org.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.IPv4(127, 0, 0, 1), + }) + return m + }(), + expected: 1, + err: nil, + }, + + { + name: "NoAnswers", + query: func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + return m + }(), + resp: func() *dns.Msg { + m := new(dns.Msg) + m.SetReply(new(dns.Msg)) + return m + }(), + expected: 0, + err: ErrNoData, + }, + + { + name: "MismatchedName", + query: func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + return m + }(), + resp: func() *dns.Msg { + m := new(dns.Msg) + m.SetReply(new(dns.Msg)) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: "example.org.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: net.IPv4(127, 0, 0, 1), + }) + return m + }(), + expected: 0, + err: ErrNoData, + }, + + { + name: "MismatchedClass", + query: func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + return m + }(), + resp: func() *dns.Msg { + m := new(dns.Msg) + m.SetReply(new(dns.Msg)) + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: "example.com.", + Rrtype: dns.TypeA, + Class: dns.ClassCHAOS, + }, + A: net.IPv4(127, 0, 0, 1), + }) + return m + }(), + expected: 0, + err: ErrNoData, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + answers, err := ValidAnswers(tt.query.Question[0], tt.resp) + if err != tt.err { + t.Fatalf("expected error %v, got %v", tt.err, err) + } + if len(answers) != tt.expected { + t.Fatalf("expected %d answers, got %d", tt.expected, len(answers)) + } + }) + } +} diff --git a/pkg/dns/dnscore/serveraddr.go b/pkg/dns/dnscore/serveraddr.go new file mode 100644 index 0000000..fa352ed --- /dev/null +++ b/pkg/dns/dnscore/serveraddr.go @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +// Protocol is a transport protocol. +type Protocol string + +// All the implemented DNS protocols. +const ( + // ProtocolUDP is DNS over UDP. + ProtocolUDP = Protocol("udp") + + // ProtocolTCP is DNS over TCP. + ProtocolTCP = Protocol("tcp") + + // ProtocolDoT is DNS over TLS. + ProtocolDoT = Protocol("dot") + + // ProtocolDoH is DNS over HTTPS. + ProtocolDoH = Protocol("doh") + + // ProtocolDoQ is DNS over QUIC. + ProtocolDoQ = Protocol("doq") +) + +// Name aliases for DNS protocols. +const ( + // ProtocolTLS is an alias for ProtocolDoT. + ProtocolTLS = ProtocolDoT + + // ProtocolHTTPS is an alias for ProtocolDoH. + ProtocolHTTPS = ProtocolDoH + + // ProtocolQUIC is an alias for ProtocolDoQ. + ProtocolQUIC = ProtocolDoQ +) + +// ServerAddr is a DNS server address. +// +// While currently minimal, ServerAddr is designed as a pointer type to +// allow for future extensions of server-specific properties (e.g., custom +// headers for DoH) without requiring breaking API changes. +// +// Construct using [NewServerAddr]. +type ServerAddr struct { + // Protocol is the transport protocol to use. + // + // Use one of: + // + // - [ProtocolUDP] + // - [ProtocolTCP] + // - [ProtocolDoT] + // - [ProtocolDoH] + // - [ProtocolDoQ] + Protocol Protocol + + // Address is the network address of the server. + // + // For [ProtocolUDP], [ProtocolTCP], and [ProtocolDoT] this is + // a string in the form returned by [net.JoinHostPort]. + // + // For [ProtocolDoH] this is a URL. + Address string +} + +// NewServerAddr constructs a new [*ServerAddr] with the given protocol and address. +func NewServerAddr(protocol Protocol, address string) *ServerAddr { + return &ServerAddr{ + Protocol: protocol, + Address: address, + } +} diff --git a/pkg/dns/dnscore/serveraddr_test.go b/pkg/dns/dnscore/serveraddr_test.go new file mode 100644 index 0000000..3c88d31 --- /dev/null +++ b/pkg/dns/dnscore/serveraddr_test.go @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import "testing" + +func TestNewServerAddr(t *testing.T) { + protocol := ProtocolUDP + address := "8.8.8.8:53" + serverAddr := NewServerAddr(protocol, address) + + if serverAddr.Protocol != protocol { + t.Errorf("Expected protocol %s, got %s", protocol, serverAddr.Protocol) + } + + if serverAddr.Address != address { + t.Errorf("Expected address %s, got %s", address, serverAddr.Address) + } +} diff --git a/pkg/dns/dnscore/slog.go b/pkg/dns/dnscore/slog.go new file mode 100644 index 0000000..6ea4f89 --- /dev/null +++ b/pkg/dns/dnscore/slog.go @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "context" + "log/slog" + "net/netip" + "time" + + "github.com/rbmk-project/rbmk/pkg/common/netipx" +) + +// addrToAddrPort is an alias for [common.AddrToAddrPort]. +var addrToAddrPort = netipx.AddrToAddrPort + +// protocolMap maps the DNS protocol to the corresponding network protocol. +var protocolMap = map[Protocol]string{ + ProtocolDoH: "tcp", + ProtocolTCP: "tcp", + ProtocolDoT: "tcp", + ProtocolUDP: "udp", + ProtocolDoQ: "udp", +} + +// maybeLogQuery is a helper function that logs the query if the logger is set +// and returns the current time for subsequent logging. +func (t *Transport) maybeLogQuery( + ctx context.Context, addr *ServerAddr, rawQuery []byte) time.Time { + t0 := t.timeNow() + if t.Logger != nil { + t.Logger.InfoContext( + ctx, + "dnsQuery", + slog.Any("dnsRawQuery", rawQuery), + slog.String("serverAddr", addr.Address), + slog.String("serverProtocol", string(addr.Protocol)), + slog.Time("t", t0), + slog.String("protocol", protocolMap[addr.Protocol]), + ) + } + return t0 +} + +// maybeLogResponseAddrPort is a helper function that logs the response if the logger is set. +func (t *Transport) maybeLogResponseAddrPort(ctx context.Context, + addr *ServerAddr, t0 time.Time, rawQuery, rawResp []byte, + laddr, raddr netip.AddrPort) { + if t.Logger != nil { + // Convert zero values to unspecified + if !laddr.IsValid() { + laddr = netip.AddrPortFrom(netip.IPv6Unspecified(), 0) + } + if !raddr.IsValid() { + raddr = netip.AddrPortFrom(netip.IPv6Unspecified(), 0) + } + + t.Logger.InfoContext( + ctx, + "dnsResponse", + slog.String("localAddr", laddr.String()), + slog.Any("dnsRawQuery", rawQuery), + slog.Any("dnsRawResponse", rawResp), + slog.String("remoteAddr", raddr.String()), + slog.String("serverAddr", addr.Address), + slog.String("serverProtocol", string(addr.Protocol)), + slog.Time("t0", t0), + slog.Time("t", t.timeNow()), + slog.String("protocol", protocolMap[addr.Protocol]), + ) + } +} + +// maybeLogResponseConn is a helper function that logs the response if the logger is set. +func (t *Transport) maybeLogResponseConn(ctx context.Context, + addr *ServerAddr, t0 time.Time, rawQuery, rawResp []byte, + conn dnsStream) { + if t.Logger != nil { + t.maybeLogResponseAddrPort( + ctx, + addr, + t0, + rawQuery, + rawResp, + addrToAddrPort(conn.LocalAddr()), + addrToAddrPort(conn.RemoteAddr()), + ) + } +} diff --git a/pkg/dns/dnscore/slog_test.go b/pkg/dns/dnscore/slog_test.go new file mode 100644 index 0000000..d09cc96 --- /dev/null +++ b/pkg/dns/dnscore/slog_test.go @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "bytes" + "context" + "io" + "log/slog" + "net" + "net/netip" + "testing" + "time" + + "github.com/rbmk-project/rbmk/pkg/common/mocks" + "github.com/stretchr/testify/assert" +) + +func TestTransport_maybeLogQuery(t *testing.T) { + tests := []struct { + name string + newLogger func(w io.Writer) *slog.Logger + expectTime time.Time + expectLog string + }{ + { + name: "Logger set", + newLogger: func(w io.Writer) *slog.Logger { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.TimeKey { + return slog.Attr{} + } + return attr + }, + })) + }, + expectTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + expectLog: "{\"level\":\"INFO\",\"msg\":\"dnsQuery\",\"dnsRawQuery\":\"AAAAAA==\",\"serverAddr\":\"8.8.8.8:53\",\"serverProtocol\":\"udp\",\"t\":\"2020-01-01T00:00:00Z\",\"protocol\":\"udp\"}\n", + }, + + { + name: "Logger not set", + newLogger: func(w io.Writer) *slog.Logger { return nil }, + expectTime: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + expectLog: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var out bytes.Buffer + transport := &Transport{ + Logger: tt.newLogger(&out), + TimeNow: func() time.Time { + return tt.expectTime + }, + } + + addr := &ServerAddr{Address: "8.8.8.8:53", Protocol: ProtocolUDP} + rawQuery := []byte{0, 0, 0, 0} + + ctx := context.Background() + actualTime := transport.maybeLogQuery(ctx, addr, rawQuery) + + assert.Equal(t, tt.expectTime, actualTime) + + actualLog := out.String() + assert.Equal(t, tt.expectLog, actualLog) + }) + } +} + +func TestTransport_maybeLogResponseAddrPort(t *testing.T) { + tests := []struct { + name string + newLogger func(w io.Writer) *slog.Logger + laddr netip.AddrPort + raddr netip.AddrPort + expectLog string + }{ + { + name: "Logger set with valid addresses", + newLogger: func(w io.Writer) *slog.Logger { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.TimeKey { + return slog.Attr{} + } + return attr + }, + })) + }, + laddr: netip.MustParseAddrPort("[2001:db8::1]:1234"), + raddr: netip.MustParseAddrPort("[2001:db8::2]:443"), + expectLog: "{\"level\":\"INFO\",\"msg\":\"dnsResponse\",\"localAddr\":\"[2001:db8::1]:1234\",\"dnsRawQuery\":\"AAAAAA==\",\"dnsRawResponse\":\"AQEBAQ==\",\"remoteAddr\":\"[2001:db8::2]:443\",\"serverAddr\":\"8.8.8.8:53\",\"serverProtocol\":\"udp\",\"t0\":\"2020-01-01T00:00:00Z\",\"t\":\"2020-01-01T00:00:11Z\",\"protocol\":\"udp\"}\n", + }, + + { + name: "Logger set with invalid addresses", + newLogger: func(w io.Writer) *slog.Logger { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.TimeKey { + return slog.Attr{} + } + return attr + }, + })) + }, + laddr: netip.AddrPort{}, // invalid + raddr: netip.AddrPort{}, // invalid + expectLog: "{\"level\":\"INFO\",\"msg\":\"dnsResponse\",\"localAddr\":\"[::]:0\",\"dnsRawQuery\":\"AAAAAA==\",\"dnsRawResponse\":\"AQEBAQ==\",\"remoteAddr\":\"[::]:0\",\"serverAddr\":\"8.8.8.8:53\",\"serverProtocol\":\"udp\",\"t0\":\"2020-01-01T00:00:00Z\",\"t\":\"2020-01-01T00:00:11Z\",\"protocol\":\"udp\"}\n", + }, + + { + name: "Logger not set", + newLogger: func(w io.Writer) *slog.Logger { return nil }, + expectLog: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var out bytes.Buffer + transport := &Transport{ + Logger: tt.newLogger(&out), + TimeNow: func() time.Time { + return time.Date(2020, 1, 1, 0, 0, 11, 0, time.UTC) + }, + } + + addr := &ServerAddr{Address: "8.8.8.8:53", Protocol: ProtocolUDP} + rawQuery := []byte{0, 0, 0, 0} + rawResponse := []byte{1, 1, 1, 1} + + ctx := context.Background() + transport.maybeLogResponseAddrPort( + ctx, + addr, + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + rawQuery, + rawResponse, + tt.laddr, + tt.raddr, + ) + + actualLog := out.String() + assert.Equal(t, tt.expectLog, actualLog) + }) + } +} + +func TestTransport_maybeLogResponseConn(t *testing.T) { + tests := []struct { + name string + newLogger func(w io.Writer) *slog.Logger + conn net.Conn + expectLog string + }{ + { + name: "Logger set with TCP addresses", + newLogger: func(w io.Writer) *slog.Logger { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.TimeKey { + return slog.Attr{} + } + return attr + }, + })) + }, + conn: &mocks.Conn{ + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.ParseIP("2001:db8::1"), + Port: 1234, + } + }, + MockRemoteAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.ParseIP("2001:db8::2"), + Port: 443, + } + }, + }, + expectLog: "{\"level\":\"INFO\",\"msg\":\"dnsResponse\",\"localAddr\":\"[2001:db8::1]:1234\",\"dnsRawQuery\":\"AAAAAA==\",\"dnsRawResponse\":\"AQEBAQ==\",\"remoteAddr\":\"[2001:db8::2]:443\",\"serverAddr\":\"8.8.8.8:53\",\"serverProtocol\":\"udp\",\"t0\":\"2020-01-01T00:00:00Z\",\"t\":\"2020-01-01T00:00:11Z\",\"protocol\":\"udp\"}\n", + }, + + { + name: "Logger set with non-TCP addresses", + newLogger: func(w io.Writer) *slog.Logger { + return slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.TimeKey { + return slog.Attr{} + } + return attr + }, + })) + }, + conn: &mocks.Conn{ + MockLocalAddr: func() net.Addr { + return &net.UnixAddr{Name: "/tmp/local.sock", Net: "unix"} + }, + MockRemoteAddr: func() net.Addr { + return &net.UnixAddr{Name: "/tmp/remote.sock", Net: "unix"} + }, + }, + expectLog: "{\"level\":\"INFO\",\"msg\":\"dnsResponse\",\"localAddr\":\"[::]:0\",\"dnsRawQuery\":\"AAAAAA==\",\"dnsRawResponse\":\"AQEBAQ==\",\"remoteAddr\":\"[::]:0\",\"serverAddr\":\"8.8.8.8:53\",\"serverProtocol\":\"udp\",\"t0\":\"2020-01-01T00:00:00Z\",\"t\":\"2020-01-01T00:00:11Z\",\"protocol\":\"udp\"}\n", + }, + + { + name: "Logger not set", + newLogger: func(w io.Writer) *slog.Logger { return nil }, + conn: &mocks.Conn{}, + expectLog: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var out bytes.Buffer + transport := &Transport{ + Logger: tt.newLogger(&out), + TimeNow: func() time.Time { + return time.Date(2020, 1, 1, 0, 0, 11, 0, time.UTC) + }, + } + + addr := &ServerAddr{Address: "8.8.8.8:53", Protocol: ProtocolUDP} + rawQuery := []byte{0, 0, 0, 0} + rawResponse := []byte{1, 1, 1, 1} + + ctx := context.Background() + transport.maybeLogResponseConn( + ctx, + addr, + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + rawQuery, + rawResponse, + tt.conn, + ) + + actualLog := out.String() + assert.Equal(t, tt.expectLog, actualLog) + }) + } +} diff --git a/pkg/dns/dnscore/transport.go b/pkg/dns/dnscore/transport.go new file mode 100644 index 0000000..0efa3e3 --- /dev/null +++ b/pkg/dns/dnscore/transport.go @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "context" + "crypto/x509" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/netip" + "time" + + "github.com/miekg/dns" +) + +// Transport allows sending and receiving DNS messages. +// +// The zero value is ready to use. +// +// A [*Transport] is safe for concurrent use by multiple goroutines +// as long as you don't modify its fields after construction and the +// underlying fields you may set (e.g., DialContext) are also safe. +type Transport struct { + // DialContext is the optional dialer for creating new + // TCP and UDP connections. If this field is nil, the default + // dialer from the [net] package will be used. + DialContext func(ctx context.Context, network, address string) (net.Conn, error) + + // DialTLSContext is like DialContext but for creating new + // TLS connections. If this field is nil, we will configure + // a suitable [*tls.Config] and use [*tls.Dialer]. + DialTLSContext func(ctx context.Context, network, address string) (net.Conn, error) + + // HTTPClient is the optional HTTP client to use for DNS-over-HTTPS. + // If this field is nil, we use the default HTTP client from [net/http]. + // + // When HTTPClientDo is nil and this field is not nil, we use this client to + // perform queries and http/httptrace to obtain connection information. + HTTPClient *http.Client + + // HTTPClientDo optionally allows full control over how HTTP requests + // are performed and how to obtain connection information. When this + // field is non-nil, it takes precedence over HTTPClient. + // + // This field is mainly useful for measurement scenarios where you need + // precise control over connection handling and addressing information. + HTTPClientDo func(req *http.Request) (*http.Response, netip.AddrPort, netip.AddrPort, error) + + // Logger is the optional structured logger for emitting + // structured diagnostic events. If this field is nil, we + // will not be emitting structured logs. + Logger *slog.Logger + + // NewHTTPRequestWithContext is an optional function that creates a new + // HTTP request with the given context. If this field is nil, the + // [http.NewRequestWithContext] function will be used. + NewHTTPRequestWithContext func(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) + + // ReadAllContext is the optional function to read the whole HTTP response + // body in DNS-over-HTTPS. If this field is nil, we use the [io.ReadAll] function + // instead. Compared to [io.ReadAll], this function has a context argument + // and an [io.Closer] argument, which SHOULD be used to close the connection + // when the context is cancelled. In general, this is not useful, but in censored + // places censorship may desync the TCP connection, making context-based + // interruption useful to avoid being blocked ~forever. + ReadAllContext func(ctx context.Context, r io.Reader, c io.Closer) ([]byte, error) + + // RootCAs contains the [*x509.CertPool] used by DNS-over-TLS + // when the DialTLSContext function pointer is nil. Leaving this + // field nil implies using the system's root CAs. + RootCAs *x509.CertPool + + // TimeNow is an optional function that returns the current time. + // If this field is nil, the [time.Now] function will be used. + TimeNow func() time.Time +} + +// DefaultTransport is the default transport used by the package. +var DefaultTransport = &Transport{} + +// ErrNoSuchTransportProtocol is returned when the given protocol is not supported. +var ErrNoSuchTransportProtocol = errors.New("no such transport protocol") + +// Query sends a DNS query to the given server address and returns the response. +// +// The context is used to control the query lifetime. If the context is +// cancelled or times out, the query will be aborted and an error will +// be immediately returned to the caller. +// +// The returned DNS message is the first message received from the server and +// it is not guaranteed to be valid for the query. You will still need to +// validate the response using the [ValidateResponse] function. +func (t *Transport) Query(ctx context.Context, + addr *ServerAddr, query *dns.Msg) (*dns.Msg, error) { + switch addr.Protocol { + case ProtocolUDP: + return t.queryUDP(ctx, addr, query) + + case ProtocolTCP: + return t.queryTCP(ctx, addr, query) + + case ProtocolDoT: + return t.queryTLS(ctx, addr, query) + + case ProtocolDoH: + return t.queryHTTPS(ctx, addr, query) + + case ProtocolDoQ: + return t.queryQUIC(ctx, addr, query) + + default: + return nil, fmt.Errorf("%w: %s", ErrNoSuchTransportProtocol, addr.Protocol) + } +} + +// MessageOrError contains either a DNS message or an error. +type MessageOrError struct { + Err error + Msg *dns.Msg +} + +// ErrTransportCannotReceiveDuplicates is returned when the transport cannot receive duplicates. +var ErrTransportCannotReceiveDuplicates = errors.New("transport cannot receive duplicates") + +// QueryWithDuplicates sends a DNS query to the given server address +// and returns the received responses. Use this method when you expect +// duplicate responses possibly caused by censorship. For example, +// the GFW (Great Firewall of China) typically causes duplicate responses +// with different addresses when a given domain is censored. +// +// This method only works with [ProtocolUDP]. +// +// As for [*Transport.Query], the context is used to control the query +// lifetime. If the context is cancelled or times out, the query will be +// aborted and the returned channel will be then closed. +// +// The returned DNS messages are the responses received from the server and +// they are not guaranteed to be valid for the query. You will still need to +// validate the responses using the [ValidateResponse] function. +func (t *Transport) QueryWithDuplicates(ctx context.Context, + addr *ServerAddr, query *dns.Msg) <-chan *MessageOrError { + + if addr.Protocol != ProtocolUDP { + ch := make(chan *MessageOrError, 1) + ch <- &MessageOrError{Err: fmt.Errorf( + "%w: %s", ErrTransportCannotReceiveDuplicates, addr.Protocol)} + close(ch) + return ch + } + + return t.queryUDPWithDuplicates(ctx, addr, query) +} diff --git a/pkg/dns/dnscore/transport_test.go b/pkg/dns/dnscore/transport_test.go new file mode 100644 index 0000000..85b5327 --- /dev/null +++ b/pkg/dns/dnscore/transport_test.go @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscore + +import ( + "context" + "errors" + "testing" + + "github.com/miekg/dns" +) + +func TestTransportQuery(t *testing.T) { + // create a canceled context so that we do not actually perform the query + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + tests := []struct { + protocol Protocol + expectErr error + }{ + {protocol: ProtocolUDP, expectErr: context.Canceled}, + {protocol: ProtocolTCP, expectErr: context.Canceled}, + {protocol: ProtocolDoT, expectErr: context.Canceled}, + {protocol: ProtocolDoH, expectErr: context.Canceled}, + {protocol: "", expectErr: ErrNoSuchTransportProtocol}, + } + + for _, tt := range tests { + t.Run(string(tt.protocol), func(t *testing.T) { + txp := &Transport{} + query := &dns.Msg{} + addr := NewServerAddr(tt.protocol, "") + + resp, err := txp.Query(ctx, addr, query) + + if !errors.Is(err, tt.expectErr) { + t.Errorf("expected %v error, got %v", tt.expectErr, err) + } + if resp != nil { + t.Errorf("expected nil response, got %v", resp) + } + }) + } +} + +func TestTransportQueryWithDuplicates(t *testing.T) { + // create a canceled context so that we do not actually perform the query + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + tests := []struct { + protocol Protocol + expectErr error + }{ + {protocol: ProtocolUDP, expectErr: context.Canceled}, + {protocol: ProtocolTCP, expectErr: ErrTransportCannotReceiveDuplicates}, + {protocol: ProtocolDoT, expectErr: ErrTransportCannotReceiveDuplicates}, + {protocol: ProtocolDoH, expectErr: ErrTransportCannotReceiveDuplicates}, + } + + for _, tt := range tests { + t.Run(string(tt.protocol), func(t *testing.T) { + txp := &Transport{} + query := &dns.Msg{} + addr := NewServerAddr(tt.protocol, "") + + out := txp.QueryWithDuplicates(ctx, addr, query) + + var results []*MessageOrError + for result := range out { + results = append(results, result) + } + + if len(results) != 1 { + t.Errorf("expected 1 result, got %d", len(results)) + } + + r0 := results[0] + if !errors.Is(r0.Err, tt.expectErr) { + t.Errorf("expected %v error, got %v", tt.expectErr, r0.Err) + } + if r0.Msg != nil { + t.Errorf("expected nil response, got %v", r0.Msg) + } + }) + } +} diff --git a/pkg/dns/dnscoretest/cert.pem b/pkg/dns/dnscoretest/cert.pem new file mode 100644 index 0000000..da1cdc1 --- /dev/null +++ b/pkg/dns/dnscoretest/cert.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBwTCCAWegAwIBAgIQUX7GxXqHiY45USlUfm9/ajAKBggqhkjOPQQDAjAtMRUw +EwYDVQQKEwxSQk1LIFByb2plY3QxFDASBgNVBAMTC2V4YW1wbGUuY29tMB4XDTI0 +MTEwNzEwMzY0MloXDTI1MTEwNzEwMzY0MlowLTEVMBMGA1UEChMMUkJNSyBQcm9q +ZWN0MRQwEgYDVQQDEwtleGFtcGxlLmNvbTBZMBMGByqGSM49AgEGCCqGSM49AwEH +A0IABHbFYUQflAxpw42YrHHXHdJZWoTbZsyRLmhiHY4OSUG/Ke6NFgrpDbYy5NFI ++8q3Qyvy4mXsJsdMKYMCHfLbtD+jaTBnMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE +DDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMDIGA1UdEQQrMCmCD3d3dy5leGFt +cGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATAKBggqhkjOPQQDAgNIADBF +AiEAxmrTP9s03S3Xi+n1k+9OInWXxI6LDAvsSctRpC98EKkCIBhyMU/qrLqdmqtu +BaPqYzn+6wv6vzlhbt42aHyjuAsu +-----END CERTIFICATE----- diff --git a/pkg/dns/dnscoretest/doc.go b/pkg/dns/dnscoretest/doc.go new file mode 100644 index 0000000..3f6f89b --- /dev/null +++ b/pkg/dns/dnscoretest/doc.go @@ -0,0 +1,4 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +// Package dnscoretest contains fake servers to test dnscore. +package dnscoretest diff --git a/pkg/dns/dnscoretest/dohttps.go b/pkg/dns/dnscoretest/dohttps.go new file mode 100644 index 0000000..9621e98 --- /dev/null +++ b/pkg/dns/dnscoretest/dohttps.go @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscoretest + +import ( + "crypto/tls" + "crypto/x509" + "io" + "net/http" + "net/url" + + "github.com/rbmk-project/rbmk/pkg/common/runtimex" +) + +// StartHTTPS starts an HTTPS server and handles incoming DNS queries. +// +// This method panics in case of failure. +func (s *Server) StartHTTPS(handler Handler) <-chan struct{} { + runtimex.Assert(!s.started, "already started") + ready := make(chan struct{}) + go func() { + cert := runtimex.Try1(tls.X509KeyPair(certPEM, keyPEM)) + config := &tls.Config{Certificates: []tls.Certificate{cert}} + listener := runtimex.Try1(s.listenTLS("tcp", "127.0.0.1:0", config)) + s.Addr = listener.Addr().String() + s.RootCAs = x509.NewCertPool() + runtimex.Assert(s.RootCAs.AppendCertsFromPEM(certPEM), "cannot append PEM cert") + s.URL = (&url.URL{Scheme: "https", Host: s.Addr, Path: "/dns-query"}).String() + s.ioclosers = append(s.ioclosers, listener) + s.started = true + srv := &http.Server{ + Handler: newHTTPHandler(handler), + } + close(ready) + _ = srv.Serve(listener) + }() + return ready +} + +func newHTTPHandler(handler Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawQuery := runtimex.Try1(io.ReadAll(r.Body)) + rw := &responseWriterHTTPS{w} + handler.Handle(rw, rawQuery) + }) +} + +// responseWriterHTTPS is a response writer for HTTPS. +type responseWriterHTTPS struct { + w http.ResponseWriter +} + +// Ensure responseWriterHTTPS implements ResponseWriter. +var _ ResponseWriter = (*responseWriterHTTPS)(nil) + +// Write implements ResponseWriter. +func (r *responseWriterHTTPS) Write(rawResp []byte) (int, error) { + r.w.Header().Add("Content-Type", "application/dns-message") + return r.w.Write(rawResp) +} diff --git a/pkg/dns/dnscoretest/dotcp.go b/pkg/dns/dnscoretest/dotcp.go new file mode 100644 index 0000000..9eeffee --- /dev/null +++ b/pkg/dns/dnscoretest/dotcp.go @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscoretest + +import ( + "bufio" + "io" + "math" + "net" + + "github.com/rbmk-project/rbmk/pkg/common/runtimex" +) + +// StartTCP starts a TCP listener and listens for incoming DNS queries. +// +// This method panics in case of failure. +func (s *Server) StartTCP(handler Handler) <-chan struct{} { + runtimex.Assert(!s.started, "already started") + ready := make(chan struct{}) + go func() { + listener := runtimex.Try1(s.listen("tcp", "127.0.0.1:0")) + s.Addr = listener.Addr().String() + s.ioclosers = append(s.ioclosers, listener) + s.started = true + close(ready) + for { + conn, err := listener.Accept() + if err != nil { + return + } + s.serveConn(handler, conn) + } + }() + return ready +} + +// listen either used the stdlib or the custom Listen func. +func (s *Server) listen(network, address string) (net.Listener, error) { + if s.Listen != nil { + return s.Listen(network, address) + } + return net.Listen(network, address) +} + +// serveConn serves a single DNS query over TCP or TLS. +func (s *Server) serveConn(handler Handler, conn net.Conn) { + // Close the connection when done serving + defer conn.Close() + + // Wrap the conn into a bufio.Reader and read the whole message + br := bufio.NewReader(conn) + header := make([]byte, 2) + _ = runtimex.Try1(io.ReadFull(br, header)) + length := int(header[0])<<8 | int(header[1]) + rawQuery := make([]byte, length) + _ = runtimex.Try1(io.ReadFull(br, rawQuery)) + + // Wrap into a response writer and serve + rw := &responseWriterStream{conn: conn} + handler.Handle(rw, rawQuery) +} + +// responseWriterStream is a response writer for TCP or TLS. +type responseWriterStream struct { + conn net.Conn +} + +// Ensure responseWriterStream implements ResponseWriter. +var _ ResponseWriter = (*responseWriterStream)(nil) + +// Write implements ResponseWriter. +func (r *responseWriterStream) Write(rawMsg []byte) (int, error) { + runtimex.Assert(len(rawMsg) <= math.MaxUint16, "message too large") + rawMsgFrame := []byte{byte(len(rawMsg) >> 8)} + rawMsgFrame = append(rawMsgFrame, byte(len(rawMsg))) + rawMsgFrame = append(rawMsgFrame, rawMsg...) + return r.conn.Write(rawMsgFrame) +} diff --git a/pkg/dns/dnscoretest/dotls.go b/pkg/dns/dnscoretest/dotls.go new file mode 100644 index 0000000..74d6652 --- /dev/null +++ b/pkg/dns/dnscoretest/dotls.go @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscoretest + +import ( + "crypto/tls" + "crypto/x509" + _ "embed" + "net" + + "github.com/rbmk-project/rbmk/pkg/common/runtimex" +) + +var ( + //go:embed cert.pem + certPEM []byte + + //go:embed key.pem + keyPEM []byte +) + +// StartTLS starts a TLS listener and listens for incoming DNS queries. +// +// This method panics in case of failure. +func (s *Server) StartTLS(handler Handler) <-chan struct{} { + runtimex.Assert(!s.started, "already started") + ready := make(chan struct{}) + go func() { + cert := runtimex.Try1(tls.X509KeyPair(certPEM, keyPEM)) + config := &tls.Config{Certificates: []tls.Certificate{cert}} + listener := runtimex.Try1(s.listenTLS("tcp", "127.0.0.1:0", config)) + s.Addr = listener.Addr().String() + s.RootCAs = x509.NewCertPool() + runtimex.Assert(s.RootCAs.AppendCertsFromPEM(certPEM), "cannot append PEM cert") + s.ioclosers = append(s.ioclosers, listener) + s.started = true + close(ready) + for { + conn, err := listener.Accept() + if err != nil { + return + } + s.serveConn(handler, conn) + } + }() + return ready +} + +// listenTLS either uses the stdlib or the custom ListenTLS func. +func (s *Server) listenTLS(network, address string, config *tls.Config) (net.Listener, error) { + if s.ListenTLS != nil { + return s.ListenTLS(network, address, config) + } + return tls.Listen(network, address, config) +} diff --git a/pkg/dns/dnscoretest/doudp.go b/pkg/dns/dnscoretest/doudp.go new file mode 100644 index 0000000..0fa799b --- /dev/null +++ b/pkg/dns/dnscoretest/doudp.go @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscoretest + +import ( + "net" + + "github.com/rbmk-project/rbmk/pkg/common/runtimex" +) + +// StartUDP starts an UDP listener and listens for incoming DNS queries. +// +// This method panics in case of failure. +func (s *Server) StartUDP(handler Handler) <-chan struct{} { + runtimex.Assert(!s.started, "already started") + ready := make(chan struct{}) + go func() { + pconn := runtimex.Try1(s.listenPacket("udp", "127.0.0.1:0")) + s.Addr = pconn.LocalAddr().String() + s.ioclosers = append(s.ioclosers, pconn) + s.started = true + close(ready) + for s.servePacketConn(handler, pconn) == nil { + // nothing + } + }() + return ready +} + +// listenPacket either uses the standard library or the custom ListenPacket func. +func (s *Server) listenPacket(network, address string) (net.PacketConn, error) { + if s.ListenPacket != nil { + return s.ListenPacket(network, address) + } + return net.ListenPacket(network, address) +} + +// servePacketConn serves a single DNS query over UDP. +func (s *Server) servePacketConn(handler Handler, pconn net.PacketConn) error { + buf := make([]byte, 4096) + count, addr, err := pconn.ReadFrom(buf) + if err != nil { + return err + } + rawQuery := buf[:count] + rw := &responseWriterUDP{pconn: pconn, addr: addr} + handler.Handle(rw, rawQuery) + return nil +} + +// responseWriterUDP is a response writer for UDP. +type responseWriterUDP struct { + pconn net.PacketConn + addr net.Addr +} + +// Ensure responseWriterUDP implements ResponseWriter. +var _ ResponseWriter = (*responseWriterUDP)(nil) + +// Write implements ResponseWriter. +func (r *responseWriterUDP) Write(rawMsg []byte) (int, error) { + return r.pconn.WriteTo(rawMsg, r.addr) +} diff --git a/pkg/dns/dnscoretest/example_test.go b/pkg/dns/dnscoretest/example_test.go new file mode 100644 index 0000000..f387347 --- /dev/null +++ b/pkg/dns/dnscoretest/example_test.go @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscoretest_test + +import ( + "fmt" + "log" + "slices" + "strings" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/dns/dnscoretest" +) + +func ExampleServer_udp() { + // Create a fake UDP server using the example.com handler + server := &dnscoretest.Server{} + handler := dnscoretest.NewExampleComHandler() + <-server.StartUDP(handler) + defer server.Close() + + // Create a DNS client + client := &dns.Client{Net: "udp"} + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + // Send the query to the fake server + resp, _, err := client.Exchange(query, server.Addr) + if err != nil { + log.Fatal(err) + } + + // print the results + var addrs []string + for _, rr := range resp.Answer { + switch rr := rr.(type) { + case *dns.A: + addrs = append(addrs, rr.A.String()) + } + } + slices.Sort(addrs) + fmt.Printf("%s\n", strings.Join(addrs, "\n")) + + // Output: + // 93.184.215.14 +} diff --git a/pkg/dns/dnscoretest/handler.go b/pkg/dns/dnscoretest/handler.go new file mode 100644 index 0000000..7707704 --- /dev/null +++ b/pkg/dns/dnscoretest/handler.go @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscoretest + +import ( + "io" + "net" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/runtimex" +) + +// ResponseWriter allows writing raw DNS responses. +type ResponseWriter interface { + io.Writer +} + +// Handler is a function that handles a DNS query. +type Handler interface { + Handle(rw ResponseWriter, rawQuery []byte) +} + +// HandlerFunc is an adapter to allow the use of ordinary functions as DNS handlers. +type HandlerFunc func(rw ResponseWriter, rawQuery []byte) + +// Ensure HandlerFunc implements Handler. +var _ Handler = HandlerFunc(nil) + +// Handle implements Handler. +func (hf HandlerFunc) Handle(rw ResponseWriter, rawQuery []byte) { + hf(rw, rawQuery) +} + +// ExampleComAddrA is the A address of example.com. +var ExampleComAddrA = net.IPv4(93, 184, 215, 14) + +// NewExampleComHandler returns a handler that responds with a valid DNS response for example.com. +func NewExampleComHandler() Handler { + return HandlerFunc(func(rw ResponseWriter, rawQuery []byte) { + query := &dns.Msg{} + runtimex.Try0(query.Unpack(rawQuery)) + resp := &dns.Msg{} + resp.SetReply(query) + resp.Answer = append(resp.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: "example.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 3600, + Rdlength: 0, + }, + A: ExampleComAddrA, + }) + rawResp := runtimex.Try1(resp.Pack()) + _ = runtimex.Try1(rw.Write(rawResp)) + }) +} diff --git a/pkg/dns/dnscoretest/integration_test.go b/pkg/dns/dnscoretest/integration_test.go new file mode 100644 index 0000000..61cecfe --- /dev/null +++ b/pkg/dns/dnscoretest/integration_test.go @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscoretest_test + +import ( + "bytes" + "crypto/tls" + "io" + "net/http" + "testing" + + "github.com/miekg/dns" + "github.com/rbmk-project/rbmk/pkg/common/runtimex" + "github.com/rbmk-project/rbmk/pkg/dns/dnscoretest" + "github.com/stretchr/testify/assert" +) + +func checkResult(t *testing.T, resp *dns.Msg, err error) { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 1, len(resp.Answer)) + assert.Equal(t, "example.com.", resp.Answer[0].Header().Name) + assert.Equal(t, dns.TypeA, resp.Answer[0].Header().Rrtype) + assert.Equal( + t, dnscoretest.ExampleComAddrA.String(), + resp.Answer[0].(*dns.A).A.String(), + ) +} + +func TestFakeDNSServer_UDP(t *testing.T) { + // Create a fake UDP server using the example.com handler + server := &dnscoretest.Server{} + handler := dnscoretest.NewExampleComHandler() + <-server.StartUDP(handler) + defer server.Close() + + // Create a DNS client + client := &dns.Client{Net: "udp"} + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + // Send the query to the fake server + resp, _, err := client.Exchange(query, server.Addr) + + // Validate the results + checkResult(t, resp, err) +} + +func TestFakeDNSServer_TCP(t *testing.T) { + // Create a fake TCP server using the example.com handler + server := &dnscoretest.Server{} + handler := dnscoretest.NewExampleComHandler() + <-server.StartTCP(handler) + defer server.Close() + + // Create a DNS client + client := &dns.Client{Net: "tcp"} + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + // Send the query to the fake server + resp, _, err := client.Exchange(query, server.Addr) + + // Validate the results + checkResult(t, resp, err) +} + +func TestFakeDNSServer_TLS(t *testing.T) { + // Create a fake TLS server using the example.com handler + server := &dnscoretest.Server{} + handler := dnscoretest.NewExampleComHandler() + <-server.StartTLS(handler) + defer server.Close() + + // Create a DNS client with TLS configuration + tlsConfig := &tls.Config{ + RootCAs: server.RootCAs, + } + client := &dns.Client{ + Net: "tcp-tls", + TLSConfig: tlsConfig, + } + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + + // Send the query to the fake server + resp, _, err := client.Exchange(query, server.Addr) + + // Validate the results + checkResult(t, resp, err) +} + +func TestFakeDNSServer_HTTPS(t *testing.T) { + // Create a fake HTTPS server using the example.com handler + server := &dnscoretest.Server{} + handler := dnscoretest.NewExampleComHandler() + <-server.StartHTTPS(handler) + defer server.Close() + + // Create an HTTP client with TLS configuration + tlsConfig := &tls.Config{ + RootCAs: server.RootCAs, + } + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + + // Create the HTTP request + query := new(dns.Msg) + query.SetQuestion("example.com.", dns.TypeA) + rawQuery := runtimex.Try1(query.Pack()) + httpReq := runtimex.Try1(http.NewRequest( + "POST", server.URL, bytes.NewReader(rawQuery))) + + // Send the query to the fake server + httpResp, err := client.Do(httpReq) + + // Validate the HTTPS response + if err != nil { + t.Fatal(err) + } + defer httpResp.Body.Close() + if httpResp.StatusCode != http.StatusOK { + t.Fatal("expected 200, got", httpResp.StatusCode) + } + rawResp := runtimex.Try1(io.ReadAll(httpResp.Body)) + resp := &dns.Msg{} + if err := resp.Unpack(rawResp); err != nil { + t.Fatal(err) + } + + // Validate the results + checkResult(t, resp, err) +} diff --git a/pkg/dns/dnscoretest/key.pem b/pkg/dns/dnscoretest/key.pem new file mode 100644 index 0000000..cfaaaaa --- /dev/null +++ b/pkg/dns/dnscoretest/key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIOwpwvZNNf0zFsZtAKmYGW5P9oZmvv2pecngYoHVeB3RoAoGCCqGSM49 +AwEHoUQDQgAEdsVhRB+UDGnDjZiscdcd0llahNtmzJEuaGIdjg5JQb8p7o0WCukN +tjLk0Uj7yrdDK/LiZewmx0wpgwId8tu0Pw== +-----END EC PRIVATE KEY----- diff --git a/pkg/dns/dnscoretest/server.go b/pkg/dns/dnscoretest/server.go new file mode 100644 index 0000000..733bed2 --- /dev/null +++ b/pkg/dns/dnscoretest/server.go @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscoretest + +import ( + "crypto/tls" + "crypto/x509" + "io" + "net" +) + +// Server is a fake DNS server. +// +// The zero value is a valid server. +type Server struct { + // Addr is the address of the server for DNS-over-UDP, + // DNS-over-TCP, and DNS-over-TLS. + Addr string + + // Listen is an optional func to override the default + // function used to create a [net.Listener]. + Listen func(network, address string) (net.Listener, error) + + // ListenPacket is an optional func to override the default + // function used to create a listening [net.PacketConn]. + ListenPacket func(network, address string) (net.PacketConn, error) + + // ListenTLS is an optional func to override the default + // function used to listen using TLS. + ListenTLS func(network, address string, config *tls.Config) (net.Listener, error) + + // RootCAs contains the cert pool the client should use + // for DNS-over-TLS and DNS-over-HTTPS. + RootCAs *x509.CertPool + + // URL is the URL for DNS-over-HTTPS. + URL string + + // ioclosers is a list of ioclosers to close when the server is closed. + ioclosers []io.Closer + + // started indicates that the server has started. + started bool +} + +// Close closes the server. +func (s *Server) Close() error { + var err error + for _, c := range s.ioclosers { + if cerr := c.Close(); cerr != nil { + err = cerr + } + } + return err +} diff --git a/pkg/dns/dnscoretest/server_test.go b/pkg/dns/dnscoretest/server_test.go new file mode 100644 index 0000000..394eac4 --- /dev/null +++ b/pkg/dns/dnscoretest/server_test.go @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package dnscoretest + +import ( + "crypto/tls" + "errors" + "net" + "testing" + + "github.com/rbmk-project/rbmk/pkg/common/mocks" + "github.com/stretchr/testify/assert" +) + +func TestServer_listen(t *testing.T) { + expectedErr := errors.New("mocked error") + srv := &Server{ + Listen: func(network, address string) (net.Listener, error) { + return nil, expectedErr + }, + } + _, err := srv.listen("tcp", "127.0.0.1:0") + assert.ErrorIs(t, err, expectedErr) +} + +func TestServer_listenPacket(t *testing.T) { + expectedErr := errors.New("mocked error") + srv := &Server{ + ListenPacket: func(network, address string) (net.PacketConn, error) { + return nil, expectedErr + }, + } + _, err := srv.listenPacket("udp", "127.0.0.1:0") + assert.ErrorIs(t, err, expectedErr) +} + +func TestServer_listenTLS(t *testing.T) { + expectedErr := errors.New("mocked error") + srv := &Server{ + ListenTLS: func(network, address string, config *tls.Config) (net.Listener, error) { + return nil, expectedErr + }, + } + _, err := srv.listenTLS("tcp", "127.0.0.1:0", nil) + assert.ErrorIs(t, err, expectedErr) +} + +func TestServer_Close(t *testing.T) { + expected := errors.New("mocked error") + srv := &Server{} + + srv.ioclosers = append(srv.ioclosers, &mocks.Conn{ + MockClose: func() error { + return nil + }, + }) + + srv.ioclosers = append(srv.ioclosers, &mocks.Conn{ + MockClose: func() error { + return expected + }, + }) + + srv.ioclosers = append(srv.ioclosers, &mocks.Conn{ + MockClose: func() error { + return nil + }, + }) + + if err := srv.Close(); !errors.Is(err, expected) { + t.Fatal("expected", expected, ", got", err) + } +} diff --git a/pkg/dns/doc.go b/pkg/dns/doc.go new file mode 100644 index 0000000..2dca390 --- /dev/null +++ b/pkg/dns/doc.go @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +/* +Package dns contains DNS-related functionality. + +See [dd-000-dnscore.md] for more information. + +[dd-000-dnscore.md]: https://github.com/rbmk-project/rbmk-project.github.io/blob/main/docs/design/dd-000-dnscore.md +*/ +package common diff --git a/pkg/dns/internal/cmd/mkcert/main.go b/pkg/dns/internal/cmd/mkcert/main.go new file mode 100644 index 0000000..cfc9378 --- /dev/null +++ b/pkg/dns/internal/cmd/mkcert/main.go @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +// Command mkcert generates a self-signed certificate for testing purposes. +package main + +import ( + "path/filepath" + + "github.com/rbmk-project/rbmk/pkg/common/selfsignedcert" +) + +var destdir = filepath.Join("pkg", "dns", "dnscoretest") + +func main() { + cert := selfsignedcert.New(selfsignedcert.NewConfigExampleCom()) + cert.WriteFiles(destdir) +} diff --git a/pkg/dns/internal/cmd/mkcert/main_test.go b/pkg/dns/internal/cmd/mkcert/main_test.go new file mode 100644 index 0000000..3345722 --- /dev/null +++ b/pkg/dns/internal/cmd/mkcert/main_test.go @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package main + +import "testing" + +func Test_main(_ *testing.T) { + destdir = "testdata" + main() +} diff --git a/pkg/dns/internal/cmd/mkcert/testdata/.gitignore b/pkg/dns/internal/cmd/mkcert/testdata/.gitignore new file mode 100644 index 0000000..fc4d927 --- /dev/null +++ b/pkg/dns/internal/cmd/mkcert/testdata/.gitignore @@ -0,0 +1,2 @@ +/cert.pem +/key.pem