Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ issues:
linters-settings:
dogsled:
max-blank-identifiers: 3
maligned:
# print struct with more effective memory layout or not, false by default
suggest-new: true
nolintlint:
allow-unused: false
require-explanation: false
Expand Down
97 changes: 88 additions & 9 deletions abci/strategies/codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package codec
import (
"bytes"
"compress/zlib"
"fmt"
"io"

cometabci "github.com/cometbft/cometbft/abci/types"
Expand All @@ -12,8 +13,9 @@ import (
)

var (
enc, _ = zstd.NewWriter(nil)
dec, _ = zstd.NewReader(nil)
enc, _ = zstd.NewWriter(nil)
dec, _ = zstd.NewReader(nil)
ErrZLibDecompressionLimit = fmt.Errorf("zlib decompression limit reached")
)

// VoteExtensionCodec is the interface for encoding / decoding vote extensions.
Expand All @@ -39,7 +41,6 @@ type ExtendedCommitCodec interface {
}

// NewDefaultVoteExtensionCodec returns a new DefaultVoteExtensionCodec.

func NewDefaultVoteExtensionCodec() *DefaultVoteExtensionCodec {
return &DefaultVoteExtensionCodec{}
}
Expand All @@ -57,24 +58,70 @@ func (codec *DefaultVoteExtensionCodec) Decode(bz []byte) (vetypes.OracleVoteExt
return ve, ve.Unmarshal(bz)
}

// NewVoteExtensionCodecWithSizeCheck returns a new VoteExtensionCodecWithSizeCheck.
func NewVoteExtensionCodecWithSizeCheck() *VoteExtensionCodecWithSizeCheck {
return &VoteExtensionCodecWithSizeCheck{}
}

// VoteExtensionCodecWithSizeCheck is an implementation of VoteExtensionCodec with a size check
// on Decoding level. For Encode it uses the vanilla Marshal implementation. For Decode it has
// an additional Encode and Compare step for checking whether the incoming bytes are the same as
// the decoded and encoded bytes. This makes sure that the vote extension doesn't have any
// extraneous fields.
type VoteExtensionCodecWithSizeCheck struct{}

func (codec *VoteExtensionCodecWithSizeCheck) Encode(ve vetypes.OracleVoteExtension) ([]byte, error) {
return ve.Marshal()
}

func (codec *VoteExtensionCodecWithSizeCheck) Decode(bz []byte) (vetypes.OracleVoteExtension, error) {
var ve vetypes.OracleVoteExtension
if err := ve.Unmarshal(bz); err != nil {
return vetypes.OracleVoteExtension{}, fmt.Errorf("failed to unmarshal vote extension: %w", err)
}

remarshaled, err := ve.Marshal()
if err != nil {
return vetypes.OracleVoteExtension{}, fmt.Errorf("failed to remarshal vote extension for size check: %w", err)
}
if len(bz) != len(remarshaled) {
return vetypes.OracleVoteExtension{}, fmt.Errorf("incoming bytes size doesn't match the remarshaled bytes size: %d != %d", len(bz), len(remarshaled))
}

return ve, nil
}

type Compressor interface {
Compress([]byte) ([]byte, error)
Decompress([]byte) ([]byte, error)
}

// ZLibCompressor is a Compressor that uses zlib to compress / decompress byte arrays, this object is not thread-safe.
type ZLibCompressor struct{}
type ZLibCompressor struct {
// decompressLimit is the maximum number of bytes that can be decompressed.
// if <=0, no limit is applied.
decompressLimit int
}

// NewZLibCompressorWithLimit returns a new zlibDecompressor with the given decompression limit.
func NewZLibCompressorWithLimit(limit int) *ZLibCompressor {
return &ZLibCompressor{decompressLimit: limit}
}

// NewZLibCompressor returns a new zlibDecompressor.
// NewZLibCompressor returns a new zlibDecompressor with no limit.
func NewZLibCompressor() *ZLibCompressor {
return &ZLibCompressor{}
return NewZLibCompressorWithLimit(0)
}

// Compress compresses the given byte array using zlib. It returns an error if the compression fails.
// This function is not thread-safe, and uses zlib.BestCompression as the compression level.
func (c *ZLibCompressor) Compress(bz []byte) ([]byte, error) {
var b bytes.Buffer

if len(bz) > c.decompressLimit && c.decompressLimit > 0 {
return nil, fmt.Errorf("zlib compression limit reached")
}

// we use the best compression level as size reduction is prioritized
w := zlib.NewWriter(&b)
defer w.Close()
Expand All @@ -95,13 +142,16 @@ func (c *ZLibCompressor) Decompress(bz []byte) ([]byte, error) {
if len(bz) == 0 {
return nil, nil
}
r, err := zlib.NewReader(bytes.NewReader(bz))
zr, err := zlib.NewReader(bytes.NewReader(bz))
if err != nil {
return nil, err
}
r.Close()
defer zr.Close()

// read bytes and return
var r io.Reader = zr
if c.decompressLimit > 0 {
r = newLimitReaderWithError(r, c.decompressLimit)
}
return io.ReadAll(r)
}

Expand Down Expand Up @@ -215,3 +265,32 @@ func (codec *CompressionExtendedCommitCodec) Decode(bz []byte) (cometabci.Extend

return codec.codec.Decode(bz)
}

// limitReaderWithError is a io.Reader that reads up to n bytes from the underlying reader.
// Unlike io.LimitReader, this reader returns ErrZLibDecompressionLimit if >n bytes were read.
type limitReaderWithError struct {
r io.Reader
n int
}

func newLimitReaderWithError(r io.Reader, n int) *limitReaderWithError {
return &limitReaderWithError{r: r, n: n}
}

func (lr *limitReaderWithError) Read(p []byte) (int, error) {
if lr.n <= 0 {
var probe [1]byte
// read one extra byte to detect(trigger) EOF error
_, err := lr.r.Read(probe[:])
if err != nil {
return 0, err
}
return 0, ErrZLibDecompressionLimit
}
if len(p) > lr.n {
p = p[:lr.n]
}
n, err := lr.r.Read(p)
lr.n -= n
return n, err
}
52 changes: 52 additions & 0 deletions abci/strategies/codec/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,55 @@ func TestCompressionExtendedCommitCodec(t *testing.T) {
require.NoError(t, err)
})
}

func TestZLibCompressor_CompressDecompress_UnderLimit_NoError(t *testing.T) {
origLimit := 1000

comp := compression.NewZLibCompressorWithLimit(origLimit)

// Ensure payload size is strictly below the current limit.
payloadLen := origLimit / 2

payload := make([]byte, payloadLen)
compressed, err := comp.Compress(payload)
require.NoError(t, err)

// Decompress should not be truncated because limit is unchanged.
out, err := comp.Decompress(compressed)
require.NoError(t, err)
require.Equal(t, payload, out)
}

func TestZLibCompressor_Compress_OverLimit_Error(t *testing.T) {
origLimit := 1000

comp := compression.NewZLibCompressorWithLimit(origLimit)

payloadLen := origLimit + 1

payload := make([]byte, payloadLen)
_, err := comp.Compress(payload)
require.Error(t, err)
require.Equal(t, "zlib compression limit reached", err.Error())
}

func TestZLibCompressor_Decompress_OverLimit_Error(t *testing.T) {
// This test checks the expected behavior when the decompressed output would exceed the limit.
origLimit := 1000

comp := compression.NewZLibCompressorWithLimit(origLimit * 2)

payloadLen := origLimit

// Create compressed data that expands to more than decompressLimit bytes.
payload := make([]byte, payloadLen)
compressed, err := comp.Compress(payload)
require.NoError(t, err)

dec := compression.NewZLibCompressorWithLimit(origLimit / 2)

// Lower the decompression limit to trigger the error/short read path.
_, err = dec.Decompress(compressed)
require.Error(t, err)
require.ErrorIs(t, err, compression.ErrZLibDecompressionLimit)
}
3 changes: 2 additions & 1 deletion abci/strategies/codec/mocks/vote_extension_codec.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion abci/strategies/currencypair/mocks/mock_oracle_keeper.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion abci/types/mocks/mock_oracle_keeper.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading