diff --git a/zstd.go b/zstd.go index 2cf5c61..0b64c1d 100644 --- a/zstd.go +++ b/zstd.go @@ -7,13 +7,30 @@ package zstd #cgo CFLAGS: -DZSTD_LEGACY_SUPPORT=4 -DZSTD_MULTITHREAD=1 #include "zstd.h" + +// These wrappers converting void* arguments to char* are required to convince +// Go that there are no pointers hiding in there, and that slow runtime checks +// are unnecessary. + +size_t ZSwrapper_compress(unsigned char* dst, size_t dstCapacity, const unsigned char* src, size_t srcSize, int compressionLevel) +{ + return ZSTD_compress(dst, dstCapacity, src, srcSize, compressionLevel); +} + +size_t ZSwrapper_decompress(unsigned char* dst, size_t dstCapacity, const unsigned char* src, size_t srcSize) +{ + return ZSTD_decompress(dst, dstCapacity, src, srcSize); +} + +unsigned long long ZSwrapper_getFrameContentSize(const unsigned char *src, size_t srcSize) { + return ZSTD_getFrameContentSize(src, srcSize); +} */ import "C" import ( "bytes" "errors" "io/ioutil" - "unsafe" ) // Defines best and standard values for zstd cli @@ -68,7 +85,7 @@ func decompressSizeHint(src []byte) int { hint := upperBound if len(src) >= zstdFrameHeaderSizeMin { - hint = int(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))) + hint = int(C.ZSwrapper_getFrameContentSize((*C.uchar)(&src[0]), C.size_t(len(src)))) if hint < 0 { // On error, just use upperBound hint = upperBound } @@ -100,22 +117,19 @@ func CompressLevel(dst, src []byte, level int) ([]byte, error) { dst = make([]byte, bound) } - // We need unsafe.Pointer(&src[0]) in the Cgo call to avoid "Go pointer to Go pointer" panics. - // This means we need to special case empty input. See: - // https://github.com/golang/go/issues/14210#issuecomment-346402945 var cWritten C.size_t if len(src) == 0 { - cWritten = C.ZSTD_compress( - unsafe.Pointer(&dst[0]), + cWritten = C.ZSwrapper_compress( + (*C.uchar)(&dst[0]), C.size_t(len(dst)), - unsafe.Pointer(nil), + (*C.uchar)(nil), C.size_t(0), C.int(level)) } else { - cWritten = C.ZSTD_compress( - unsafe.Pointer(&dst[0]), + cWritten = C.ZSwrapper_compress( + (*C.uchar)(&dst[0]), C.size_t(len(dst)), - unsafe.Pointer(&src[0]), + (*C.uchar)(&src[0]), C.size_t(len(src)), C.int(level)) } @@ -143,10 +157,10 @@ func Decompress(dst, src []byte) ([]byte, error) { dst = make([]byte, bound) } - written := int(C.ZSTD_decompress( - unsafe.Pointer(&dst[0]), + written := int(C.ZSwrapper_decompress( + (*C.uchar)(&dst[0]), C.size_t(len(dst)), - unsafe.Pointer(&src[0]), + (*C.uchar)(&src[0]), C.size_t(len(src)))) err := getError(written) if err == nil { diff --git a/zstd_test.go b/zstd_test.go index 0253537..cb57698 100644 --- a/zstd_test.go +++ b/zstd_test.go @@ -336,7 +336,9 @@ func BenchmarkDecompression(b *testing.B) { if err != nil { b.Fatalf("Failed compressing: %s", err) } - b.Logf("Reduced from %v to %v", len(raw), len(dst)) + if b.N == 1 { + b.Logf("Reduced from %v to %v", len(raw), len(dst)) + } b.SetBytes(int64(len(raw))) b.ResetTimer() for i := 0; i < b.N; i++ {