diff --git a/mocks/packetconn.go b/mocks/packetconn.go new file mode 100644 index 0000000..cccc5d7 --- /dev/null +++ b/mocks/packetconn.go @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package mocks + +import ( + "net" + "time" +) + +// PacketConn is a mockable [net.PacketConn]. +type PacketConn struct { + // MockReadFrom is the function to call when ReadFrom is called. + MockReadFrom func(p []byte) (int, net.Addr, error) + + // MockWriteTo is the function to call when WriteTo is called. + MockWriteTo func(p []byte, addr net.Addr) (int, error) + + // MockClose is the function to call when Close is called. + MockClose func() error + + // MockLocalAddr is the function to call when LocalAddr is called. + MockLocalAddr func() net.Addr + + // MockSetDeadline is the function to call when SetDeadline is called. + MockSetDeadline func(t time.Time) error + + // MockSetReadDeadline is the function to call when SetReadDeadline is called. + MockSetReadDeadline func(t time.Time) error + + // MockSetWriteDeadline is the function to call when SetWriteDeadline is called. + MockSetWriteDeadline func(t time.Time) error +} + +var _ net.PacketConn = &PacketConn{} + +// ReadFrom calls MockReadFrom. +func (pc *PacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + return pc.MockReadFrom(p) +} + +// WriteTo calls MockWriteTo. +func (pc *PacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { + return pc.MockWriteTo(p, addr) +} + +// Close calls MockClose. +func (pc *PacketConn) Close() error { + return pc.MockClose() +} + +// LocalAddr calls MockLocalAddr. +func (pc *PacketConn) LocalAddr() net.Addr { + return pc.MockLocalAddr() +} + +// SetDeadline calls MockSetDeadline. +func (pc *PacketConn) SetDeadline(t time.Time) error { + return pc.MockSetDeadline(t) +} + +// SetReadDeadline calls MockSetReadDeadline. +func (pc *PacketConn) SetReadDeadline(t time.Time) error { + return pc.MockSetReadDeadline(t) +} + +// SetWriteDeadline calls MockSetWriteDeadline. +func (pc *PacketConn) SetWriteDeadline(t time.Time) error { + return pc.MockSetWriteDeadline(t) +} diff --git a/mocks/packetconn_test.go b/mocks/packetconn_test.go new file mode 100644 index 0000000..a0f4115 --- /dev/null +++ b/mocks/packetconn_test.go @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package mocks + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestPacketConn(t *testing.T) { + t.Run("ReadFrom", func(t *testing.T) { + expected := errors.New("mocked error") + expectedAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 8080, + } + pc := &PacketConn{ + MockReadFrom: func(p []byte) (int, net.Addr, error) { + return 0, expectedAddr, expected + }, + } + count, addr, err := pc.ReadFrom(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + if diff := cmp.Diff(expectedAddr, addr); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("WriteTo", func(t *testing.T) { + expected := errors.New("mocked error") + addr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 8080, + } + pc := &PacketConn{ + MockWriteTo: func(p []byte, addr net.Addr) (int, error) { + return 0, expected + }, + } + count, err := pc.WriteTo(make([]byte, 128), addr) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + }) + + t.Run("Close", func(t *testing.T) { + expected := errors.New("mocked error") + pc := &PacketConn{ + MockClose: func() error { + return expected + }, + } + err := pc.Close() + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("LocalAddr", func(t *testing.T) { + expected := &net.UDPAddr{ + IP: net.IPv6loopback, + Port: 1234, + } + pc := &PacketConn{ + MockLocalAddr: func() net.Addr { + return expected + }, + } + out := pc.LocalAddr() + if diff := cmp.Diff(expected, out); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("SetDeadline", func(t *testing.T) { + expected := errors.New("mocked error") + pc := &PacketConn{ + MockSetDeadline: func(t time.Time) error { + return expected + }, + } + err := pc.SetDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("SetReadDeadline", func(t *testing.T) { + expected := errors.New("mocked error") + pc := &PacketConn{ + MockSetReadDeadline: func(t time.Time) error { + return expected + }, + } + err := pc.SetReadDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("SetWriteDeadline", func(t *testing.T) { + expected := errors.New("mocked error") + pc := &PacketConn{ + MockSetWriteDeadline: func(t time.Time) error { + return expected + }, + } + err := pc.SetWriteDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) +} diff --git a/mocks/tlsconn.go b/mocks/tlsconn.go new file mode 100644 index 0000000..7da0368 --- /dev/null +++ b/mocks/tlsconn.go @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package mocks + +import "crypto/tls" + +// TLSConn is a mockable TLS connection. +type TLSConn struct { + // We embed Conn to handle the net.Conn interface. + Conn + + // MockConnectionState is the function to call when ConnectionState is called. + MockConnectionState func() tls.ConnectionState + + // MockHandshakeContext is the function to call when HandshakeContext is called. + MockHandshakeContext func() error +} + +// ConnectionState calls MockConnectionState. +func (c *TLSConn) ConnectionState() tls.ConnectionState { + return c.MockConnectionState() +} + +// HandshakeContext calls MockHandshakeContext. +func (c *TLSConn) HandshakeContext() error { + return c.MockHandshakeContext() +} diff --git a/mocks/tlsconn_test.go b/mocks/tlsconn_test.go new file mode 100644 index 0000000..14726b5 --- /dev/null +++ b/mocks/tlsconn_test.go @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: GPL-3.0-or-later + +package mocks + +import ( + "crypto/tls" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestTLSConn(t *testing.T) { + t.Run("ConnectionState", func(t *testing.T) { + expectedState := tls.ConnectionState{ + Version: tls.VersionTLS13, + HandshakeComplete: true, + DidResume: false, + CipherSuite: tls.TLS_AES_128_GCM_SHA256, + NegotiatedProtocol: "h2", + ServerName: "example.com", + PeerCertificates: nil, + VerifiedChains: nil, + SignedCertificateTimestamps: nil, + OCSPResponse: nil, + } + + conn := &TLSConn{ + MockConnectionState: func() tls.ConnectionState { + return expectedState + }, + } + + state := conn.ConnectionState() + if diff := cmp.Diff(expectedState, state, + cmpopts.IgnoreUnexported(tls.ConnectionState{})); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("HandshakeContext", func(t *testing.T) { + expected := errors.New("mocked handshake error") + conn := &TLSConn{ + MockHandshakeContext: func() error { + return expected + }, + } + + err := conn.HandshakeContext() + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + }) + + t.Run("Embedded Conn methods", func(t *testing.T) { + expected := errors.New("mocked read error") + conn := &TLSConn{ + Conn: Conn{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + }, + } + + count, err := conn.Read(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } + }) +}