Skip to content

Commit

Permalink
sstable: fix restarts integer overflow
Browse files Browse the repository at this point in the history
Fix bug with integer overflows while indexing into blocks with large KVs in SeekGE() and SeekLT() in block_iter. Updated members in blockEntry that represent offsets in blocks to be type offsetInBlock (alias for int64).

Added check in rowblk_writer to ensure that block sizes do not exceed MaximumBlockSize before writing more data to the block.

Wrote unit tests to verify correct behavior for SeekGE() and SeekLT() with large blocks >2GB.
  • Loading branch information
EdwardX29 committed Jan 21, 2025
1 parent 19b47dc commit 13a815e
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 53 deletions.
89 changes: 57 additions & 32 deletions sstable/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ func (w *blockWriter) storeWithOptionalValuePrefix(
valuePrefix valuePrefix,
setHasSameKeyPrefix bool,
) {
if len(w.buf) >= MaximumBlockSize {
panic(errors.AssertionFailedf("block: adding KV to %d-block; already exceeds %d-byte maximum",
len(w.buf), MaximumBlockSize))
}

shared := 0
if !setHasSameKeyPrefix {
w.setHasSameKeyPrefixSinceLastRestart = false
Expand Down Expand Up @@ -281,14 +286,6 @@ func (w *blockWriter) estimatedSize() int {
return len(w.buf) + 4*len(w.restarts) + emptyBlockSize
}

type blockEntry struct {
offset int32
keyStart int32
keyEnd int32
valStart int32
valSize int32
}

// blockIter is an iterator over a single block of data.
//
// A blockIter provides an additional guarantee around key stability when a
Expand Down Expand Up @@ -343,10 +340,10 @@ type blockIter struct {
cmp Compare
// offset is the byte index that marks where the current key/value is
// encoded in the block.
offset int32
offset offsetInBlock
// nextOffset is the byte index where the next key/value is encoded in the
// block.
nextOffset int32
nextOffset offsetInBlock
// A "restart point" in a block is a point where the full key is encoded,
// instead of just having a suffix of the key encoded. See readEntry() for
// how prefix compression of keys works. Keys in between two restart points
Expand All @@ -359,7 +356,10 @@ type blockIter struct {
// 4 bytes of the block as a uint32 (i.ptr[len(block)-4:]). i.restarts can
// therefore be seen as the point where data in the block ends, and a list
// of offsets of all restart points begins.
restarts int32
//
// int64 is used to prevent overflow and preserve signedness for binary
// search invariants.
restarts offsetInBlock
// Number of restart points in this block. Encoded at the end of the block
// as a uint32.
numRestarts int32
Expand Down Expand Up @@ -411,6 +411,27 @@ type blockIter struct {
hideObsoletePoints bool
}

// offsetInBlock represents an offset in a block
//
// While restart points are serialized as uint32's, it is possible for offsets to
// be greater than math.MaxUint32 since they may point to an offset after the KVs.
//
// Previously, offsets were represented as int32, which causes problems with
// integer overflows while indexing into blocks (i.data) with large KVs in SeekGE()
// and SeekLT(). Using an int64 solves the problem of overflows as wraparounds will
// be prevented. Additionally, the signedness of int64 allows repsentation of
// iterators that have conducted backward interation and allows for binary search
// invariants in SeekGE() and SeekLT() to be preserved.
type offsetInBlock int64

type blockEntry struct {
offset offsetInBlock
keyStart offsetInBlock
keyEnd offsetInBlock
valStart offsetInBlock
valSize uint32
}

// blockIter implements the base.InternalIterator interface.
var _ base.InternalIterator = (*blockIter)(nil)

Expand All @@ -431,7 +452,7 @@ func (i *blockIter) init(
return base.CorruptionErrorf("pebble/table: invalid table (block has no restart points)")
}
i.cmp = cmp
i.restarts = int32(len(block)) - 4*(1+numRestarts)
i.restarts = offsetInBlock(len(block)) - 4*(1+offsetInBlock(numRestarts))
i.numRestarts = numRestarts
i.globalSeqNum = globalSeqNum
i.ptr = unsafe.Pointer(&block[0])
Expand Down Expand Up @@ -570,7 +591,7 @@ func (i *blockIter) readEntry() {
}
ptr = unsafe.Pointer(uintptr(ptr) + uintptr(unshared))
i.val = getBytes(ptr, int(value))
i.nextOffset = int32(uintptr(ptr)-uintptr(i.ptr)) + int32(value)
i.nextOffset = offsetInBlock(uintptr(ptr)-uintptr(i.ptr)) + offsetInBlock(value)
}

func (i *blockIter) readFirstKey() error {
Expand Down Expand Up @@ -665,16 +686,16 @@ func (i *blockIter) clearCache() {
}

func (i *blockIter) cacheEntry() {
var valStart int32
valSize := int32(len(i.val))
var valStart offsetInBlock
valSize := uint32(len(i.val))
if valSize > 0 {
valStart = int32(uintptr(unsafe.Pointer(&i.val[0])) - uintptr(i.ptr))
valStart = offsetInBlock(uintptr(unsafe.Pointer(&i.val[0])) - uintptr(i.ptr))
}

i.cached = append(i.cached, blockEntry{
offset: i.offset,
keyStart: int32(len(i.cachedBuf)),
keyEnd: int32(len(i.cachedBuf) + len(i.key)),
keyStart: offsetInBlock(len(i.cachedBuf)),
keyEnd: offsetInBlock(len(i.cachedBuf) + len(i.key)),
valStart: valStart,
valSize: valSize,
})
Expand Down Expand Up @@ -706,8 +727,9 @@ func (i *blockIter) SeekGE(key []byte, flags base.SeekGEFlags) (*InternalKey, ba
upper := i.numRestarts
for index < upper {
h := int32(uint(index+upper) >> 1) // avoid overflow when computing h

// index ≤ h < upper
offset := decodeRestart(i.data[i.restarts+4*h:])
offset := decodeRestart(i.data[i.restarts+4*offsetInBlock(h):])
// For a restart point, there are 0 bytes shared with the previous key.
// The varint encoding of 0 occupies 1 byte.
ptr := unsafe.Pointer(uintptr(i.ptr) + uintptr(offset+1))
Expand Down Expand Up @@ -779,7 +801,7 @@ func (i *blockIter) SeekGE(key []byte, flags base.SeekGEFlags) (*InternalKey, ba
// could be equal to the search key. If index == 0, then all keys in this
// block are larger than the key sought, and offset remains at zero.
if index > 0 {
i.offset = decodeRestart(i.data[i.restarts+4*(index-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index-1):])
}
i.readEntry()
hiddenPoint := i.decodeInternalKey(i.key)
Expand Down Expand Up @@ -839,8 +861,9 @@ func (i *blockIter) SeekLT(key []byte, flags base.SeekLTFlags) (*InternalKey, ba
upper := i.numRestarts
for index < upper {
h := int32(uint(index+upper) >> 1) // avoid overflow when computing h

// index ≤ h < upper
offset := decodeRestart(i.data[i.restarts+4*h:])
offset := decodeRestart(i.data[i.restarts+4*offsetInBlock(h):])
// For a restart point, there are 0 bytes shared with the previous key.
// The varint encoding of 0 occupies 1 byte.
ptr := unsafe.Pointer(uintptr(i.ptr) + uintptr(offset+1))
Expand Down Expand Up @@ -914,9 +937,9 @@ func (i *blockIter) SeekLT(key []byte, flags base.SeekLTFlags) (*InternalKey, ba
// are larger than the search key, so there is no match.
targetOffset := i.restarts
if index > 0 {
i.offset = decodeRestart(i.data[i.restarts+4*(index-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index-1):])
if index < i.numRestarts {
targetOffset = decodeRestart(i.data[i.restarts+4*(index):])
targetOffset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index):])
}
} else if index == 0 {
// If index == 0 then all keys in this block are larger than the key
Expand Down Expand Up @@ -1017,9 +1040,9 @@ func (i *blockIter) First() (*InternalKey, base.LazyValue) {
return &i.ikey, i.lazyValue
}

func decodeRestart(b []byte) int32 {
func decodeRestart(b []byte) offsetInBlock {
_ = b[3] // bounds check hint to compiler; see golang.org/issue/14808
return int32(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 |
return offsetInBlock(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 |
uint32(b[3]&restartMaskLittleEndianHighByteWithoutSetHasSamePrefix)<<24)
}

Expand All @@ -1030,7 +1053,7 @@ func (i *blockIter) Last() (*InternalKey, base.LazyValue) {
}

// Seek forward from the last restart point.
i.offset = decodeRestart(i.data[i.restarts+4*(i.numRestarts-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(i.numRestarts-1):])
if !i.valid() {
return nil, base.LazyValue{}
}
Expand Down Expand Up @@ -1227,7 +1250,7 @@ func (i *blockIter) nextPrefixV3(succKey []byte) (*InternalKey, base.LazyValue)
}
// The starting position of the value.
valuePtr := unsafe.Pointer(uintptr(ptr) + uintptr(unshared))
i.nextOffset = int32(uintptr(valuePtr)-uintptr(i.ptr)) + int32(value)
i.nextOffset = offsetInBlock(uintptr(valuePtr)-uintptr(i.ptr)) + offsetInBlock(value)
if invariants.Enabled && unshared < 8 {
// This should not happen since only the key prefix is shared, so even
// if the prefix length is the same as the user key length, the unshared
Expand Down Expand Up @@ -1273,8 +1296,9 @@ func (i *blockIter) nextPrefixV3(succKey []byte) (*InternalKey, base.LazyValue)
upper := i.numRestarts
for index < upper {
h := int32(uint(index+upper) >> 1) // avoid overflow when computing h

// index ≤ h < upper
offset := decodeRestart(i.data[i.restarts+4*h:])
offset := decodeRestart(i.data[i.restarts+4*offsetInBlock(h):])
if offset < targetOffset {
index = h + 1 // preserves f(index-1) == false
} else {
Expand All @@ -1301,7 +1325,7 @@ func (i *blockIter) nextPrefixV3(succKey []byte) (*InternalKey, base.LazyValue)
// most significant bit of the 3rd byte is what we use for
// encoding the set-has-same-prefix information, the indexing
// below has +3.
i.data[i.restarts+4*index+3]&restartMaskLittleEndianHighByteOnlySetHasSamePrefix != 0 {
i.data[i.restarts+4*offsetInBlock(index)+3]&restartMaskLittleEndianHighByteOnlySetHasSamePrefix != 0 {
// We still have the same prefix, so move to the next restart.
index++
}
Expand All @@ -1310,7 +1334,7 @@ func (i *blockIter) nextPrefixV3(succKey []byte) (*InternalKey, base.LazyValue)
// Managed to skip past at least one restart. Resume iteration
// from index-1. Since nextFastCount has been reset to 0, we
// should be able to iterate to the next prefix.
i.offset = decodeRestart(i.data[i.restarts+4*(index-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index-1):])
i.readEntry()
}
// Else, unable to skip past any restart. Resume iteration. Since
Expand Down Expand Up @@ -1455,8 +1479,9 @@ start:
upper := i.numRestarts
for index < upper {
h := int32(uint(index+upper) >> 1) // avoid overflow when computing h

// index ≤ h < upper
offset := decodeRestart(i.data[i.restarts+4*h:])
offset := decodeRestart(i.data[i.restarts+4*offsetInBlock(h):])
if offset < targetOffset {
// Looking for the first restart that has offset >= targetOffset, so
// ignore h and earlier.
Expand All @@ -1477,7 +1502,7 @@ start:
// as the index).
i.offset = 0
if index > 0 {
i.offset = decodeRestart(i.data[i.restarts+4*(index-1):])
i.offset = decodeRestart(i.data[i.restarts+4*offsetInBlock(index-1):])
}
// TODO(sumeer): why is the else case not an error given targetOffset is a
// valid offset.
Expand Down
Loading

0 comments on commit 13a815e

Please sign in to comment.