Skip to content
This repository was archived by the owner on Jul 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions mocks/packetconn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// SPDX-License-Identifier: GPL-3.0-or-later

package mocks

import (
"net"
"time"
)

// PacketConn is a mockable [net.PacketConn].
type PacketConn struct {
// MockReadFrom is the function to call when ReadFrom is called.
MockReadFrom func(p []byte) (int, net.Addr, error)

// MockWriteTo is the function to call when WriteTo is called.
MockWriteTo func(p []byte, addr net.Addr) (int, error)

// MockClose is the function to call when Close is called.
MockClose func() error

// MockLocalAddr is the function to call when LocalAddr is called.
MockLocalAddr func() net.Addr

// MockSetDeadline is the function to call when SetDeadline is called.
MockSetDeadline func(t time.Time) error

// MockSetReadDeadline is the function to call when SetReadDeadline is called.
MockSetReadDeadline func(t time.Time) error

// MockSetWriteDeadline is the function to call when SetWriteDeadline is called.
MockSetWriteDeadline func(t time.Time) error
}

var _ net.PacketConn = &PacketConn{}

// ReadFrom calls MockReadFrom.
func (pc *PacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
return pc.MockReadFrom(p)
}

// WriteTo calls MockWriteTo.
func (pc *PacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
return pc.MockWriteTo(p, addr)
}

// Close calls MockClose.
func (pc *PacketConn) Close() error {
return pc.MockClose()
}

// LocalAddr calls MockLocalAddr.
func (pc *PacketConn) LocalAddr() net.Addr {
return pc.MockLocalAddr()
}

// SetDeadline calls MockSetDeadline.
func (pc *PacketConn) SetDeadline(t time.Time) error {
return pc.MockSetDeadline(t)
}

// SetReadDeadline calls MockSetReadDeadline.
func (pc *PacketConn) SetReadDeadline(t time.Time) error {
return pc.MockSetReadDeadline(t)
}

// SetWriteDeadline calls MockSetWriteDeadline.
func (pc *PacketConn) SetWriteDeadline(t time.Time) error {
return pc.MockSetWriteDeadline(t)
}
125 changes: 125 additions & 0 deletions mocks/packetconn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// SPDX-License-Identifier: GPL-3.0-or-later

package mocks

import (
"errors"
"net"
"testing"
"time"

"github.com/google/go-cmp/cmp"
)

func TestPacketConn(t *testing.T) {
t.Run("ReadFrom", func(t *testing.T) {
expected := errors.New("mocked error")
expectedAddr := &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 8080,
}
pc := &PacketConn{
MockReadFrom: func(p []byte) (int, net.Addr, error) {
return 0, expectedAddr, expected
},
}
count, addr, err := pc.ReadFrom(make([]byte, 128))
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if count != 0 {
t.Fatal("expected 0 bytes")
}
if diff := cmp.Diff(expectedAddr, addr); diff != "" {
t.Fatal(diff)
}
})

t.Run("WriteTo", func(t *testing.T) {
expected := errors.New("mocked error")
addr := &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 8080,
}
pc := &PacketConn{
MockWriteTo: func(p []byte, addr net.Addr) (int, error) {
return 0, expected
},
}
count, err := pc.WriteTo(make([]byte, 128), addr)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if count != 0 {
t.Fatal("expected 0 bytes")
}
})

t.Run("Close", func(t *testing.T) {
expected := errors.New("mocked error")
pc := &PacketConn{
MockClose: func() error {
return expected
},
}
err := pc.Close()
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
})

t.Run("LocalAddr", func(t *testing.T) {
expected := &net.UDPAddr{
IP: net.IPv6loopback,
Port: 1234,
}
pc := &PacketConn{
MockLocalAddr: func() net.Addr {
return expected
},
}
out := pc.LocalAddr()
if diff := cmp.Diff(expected, out); diff != "" {
t.Fatal(diff)
}
})

t.Run("SetDeadline", func(t *testing.T) {
expected := errors.New("mocked error")
pc := &PacketConn{
MockSetDeadline: func(t time.Time) error {
return expected
},
}
err := pc.SetDeadline(time.Time{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
})

t.Run("SetReadDeadline", func(t *testing.T) {
expected := errors.New("mocked error")
pc := &PacketConn{
MockSetReadDeadline: func(t time.Time) error {
return expected
},
}
err := pc.SetReadDeadline(time.Time{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
})

t.Run("SetWriteDeadline", func(t *testing.T) {
expected := errors.New("mocked error")
pc := &PacketConn{
MockSetWriteDeadline: func(t time.Time) error {
return expected
},
}
err := pc.SetWriteDeadline(time.Time{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
})
}
27 changes: 27 additions & 0 deletions mocks/tlsconn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-License-Identifier: GPL-3.0-or-later

package mocks

import "crypto/tls"

// TLSConn is a mockable TLS connection.
type TLSConn struct {
// We embed Conn to handle the net.Conn interface.
Conn

// MockConnectionState is the function to call when ConnectionState is called.
MockConnectionState func() tls.ConnectionState

// MockHandshakeContext is the function to call when HandshakeContext is called.
MockHandshakeContext func() error
}

// ConnectionState calls MockConnectionState.
func (c *TLSConn) ConnectionState() tls.ConnectionState {
return c.MockConnectionState()
}

// HandshakeContext calls MockHandshakeContext.
func (c *TLSConn) HandshakeContext() error {
return c.MockHandshakeContext()
}
74 changes: 74 additions & 0 deletions mocks/tlsconn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// SPDX-License-Identifier: GPL-3.0-or-later

package mocks

import (
"crypto/tls"
"errors"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)

func TestTLSConn(t *testing.T) {
t.Run("ConnectionState", func(t *testing.T) {
expectedState := tls.ConnectionState{
Version: tls.VersionTLS13,
HandshakeComplete: true,
DidResume: false,
CipherSuite: tls.TLS_AES_128_GCM_SHA256,
NegotiatedProtocol: "h2",
ServerName: "example.com",
PeerCertificates: nil,
VerifiedChains: nil,
SignedCertificateTimestamps: nil,
OCSPResponse: nil,
}

conn := &TLSConn{
MockConnectionState: func() tls.ConnectionState {
return expectedState
},
}

state := conn.ConnectionState()
if diff := cmp.Diff(expectedState, state,
cmpopts.IgnoreUnexported(tls.ConnectionState{})); diff != "" {
t.Fatal(diff)
}
})

t.Run("HandshakeContext", func(t *testing.T) {
expected := errors.New("mocked handshake error")
conn := &TLSConn{
MockHandshakeContext: func() error {
return expected
},
}

err := conn.HandshakeContext()
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
})

t.Run("Embedded Conn methods", func(t *testing.T) {
expected := errors.New("mocked read error")
conn := &TLSConn{
Conn: Conn{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
},
}

count, err := conn.Read(make([]byte, 128))
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
}
if count != 0 {
t.Fatal("expected 0 bytes")
}
})
}