diff --git a/README.md b/README.md index 3432e4a..41cfbdb 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,10 @@ Go implementation for client to interact with storage nodes in 0G Storage networ Following packages can help applications to integrate with 0g storage network: -- **[core](core)**: provides underlying utilities to build merkle tree for files or iterable data, and defines data padding standard to interact with [Flow contract](contract/contract.go). +- **[core](core)**: provides underlying utilities to build merkle tree for files or iterable data, defines data padding standard to interact with [Flow contract](contract/contract.go), and implements client-side AES-256-CTR encryption for file uploads. - **[node](node)**: defines RPC client structures to facilitate RPC interactions with 0g storage nodes and 0g key-value (KV) nodes. -- **[kv](kv)**: defines structures to interact with 0g storage kv. -- **[transfer](transfer)** : defines data structures and functions for transferring data between local and 0g storage. +- **[kv](kv)**: defines structures to interact with 0g storage kv, with optional stream data encryption via `UploadOption.EncryptionKey`. +- **[transfer](transfer)**: defines data structures and functions for transferring data between local and 0g storage, including encrypted upload/download support via `UploadOption.EncryptionKey` and `Downloader.WithEncryptionKey`. - **[indexer](indexer)**: select storage nodes to upload data from indexer which maintains trusted node list. Besides, allow clients to download files via HTTP GET requests. ## CLI @@ -53,6 +53,14 @@ To generate a file for test purpose, with a fixed file size or random file size The client will submit the data segments to the storage nodes which is determined by the indexer according to their shard configurations. +**Upload with encryption** + +Encrypt files client-side using AES-256-CTR before uploading. The encryption key is a hex-encoded 32-byte key with `0x` prefix: + +``` +./0g-storage-client upload --url --key --indexer --file --encryption-key <0x_hex_encoded_32_byte_key> +``` + **Download file** ``` ./0g-storage-client download --indexer --root --file @@ -60,6 +68,16 @@ The client will submit the data segments to the storage nodes which is determine If you want to verify the **merkle proof** of downloaded segment, please specify `--proof` option. +**Download with decryption** + +To download and decrypt a file that was uploaded with an encryption key: + +``` +./0g-storage-client download --indexer --root --file --encryption-key <0x_hex_encoded_32_byte_key> +``` + +The encryption key must match the one used during upload. + **Write to KV** By indexer: @@ -69,13 +87,21 @@ By indexer: `--stream-keys` and `--stream-values` are comma separated string list and their length must be equal. +**Write to KV with encryption** + +``` +./0g-storage-client kv-write --url --key --indexer --stream-id --stream-keys --stream-values --encryption-key <0x_hex_encoded_32_byte_key> +``` + +The entire stream data is encrypted client-side using AES-256-CTR before uploading. The KV node must be configured with the encryption key to decrypt and replay the data. + **Read from KV** ``` ./0g-storage-client kv-read --node --stream-id --stream-keys ``` -Please pay attention here `--node` is the url of a KV node. +Please pay attention here `--node` is the url of a KV node. If data was written with encryption, the KV node handles decryption during replay — no encryption key is needed for reading. ## Indexer diff --git a/cmd/download.go b/cmd/download.go index 5f0bfb8..2c23289 100644 --- a/cmd/download.go +++ b/cmd/download.go @@ -9,6 +9,7 @@ import ( "github.com/0gfoundation/0g-storage-client/indexer" "github.com/0gfoundation/0g-storage-client/node" "github.com/0gfoundation/0g-storage-client/transfer" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -23,6 +24,8 @@ type downloadArgument struct { roots []string proof bool + encryptionKey string + routines int timeout time.Duration @@ -43,6 +46,8 @@ func bindDownloadFlags(cmd *cobra.Command, args *downloadArgument) { cmd.Flags().BoolVar(&args.proof, "proof", false, "Whether to download with merkle proof for validation") + cmd.Flags().StringVar(&args.encryptionKey, "encryption-key", "", "Hex-encoded 32-byte AES-256 encryption key for file decryption") + cmd.Flags().IntVar(&args.routines, "routines", runtime.GOMAXPROCS(0), "number of go routines for downloading simultaneously") cmd.Flags().DurationVar(&args.timeout, "timeout", 0, "cli task timeout, 0 for no timeout") @@ -100,6 +105,18 @@ func download(*cobra.Command, []string) { logrus.WithError(err).Fatal("Failed to initialize downloader") } downloaderImpl.WithRoutines(downloadArgs.routines) + if downloadArgs.encryptionKey != "" { + keyBytes, err := hexutil.Decode(downloadArgs.encryptionKey) + if err != nil { + closer() + logrus.WithError(err).Fatal("Failed to decode encryption key") + } + if len(keyBytes) != 32 { + closer() + logrus.Fatal("Encryption key must be exactly 32 bytes (64 hex characters)") + } + downloaderImpl.WithEncryptionKey(keyBytes) + } downloader = downloaderImpl defer closer() } diff --git a/cmd/kv_write.go b/cmd/kv_write.go index 43f9c98..e1af278 100644 --- a/cmd/kv_write.go +++ b/cmd/kv_write.go @@ -13,6 +13,7 @@ import ( "github.com/0gfoundation/0g-storage-client/node" "github.com/0gfoundation/0g-storage-client/transfer" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -39,9 +40,10 @@ var ( fee float64 nonce uint - method string - fullTrusted bool - timeout time.Duration + method string + fullTrusted bool + timeout time.Duration + encryptionKey string } kvWriteCmd = &cobra.Command{ @@ -83,6 +85,7 @@ func init() { kvWriteCmd.Flags().UintVar(&kvWriteArgs.nonce, "nonce", 0, "nonce of upload transaction") kvWriteCmd.Flags().StringVar(&kvWriteArgs.method, "method", "random", "method for selecting nodes, can be max, min, random, or positive number, if provided a number, will fail if the requirement cannot be met") kvWriteCmd.Flags().BoolVar(&kvWriteArgs.fullTrusted, "full-trusted", false, "whether all selected nodes should be from trusted nodes") + kvWriteCmd.Flags().StringVar(&kvWriteArgs.encryptionKey, "encryption-key", "", "Hex-encoded 32-byte AES-256 encryption key for encrypting the stream data") rootCmd.AddCommand(kvWriteCmd) } @@ -111,6 +114,17 @@ func kvWrite(*cobra.Command, []string) { if kvWriteArgs.finalityRequired { finalityRequired = transfer.FileFinalized } + var encryptionKey []byte + if kvWriteArgs.encryptionKey != "" { + var err error + encryptionKey, err = hexutil.Decode(kvWriteArgs.encryptionKey) + if err != nil { + logrus.WithError(err).Fatal("Failed to decode encryption key") + } + if len(encryptionKey) != 32 { + logrus.Fatalf("Encryption key must be 32 bytes, got %d", len(encryptionKey)) + } + } opt := transfer.UploadOption{ FinalityRequired: finalityRequired, TaskSize: kvWriteArgs.taskSize, @@ -120,6 +134,7 @@ func kvWrite(*cobra.Command, []string) { Nonce: nonce, Method: kvWriteArgs.method, FullTrusted: kvWriteArgs.fullTrusted, + EncryptionKey: encryptionKey, } var clients *transfer.SelectedNodes diff --git a/cmd/upload.go b/cmd/upload.go index 109d9ae..8770f2f 100644 --- a/cmd/upload.go +++ b/cmd/upload.go @@ -64,6 +64,8 @@ type uploadArgument struct { timeout time.Duration + encryptionKey string + flowAddress string marketAddress string } @@ -97,6 +99,8 @@ func bindUploadFlags(cmd *cobra.Command, args *uploadArgument) { cmd.Flags().DurationVar(&args.timeout, "timeout", 0, "cli task timeout, 0 for no timeout") + cmd.Flags().StringVar(&args.encryptionKey, "encryption-key", "", "Hex-encoded 32-byte AES-256 encryption key for file encryption") + cmd.Flags().StringVar(&args.flowAddress, "flow-address", "", "Flow contract address (skip storage node status call when set)") cmd.Flags().StringVar(&args.marketAddress, "market-address", "", "Market contract address (optional, skip flow lookup when set)") } @@ -162,6 +166,18 @@ func upload(*cobra.Command, []string) { if uploadArgs.maxGasPrice > 0 { maxGasPrice = big.NewInt(int64(uploadArgs.maxGasPrice)) } + var encryptionKey []byte + if uploadArgs.encryptionKey != "" { + var err error + encryptionKey, err = hexutil.Decode(uploadArgs.encryptionKey) + if err != nil { + logrus.WithError(err).Fatal("Failed to decode encryption key") + } + if len(encryptionKey) != 32 { + logrus.Fatal("Encryption key must be exactly 32 bytes (64 hex characters)") + } + } + opt := transfer.UploadOption{ Submitter: submitter, Tags: hexutil.MustDecode(uploadArgs.tags), @@ -177,6 +193,7 @@ func upload(*cobra.Command, []string) { Method: uploadArgs.method, FullTrusted: uploadArgs.fullTrusted, FastMode: uploadArgs.fastMode, + EncryptionKey: encryptionKey, } file, err := core.Open(uploadArgs.file) diff --git a/core/encrypted_data.go b/core/encrypted_data.go new file mode 100644 index 0000000..4bfb776 --- /dev/null +++ b/core/encrypted_data.go @@ -0,0 +1,113 @@ +package core + +// EncryptedData wraps an IterableData with AES-256-CTR encryption. +// It prepends a 17-byte encryption header (version + nonce) to the data stream +// and encrypts the inner data on-the-fly during reads. +type EncryptedData struct { + inner IterableData + key [32]byte + header *EncryptionHeader + encryptedSize int64 + paddedSize uint64 +} + +var _ IterableData = (*EncryptedData)(nil) + +// NewEncryptedData creates an EncryptedData wrapper around the given data source. +// A random nonce is generated for the encryption header. +func NewEncryptedData(inner IterableData, key [32]byte) (*EncryptedData, error) { + header, err := NewEncryptionHeader() + if err != nil { + return nil, err + } + encryptedSize := inner.Size() + int64(EncryptionHeaderSize) + paddedSize := IteratorPaddedSize(encryptedSize, true) + + return &EncryptedData{ + inner: inner, + key: key, + header: header, + encryptedSize: encryptedSize, + paddedSize: paddedSize, + }, nil +} + +// Header returns the encryption header containing the version and nonce. +func (ed *EncryptedData) Header() *EncryptionHeader { + return ed.header +} + +func (ed *EncryptedData) NumChunks() uint64 { + return NumSplits(ed.encryptedSize, DefaultChunkSize) +} + +func (ed *EncryptedData) NumSegments() uint64 { + return NumSplits(ed.encryptedSize, DefaultSegmentSize) +} + +func (ed *EncryptedData) Size() int64 { + return ed.encryptedSize +} + +func (ed *EncryptedData) PaddedSize() uint64 { + return ed.paddedSize +} + +func (ed *EncryptedData) Offset() int64 { + return 0 +} + +// Read reads encrypted data at the given offset. +// For offsets within the header region (0..16), header bytes are returned. +// For offsets beyond the header, data is read from the inner source and encrypted. +func (ed *EncryptedData) Read(buf []byte, offset int64) (int, error) { + if offset < 0 || offset >= ed.encryptedSize { + return 0, nil + } + + headerSize := int64(EncryptionHeaderSize) + written := 0 + + // If offset falls within the header region + if offset < headerSize { + headerBytes := ed.header.ToBytes() + headerStart := int(offset) + headerEnd := int(headerSize) + if headerEnd > headerStart+len(buf) { + headerEnd = headerStart + len(buf) + } + n := headerEnd - headerStart + copy(buf[:n], headerBytes[headerStart:headerEnd]) + written += n + } + + // If we still have room in buf and there's data beyond the header + if written < len(buf) { + var dataOffset int64 + if offset < headerSize { + dataOffset = 0 + } else { + dataOffset = offset - headerSize + } + + remainingBuf := buf[written:] + innerRead, err := ed.inner.Read(remainingBuf, dataOffset) + if err != nil { + return written, err + } + + // Encrypt the data we just read + if innerRead > 0 { + CryptAt(&ed.key, &ed.header.Nonce, uint64(dataOffset), buf[written:written+innerRead]) + } + + written += innerRead + } + + return written, nil +} + +// Split returns the encrypted data as a single fragment (splitting is not supported for encrypted data). +func (ed *EncryptedData) Split(fragmentSize int64) []IterableData { + return []IterableData{ed} +} diff --git a/core/encrypted_data_test.go b/core/encrypted_data_test.go new file mode 100644 index 0000000..0b63b90 --- /dev/null +++ b/core/encrypted_data_test.go @@ -0,0 +1,196 @@ +package core + +import ( + "fmt" + "os" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncryptedDataSize(t *testing.T) { + original := make([]byte, 1000) + for i := range original { + original[i] = 1 + } + inner, err := NewDataInMemory(original) + require.NoError(t, err) + + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + encrypted, err := NewEncryptedData(inner, key) + require.NoError(t, err) + + assert.Equal(t, inner.Size()+int64(EncryptionHeaderSize), encrypted.Size()) +} + +func TestEncryptedDataReadHeader(t *testing.T) { + original := make([]byte, 100) + for i := range original { + original[i] = 1 + } + inner, err := NewDataInMemory(original) + require.NoError(t, err) + + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + encrypted, err := NewEncryptedData(inner, key) + require.NoError(t, err) + + // Read just the header + buf := make([]byte, EncryptionHeaderSize) + n, err := encrypted.Read(buf, 0) + require.NoError(t, err) + assert.Equal(t, EncryptionHeaderSize, n) + assert.Equal(t, byte(EncryptionVersion), buf[0]) + assert.Equal(t, encrypted.Header().Nonce[:], buf[1:17]) +} + +func TestEncryptedDataRoundtrip(t *testing.T) { + original := []byte("hello world encryption test with EncryptedData wrapper") + inner, err := NewDataInMemory(original) + require.NoError(t, err) + + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + encrypted, err := NewEncryptedData(inner, key) + require.NoError(t, err) + + // Read full encrypted stream + encryptedSize := int(encrypted.Size()) + encryptedBuf := make([]byte, encryptedSize) + n, err := encrypted.Read(encryptedBuf, 0) + require.NoError(t, err) + assert.Equal(t, encryptedSize, n) + + // Decrypt and verify + decrypted, err := DecryptFile(&key, encryptedBuf) + require.NoError(t, err) + assert.Equal(t, original, decrypted) +} + +func TestEncryptedDataReadAtOffset(t *testing.T) { + original := make([]byte, 500) + for i := range original { + original[i] = 0xAB + } + inner, err := NewDataInMemory(original) + require.NoError(t, err) + + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + encrypted, err := NewEncryptedData(inner, key) + require.NoError(t, err) + + // Read full encrypted data + encryptedSize := int(encrypted.Size()) + fullBuf := make([]byte, encryptedSize) + encrypted.Read(fullBuf, 0) + + // Read in two parts and verify they match + split := 100 + part1 := make([]byte, split) + part2 := make([]byte, encryptedSize-split) + encrypted.Read(part1, 0) + encrypted.Read(part2, int64(split)) + + assert.Equal(t, fullBuf[:split], part1) + assert.Equal(t, fullBuf[split:], part2) +} + +func TestEncryptedDataMerkleTreeConsistency(t *testing.T) { + // Verify that building a merkle tree on encrypted data works correctly + // and that the same encrypted data produces the same merkle root + original := make([]byte, 300) + for i := range original { + original[i] = 0x55 + } + inner, err := NewDataInMemory(original) + require.NoError(t, err) + + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + encrypted, err := NewEncryptedData(inner, key) + require.NoError(t, err) + + // Build merkle tree on encrypted data + tree, err := MerkleTree(encrypted) + require.NoError(t, err) + assert.NotEmpty(t, tree.Root()) + + // Read the full encrypted stream and build merkle tree on it as in-memory data + encryptedSize := int(encrypted.Size()) + encryptedBuf := make([]byte, encryptedSize) + n, err := encrypted.Read(encryptedBuf, 0) + require.NoError(t, err) + assert.Equal(t, encryptedSize, n) + + inMem, err := NewDataInMemory(encryptedBuf) + require.NoError(t, err) + inMemTree, err := MerkleTree(inMem) + require.NoError(t, err) + + // Both merkle trees should produce the same root + assert.Equal(t, tree.Root(), inMemTree.Root()) +} + +// TestEncryptedFileSubmissionRootConsistency verifies that MerkleTree root and +// CreateSubmission root match when EncryptedData wraps a File (not DataInMemory). +// This catches the bug where File.Read returned 0 on non-EOF reads, causing +// EncryptedData to skip encryption for partial reads (e.g., 1023-byte files). +func TestEncryptedFileSubmissionRootConsistency(t *testing.T) { + sizes := []int{1023, 1024, 1025, 256*4 - 17, 256*4 - 16, 256 * 5} + for _, size := range sizes { + t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { + // Write data to a temp file + original := make([]byte, size) + for i := range original { + original[i] = byte(i % 251) + } + tmpFile, err := os.CreateTemp("", "encrypted_test_*") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + _, err = tmpFile.Write(original) + require.NoError(t, err) + tmpFile.Close() + + // Open as File (IterableData) + file, err := Open(tmpFile.Name()) + require.NoError(t, err) + defer file.Close() + + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + encrypted, err := NewEncryptedData(file, key) + require.NoError(t, err) + + // Build MerkleTree root (reads full segments, hits EOF) + tree, err := MerkleTree(encrypted) + require.NoError(t, err) + + // Build submission root (reads in smaller chunks per node) + flow := NewFlow(encrypted, nil) + submission, err := flow.CreateSubmission(common.Address{}) + require.NoError(t, err) + + // These must match; if File.Read returns wrong count, + // encryption is skipped in CreateSubmission reads and roots diverge + assert.Equal(t, tree.Root(), submission.Root(), + "MerkleTree root and Submission root must match for file size %d", size) + }) + } +} diff --git a/core/encryption.go b/core/encryption.go new file mode 100644 index 0000000..10605de --- /dev/null +++ b/core/encryption.go @@ -0,0 +1,167 @@ +package core + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "fmt" +) + +const ( + // EncryptionHeaderSize is the size of the encryption header in bytes (1 byte version + 16 bytes nonce). + EncryptionHeaderSize = 17 + + // EncryptionVersion is the current encryption format version. + EncryptionVersion = 1 +) + +// EncryptionHeader stores the version and nonce for AES-256-CTR encryption. +type EncryptionHeader struct { + Version uint8 + Nonce [16]byte +} + +// NewEncryptionHeader creates a new encryption header with a random nonce. +func NewEncryptionHeader() (*EncryptionHeader, error) { + var nonce [16]byte + if _, err := rand.Read(nonce[:]); err != nil { + return nil, fmt.Errorf("failed to generate random nonce: %w", err) + } + return &EncryptionHeader{ + Version: EncryptionVersion, + Nonce: nonce, + }, nil +} + +// ParseEncryptionHeader extracts an encryption header from the given data. +func ParseEncryptionHeader(data []byte) (*EncryptionHeader, error) { + if len(data) < EncryptionHeaderSize { + return nil, fmt.Errorf("data too short for encryption header: %d < %d", len(data), EncryptionHeaderSize) + } + version := data[0] + if version != EncryptionVersion { + return nil, fmt.Errorf("unsupported encryption version: %d", version) + } + var nonce [16]byte + copy(nonce[:], data[1:17]) + return &EncryptionHeader{ + Version: version, + Nonce: nonce, + }, nil +} + +// ToBytes serializes the header to a fixed-size byte array. +func (h *EncryptionHeader) ToBytes() [EncryptionHeaderSize]byte { + var buf [EncryptionHeaderSize]byte + buf[0] = h.Version + copy(buf[1:17], h.Nonce[:]) + return buf +} + +// CryptAt encrypts or decrypts data in-place at a given byte offset within the plaintext stream. +// AES-256-CTR is symmetric: encrypt and decrypt are the same operation. +// The offset is the byte offset within the data stream (not counting the header). +func CryptAt(key *[32]byte, nonce *[16]byte, offset uint64, data []byte) { + if len(data) == 0 { + return + } + + block, err := aes.NewCipher(key[:]) + if err != nil { + panic(fmt.Sprintf("aes.NewCipher: %v", err)) // key is always 32 bytes + } + + blockSize := uint64(aes.BlockSize) + blockOffset := offset / blockSize + byteOffset := offset % blockSize + + // Compute the adjusted counter: nonce + blockOffset (big-endian 128-bit addition) + counter := make([]byte, 16) + copy(counter, nonce[:]) + addToCounter(counter, blockOffset) + + stream := cipher.NewCTR(block, counter) + + // Skip byteOffset bytes of keystream for sub-block alignment + if byteOffset > 0 { + skip := make([]byte, byteOffset) + stream.XORKeyStream(skip, skip) + } + + stream.XORKeyStream(data, data) +} + +// addToCounter adds a uint64 value to a big-endian 128-bit counter. +func addToCounter(counter []byte, val uint64) { + lo := binary.BigEndian.Uint64(counter[8:16]) + hi := binary.BigEndian.Uint64(counter[0:8]) + + newLo := lo + val + if newLo < lo { + hi++ // carry + } + + binary.BigEndian.PutUint64(counter[8:16], newLo) + binary.BigEndian.PutUint64(counter[0:8], hi) +} + +// EncryptBytes encrypts plaintext bytes using AES-256-CTR with a random nonce. +// Returns [1-byte version][16-byte nonce][ciphertext]. +func EncryptBytes(key *[32]byte, plaintext []byte) ([]byte, error) { + header, err := NewEncryptionHeader() + if err != nil { + return nil, err + } + headerBytes := header.ToBytes() + result := make([]byte, EncryptionHeaderSize+len(plaintext)) + copy(result[:EncryptionHeaderSize], headerBytes[:]) + copy(result[EncryptionHeaderSize:], plaintext) + CryptAt(key, &header.Nonce, 0, result[EncryptionHeaderSize:]) + return result, nil +} + +// DecryptBytes decrypts data produced by EncryptBytes: strips the header and decrypts. +// Returns the original plaintext. +func DecryptBytes(key *[32]byte, encrypted []byte) ([]byte, error) { + return DecryptFile(key, encrypted) +} + +// DecryptFile decrypts a full downloaded file: strips the header and decrypts the remaining bytes. +// Returns the decrypted data without the header. +func DecryptFile(key *[32]byte, encrypted []byte) ([]byte, error) { + if len(encrypted) < EncryptionHeaderSize { + return nil, fmt.Errorf("encrypted data too short") + } + header, err := ParseEncryptionHeader(encrypted) + if err != nil { + return nil, err + } + data := make([]byte, len(encrypted)-EncryptionHeaderSize) + copy(data, encrypted[EncryptionHeaderSize:]) + CryptAt(key, &header.Nonce, 0, data) + return data, nil +} + +// DecryptSegment decrypts a single downloaded segment. +// For segment 0: the first EncryptionHeaderSize bytes are the header, the rest is encrypted data starting at offset 0. +// For other segments: all bytes are encrypted data at the correct offset. +// segmentSize is the standard segment size (e.g. DefaultSegmentSize = 256KB). +func DecryptSegment(key *[32]byte, segmentIndex, segmentSize uint64, segmentData []byte, header *EncryptionHeader) []byte { + if segmentIndex == 0 { + // First segment: skip header bytes, decrypt the rest starting at data offset 0 + encrypted := segmentData[EncryptionHeaderSize:] + data := make([]byte, len(encrypted)) + copy(data, encrypted) + CryptAt(key, &header.Nonce, 0, data) + return data + } + + // Other segments: all bytes are encrypted data + // Data offset = segmentIndex * segmentSize - EncryptionHeaderSize + dataOffset := segmentIndex*segmentSize - uint64(EncryptionHeaderSize) + data := make([]byte, len(segmentData)) + copy(data, segmentData) + CryptAt(key, &header.Nonce, dataOffset, data) + return data +} diff --git a/core/encryption_test.go b/core/encryption_test.go new file mode 100644 index 0000000..6005335 --- /dev/null +++ b/core/encryption_test.go @@ -0,0 +1,249 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHeaderRoundtrip(t *testing.T) { + header, err := NewEncryptionHeader() + require.NoError(t, err) + + bytes := header.ToBytes() + parsed, err := ParseEncryptionHeader(bytes[:]) + require.NoError(t, err) + + assert.Equal(t, uint8(EncryptionVersion), parsed.Version) + assert.Equal(t, header.Nonce, parsed.Nonce) +} + +func TestCryptRoundtrip(t *testing.T) { + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + nonce := [16]byte{} + for i := range nonce { + nonce[i] = 0x13 + } + original := []byte("hello world encryption test data") + buf := make([]byte, len(original)) + copy(buf, original) + + // Encrypt + CryptAt(&key, &nonce, 0, buf) + assert.NotEqual(t, original, buf) + + // Decrypt (same operation for CTR) + CryptAt(&key, &nonce, 0, buf) + assert.Equal(t, original, buf) +} + +func TestCryptAtOffset(t *testing.T) { + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + nonce := [16]byte{} + for i := range nonce { + nonce[i] = 0x13 + } + original := make([]byte, 100) + + // Encrypt full + full := make([]byte, 100) + copy(full, original) + CryptAt(&key, &nonce, 0, full) + + // Encrypt in two parts at different offsets + part1 := make([]byte, 50) + part2 := make([]byte, 50) + copy(part1, original[:50]) + copy(part2, original[50:]) + CryptAt(&key, &nonce, 0, part1) + CryptAt(&key, &nonce, 50, part2) + + assert.Equal(t, full[:50], part1) + assert.Equal(t, full[50:], part2) +} + +func TestDecryptFile(t *testing.T) { + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + original := []byte("test data for encryption") + + // Build encrypted file: header + encrypted data + header, err := NewEncryptionHeader() + require.NoError(t, err) + + encryptedData := make([]byte, len(original)) + copy(encryptedData, original) + CryptAt(&key, &header.Nonce, 0, encryptedData) + + headerBytes := header.ToBytes() + encryptedFile := make([]byte, 0, EncryptionHeaderSize+len(encryptedData)) + encryptedFile = append(encryptedFile, headerBytes[:]...) + encryptedFile = append(encryptedFile, encryptedData...) + + // Decrypt + decrypted, err := DecryptFile(&key, encryptedFile) + require.NoError(t, err) + assert.Equal(t, original, decrypted) +} + +func TestDecryptSegmentZero(t *testing.T) { + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + header, err := NewEncryptionHeader() + require.NoError(t, err) + segmentSize := uint64(256 * 1024) // 256KB + + // Build segment 0: header + encrypted plaintext + plaintext := make([]byte, int(segmentSize)-EncryptionHeaderSize) + for i := range plaintext { + plaintext[i] = 0xAB + } + encrypted := make([]byte, len(plaintext)) + copy(encrypted, plaintext) + CryptAt(&key, &header.Nonce, 0, encrypted) + + headerBytes := header.ToBytes() + segmentData := make([]byte, 0, segmentSize) + segmentData = append(segmentData, headerBytes[:]...) + segmentData = append(segmentData, encrypted...) + assert.Equal(t, int(segmentSize), len(segmentData)) + + // decrypt_segment for segment 0 returns plaintext without header + decrypted := DecryptSegment(&key, 0, segmentSize, segmentData, header) + assert.Equal(t, plaintext, decrypted) +} + +func TestDecryptSegmentNonzero(t *testing.T) { + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + header, err := NewEncryptionHeader() + require.NoError(t, err) + segmentSize := uint64(256 * 1024) + + // Segment 1's data offset is segmentSize - HeaderSize + dataOffset := segmentSize - uint64(EncryptionHeaderSize) + plaintext := make([]byte, segmentSize) + for i := range plaintext { + plaintext[i] = 0xCD + } + encrypted := make([]byte, len(plaintext)) + copy(encrypted, plaintext) + CryptAt(&key, &header.Nonce, dataOffset, encrypted) + + decrypted := DecryptSegment(&key, 1, segmentSize, encrypted, header) + assert.Equal(t, plaintext, decrypted) +} + +func TestDecryptSegmentPaddedPreservesHeader(t *testing.T) { + // Simulates what download_segment_padded does for segment 0 + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + header, err := NewEncryptionHeader() + require.NoError(t, err) + segmentSize := uint64(256 * 1024) + + plaintext := make([]byte, int(segmentSize)-EncryptionHeaderSize) + for i := range plaintext { + plaintext[i] = 0xEF + } + encrypted := make([]byte, len(plaintext)) + copy(encrypted, plaintext) + CryptAt(&key, &header.Nonce, 0, encrypted) + + headerBytes := header.ToBytes() + rawSegment := make([]byte, 0, segmentSize) + rawSegment = append(rawSegment, headerBytes[:]...) + rawSegment = append(rawSegment, encrypted...) + + // Decrypt in-place after header (what download_segment_padded does) + result := make([]byte, len(rawSegment)) + copy(result, rawSegment) + CryptAt(&key, &header.Nonce, 0, result[EncryptionHeaderSize:]) + + // Header preserved, data decrypted + assert.Equal(t, headerBytes[:], result[:EncryptionHeaderSize]) + assert.Equal(t, plaintext, result[EncryptionHeaderSize:]) + assert.Equal(t, int(segmentSize), len(result)) +} + +func TestEncryptDecryptBytes(t *testing.T) { + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + + tests := []struct { + name string + data []byte + }{ + {"empty", []byte{}}, + {"short", []byte("hello")}, + {"binary", []byte{0x00, 0xff, 0x42, 0x13}}, + {"long", make([]byte, 10000)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encrypted, err := EncryptBytes(&key, tt.data) + require.NoError(t, err) + assert.Equal(t, EncryptionHeaderSize+len(tt.data), len(encrypted)) + + decrypted, err := DecryptBytes(&key, encrypted) + require.NoError(t, err) + assert.Equal(t, tt.data, decrypted) + }) + } +} + +func TestMultiSegmentDecryptMatchesFullFile(t *testing.T) { + // Encrypt a file spanning 2 segments, decrypt per-segment, verify matches full decrypt + key := [32]byte{} + for i := range key { + key[i] = 0x42 + } + header, err := NewEncryptionHeader() + require.NoError(t, err) + segmentSize := uint64(256) // Small for testing + + plaintext := make([]byte, int(segmentSize)*2-EncryptionHeaderSize) + for i := range plaintext { + plaintext[i] = 0x77 + } + fullEncrypted := make([]byte, len(plaintext)) + copy(fullEncrypted, plaintext) + CryptAt(&key, &header.Nonce, 0, fullEncrypted) + + // Build encrypted file + headerBytes := header.ToBytes() + fileData := make([]byte, 0, EncryptionHeaderSize+len(fullEncrypted)) + fileData = append(fileData, headerBytes[:]...) + fileData = append(fileData, fullEncrypted...) + + // Segment 0: first segmentSize bytes of the file + seg0Data := fileData[:segmentSize] + seg0Decrypted := DecryptSegment(&key, 0, segmentSize, seg0Data, header) + + // Segment 1: remaining bytes + seg1Data := fileData[segmentSize:] + seg1Decrypted := DecryptSegment(&key, 1, segmentSize, seg1Data, header) + + // Concatenated decrypted segments should equal original plaintext + combined := make([]byte, 0, len(plaintext)) + combined = append(combined, seg0Decrypted...) + combined = append(combined, seg1Decrypted...) + assert.Equal(t, plaintext, combined) +} diff --git a/core/file.go b/core/file.go index 771e549..284577b 100644 --- a/core/file.go +++ b/core/file.go @@ -30,7 +30,7 @@ var _ IterableData = (*File)(nil) func (file *File) Read(buf []byte, offset int64) (int, error) { n, err := file.underlying.ReadAt(buf, file.offset+offset) // unexpected IO error - if !errors.Is(err, io.EOF) { + if err != nil && !errors.Is(err, io.EOF) { return 0, err } return n, nil diff --git a/kv/batcher.go b/kv/batcher.go index 5a7803e..4bc6861 100644 --- a/kv/batcher.go +++ b/kv/batcher.go @@ -5,8 +5,10 @@ import ( zg_common "github.com/0gfoundation/0g-storage-client/common" "github.com/0gfoundation/0g-storage-client/core" + "github.com/0gfoundation/0g-storage-client/node" "github.com/0gfoundation/0g-storage-client/transfer" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/openweb3/web3go" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -15,6 +17,7 @@ import ( // Batcher struct to cache and execute KV write and access control operations. type Batcher struct { *streamDataBuilder + kvClient *Client clients *transfer.SelectedNodes w3Client *web3go.Client logger *logrus.Logger @@ -30,6 +33,33 @@ func NewBatcher(version uint64, clients *transfer.SelectedNodes, w3Client *web3g } } +// WithKVClient sets a KV client on the batcher to enable Get (read-your-own-writes with remote fallback). +func (b *Batcher) WithKVClient(kvClient *Client) *Batcher { + b.kvClient = kvClient + return b +} + +// Get returns the value for a key. It first checks the local write cache (uncommitted Set calls), +// and falls back to querying the KV node using the batcher's version. +func (b *Batcher) Get(ctx context.Context, streamId common.Hash, key []byte) (*node.Value, error) { + // Check local writes first + if keys, ok := b.writes[streamId]; ok { + if data, ok := keys[hexutil.Encode(key)]; ok { + return &node.Value{ + Version: b.version, + Data: data, + Size: uint64(len(data)), + }, nil + } + } + + // Fall back to KV node + if b.kvClient == nil { + return nil, errors.New("key not found locally and no KV client configured") + } + return b.kvClient.GetValue(ctx, streamId, key, b.version) +} + // Exec Serialize the cached KV operations in Batcher, then submit the serialized data to 0g storage network. // The submission process is the same as uploading a normal file. The batcher should be dropped after execution. // Note, this may be time consuming operation, e.g. several seconds or even longer. @@ -42,6 +72,11 @@ func (b *Batcher) Exec(ctx context.Context, option ...transfer.UploadOption) (co } encoded, err := streamData.Encode() + logrus.WithFields(logrus.Fields{ + "version": streamData.Version, + "data": encoded, + }).Debug("Built stream data") + if err != nil { return common.Hash{}, errors.WithMessage(err, "Failed to encode data") } diff --git a/kv/client.go b/kv/client.go index b675e88..ae9a957 100644 --- a/kv/client.go +++ b/kv/client.go @@ -50,7 +50,7 @@ func (c *Client) GetValue(ctx context.Context, streamId common.Hash, key []byte, for { var seg *node.Value seg, err = c.node.GetValue(ctx, streamId, key, uint64(len(val.Data)), maxQuerySize, val.Version) - + if err != nil { return } diff --git a/tests/Makefile b/tests/Makefile index 8b27cf5..04c27ad 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -226,14 +226,10 @@ stop: -@if [ -f "$(DATA_DIR)/reth.pid" ]; then \ kill $$(cat "$(DATA_DIR)/reth.pid") 2>/dev/null || true; \ rm -f "$(DATA_DIR)/reth.pid"; \ - else \ - pkill -f "$(GETH_BIN)" 2>/dev/null || true; \ fi -@if [ -f "$(DATA_DIR)/chaind.pid" ]; then \ kill $$(cat "$(DATA_DIR)/chaind.pid") 2>/dev/null || true; \ rm -f "$(DATA_DIR)/chaind.pid"; \ - else \ - pkill -f "$(CHAIND_BIN)" 2>/dev/null || true; \ fi @sleep 2 @printf "$(GREEN)All processes stopped!$(NC)\n" diff --git a/tests/cli_file_encrypted_upload_download_test.py b/tests/cli_file_encrypted_upload_download_test.py new file mode 100644 index 0000000..8edb5d9 --- /dev/null +++ b/tests/cli_file_encrypted_upload_download_test.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 + +import os +import random +import tempfile + +from config.node_config import GENESIS_ACCOUNT +from utility.utils import ( + wait_until, +) +from client_test_framework.test_framework import ClientTestFramework + +# Fixed 32-byte encryption key (hex-encoded with 0x prefix) +ENCRYPTION_KEY = "0x" + "ab" * 32 + + +class FileEncryptedUploadDownloadTest(ClientTestFramework): + def setup_params(self): + self.num_blockchain_nodes = 1 + self.num_nodes = 4 + self.zgs_node_configs[0] = { + "db_max_num_sectors": 2**30, + "shard_position": "0/4", + } + self.zgs_node_configs[1] = { + "db_max_num_sectors": 2**30, + "shard_position": "1/4", + } + self.zgs_node_configs[2] = { + "db_max_num_sectors": 2**30, + "shard_position": "2/4", + } + self.zgs_node_configs[3] = { + "db_max_num_sectors": 2**30, + "shard_position": "3/4", + } + + def run_test(self): + data_size = [ + 2, + 255, + 256, + 257, + 1023, + 1024, + 1025, + 256 * 960, + 256 * 1023, + 256 * 1024, + 256 * 1025, + 256 * 2048, + 256 * 16385, + 256 * 1024 * 64, + 256 * 480, + 256 * 1024 * 10, + 1000, + 256 * 960, + 256 * 100, + 256 * 960, + ] + + for i, v in enumerate(data_size): + self.__test_encrypted_upload_download_file(v, i + 1) + + def __test_encrypted_upload_download_file(self, size, submission_index): + self.log.info("encrypted file size: %d", size) + + file_to_upload = tempfile.NamedTemporaryFile(dir=self.root_dir, delete=False) + data = random.randbytes(size) + + file_to_upload.write(data) + file_to_upload.close() + + root = self._upload_file_use_cli( + self.blockchain_nodes[0].rpc_url, + GENESIS_ACCOUNT.key, + ",".join([x.rpc_url for x in self.nodes]), + None, + file_to_upload, + skip_tx=False, + encryption_key=ENCRYPTION_KEY, + ) + + self.log.info("root: %s", root) + wait_until(lambda: self.contract.num_submissions() == submission_index) + + for node_idx in range(4): + client = self.nodes[node_idx] + wait_until(lambda: client.zgs_get_file_info(root) is not None) + wait_until(lambda: client.zgs_get_file_info(root)["finalized"]) + + # Download with encryption key and verify decrypted content matches original + file_to_download = os.path.join( + self.root_dir, "download_enc_{}_{}".format(submission_index, size) + ) + self._download_file_use_cli( + ",".join([x.rpc_url for x in self.nodes]), + None, + root, + file_to_download=file_to_download, + with_proof=True, + remove=False, + encryption_key=ENCRYPTION_KEY, + ) + + with open(file_to_download, "rb") as f: + downloaded_data = f.read() + assert downloaded_data == data, "decrypted data mismatch for size %d" % size + os.remove(file_to_download) + + # Also test download without proof + self._download_file_use_cli( + ",".join([x.rpc_url for x in self.nodes]), + None, + root, + with_proof=False, + encryption_key=ENCRYPTION_KEY, + ) + + +if __name__ == "__main__": + FileEncryptedUploadDownloadTest().main() diff --git a/tests/cli_kv_encrypted_test.py b/tests/cli_kv_encrypted_test.py new file mode 100644 index 0000000..1565cab --- /dev/null +++ b/tests/cli_kv_encrypted_test.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 + +from config.node_config import GENESIS_ACCOUNT +from utility.utils import ( + assert_equal, + wait_until, +) +from client_utility.kv import to_stream_id +from client_test_framework.test_framework import ClientTestFramework + +# Fixed 32-byte encryption key (hex-encoded with 0x prefix) +ENCRYPTION_KEY = "0x" + "ab" * 32 +# Same key without 0x prefix for KV node config +ENCRYPTION_KEY_HEX = "ab" * 32 + + +class KVEncryptedTest(ClientTestFramework): + def setup_params(self): + self.num_blockchain_nodes = 1 + self.num_nodes = 1 + + def run_test(self): + # Set up KV node with encryption key configured for replay + self.setup_kv_node( + 0, + [to_stream_id(0)], + updated_config={"encryption_key": ENCRYPTION_KEY_HEX}, + ) + self.setup_indexer(self.nodes[0].rpc_url, self.nodes[0].rpc_url) + + # Write KV data with encryption via direct node + self._kv_write_use_cli( + self.blockchain_nodes[0].rpc_url, + GENESIS_ACCOUNT.key, + self.nodes[0].rpc_url, + None, + to_stream_id(0), + "0,1,2,3,4,5,6,7,8,9,10", + "0,1,2,3,4,5,6,7,8,9,10", + encryption_key=ENCRYPTION_KEY, + ) + + # Write KV data with encryption via indexer + self._kv_write_use_cli( + self.blockchain_nodes[0].rpc_url, + GENESIS_ACCOUNT.key, + None, + self.indexer_rpc_url, + to_stream_id(0), + "11,12,13,14,15,16,17,18,19,20", + "11,12,13,14,15,16,17,18,19,20", + encryption_key=ENCRYPTION_KEY, + ) + + # Wait for KV node to commit both transactions + wait_until(lambda: self.kv_nodes[0].kv_get_trasanction_result(0) == "Commit") + wait_until(lambda: self.kv_nodes[0].kv_get_trasanction_result(1) == "Commit") + + # Read back via CLI — no encryption key needed (KV node decrypts during replay) + res = self._kv_read_use_cli( + self.kv_nodes[0].rpc_url, + to_stream_id(0), + "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21", + ) + for i in range(21): + assert_equal(res[str(i)], str(i)) + # Key 21 was never written, should be empty + assert_equal(res["21"], "") + + +if __name__ == "__main__": + KVEncryptedTest().main() diff --git a/tests/client_test_framework/test_framework.py b/tests/client_test_framework/test_framework.py index 3267f66..96d908f 100644 --- a/tests/client_test_framework/test_framework.py +++ b/tests/client_test_framework/test_framework.py @@ -23,6 +23,7 @@ from utility.utils import ( PortMin, PortCategory, + PORT_RANGE, arrange_port, is_windows_platform, wait_until, @@ -117,6 +118,7 @@ def _upload_file_use_cli( file_to_upload, fragment_size=None, skip_tx=True, + encryption_key=None, ): upload_args = [ self.cli_binary, @@ -138,6 +140,9 @@ def _upload_file_use_cli( if fragment_size is not None: upload_args.append("--fragment-size") upload_args.append(str(fragment_size)) + if encryption_key is not None: + upload_args.append("--encryption-key") + upload_args.append(encryption_key) upload_args.append("--file") self.log.info("upload file with cli: {}".format(upload_args)) @@ -192,6 +197,7 @@ def _download_file_use_cli( file_to_download=None, with_proof=True, remove=True, + encryption_key=None, ): if file_to_download is None: file_to_download = os.path.join( @@ -213,6 +219,10 @@ def _download_file_use_cli( download_args.append("--roots") download_args.append(roots) + if encryption_key is not None: + download_args.append("--encryption-key") + download_args.append(encryption_key) + if node_rpc_url is not None: download_args.append("--node") download_args.append(node_rpc_url) @@ -267,6 +277,7 @@ def _kv_write_use_cli( kv_keys, kv_values, skip_tx=True, + encryption_key=None, ): kv_write_args = [ self.cli_binary, @@ -293,6 +304,9 @@ def _kv_write_use_cli( elif indexer_url is not None: kv_write_args.append("--indexer") kv_write_args.append(indexer_url) + if encryption_key is not None: + kv_write_args.append("--encryption-key") + kv_write_args.append(encryption_key) self.log.info("kv write with cli: {}".format(kv_write_args)) output = tempfile.NamedTemporaryFile( @@ -438,7 +452,9 @@ def setup_indexer(self, trusted, discover_node, discover_ports=None): if discover_node is not None: indexer_args.append("--node") indexer_args.append(discover_node) - self.log.info("start indexer with args: {}".format(indexer_args)) + + indexer_port = arrange_port(PortCategory.ZGS_INDEXER_RPC, 0) + self.log.info("start indexer [RPC: %d] with args: %s", indexer_port, indexer_args) data_dir = os.path.join(self.root_dir, "indexer0") os.mkdir(data_dir) stdout = tempfile.NamedTemporaryFile( @@ -501,6 +517,10 @@ def main(self): self.options = parser.parse_args() PortMin.n = self.options.port_min + # Calculate and log port range + port_max = self.options.port_min + PORT_RANGE + print(f"[PORT INFO] Test assigned port range: {self.options.port_min}-{port_max-1} (range size: {PORT_RANGE})", flush=True) + # Set up temp directory and start logging if self.options.tmpdir: self.options.tmpdir = os.path.abspath(self.options.tmpdir) diff --git a/tests/client_test_framework/zg_node.py b/tests/client_test_framework/zg_node.py deleted file mode 100644 index 089ca65..0000000 --- a/tests/client_test_framework/zg_node.py +++ /dev/null @@ -1,123 +0,0 @@ -import os -import subprocess -import tempfile - -from test_framework.blockchain_node import BlockChainNodeType, BlockchainNode -from utility.utils import ( - PortCategory, - arrange_port, - wait_until, -) -from utility.simple_rpc_proxy import SimpleRpcProxy - - -def _chain_data_dir() -> str: - return os.path.join("tmp", f"data_{arrange_port(PortCategory.ZG_ETH_HTTP, 0)}") - - -def _chain_make_args(root_dir: str, target: str) -> list[str]: - data_dir = _chain_data_dir() - return [ - "make", - target, - f"DATA_DIR={data_dir}", - f"ETH_HTTP_PORT={arrange_port(PortCategory.ZG_ETH_HTTP, 0)}", - f"ETH_WS_PORT={arrange_port(PortCategory.ZG_ETH_WS, 0)}", - f"ETH_METRICS_PORT={arrange_port(PortCategory.ZG_ETH_METRICS, 0)}", - f"AUTHRPC_PORT={arrange_port(PortCategory.ZG_AUTHRPC, 0)}", - f"CONSENSUS_RPC_PORT={arrange_port(PortCategory.ZG_CONSENSUS_RPC, 0)}", - f"CONSENSUS_P2P_PORT={arrange_port(PortCategory.ZG_CONSENSUS_P2P, 0)}", - f"NODE_API_PORT={arrange_port(PortCategory.ZG_NODE_API, 0)}", - f"P2P_PORT={arrange_port(PortCategory.ZG_P2P, 0)}", - f"DISCOVERY_PORT={arrange_port(PortCategory.ZG_DISCOVERY, 0)}", - ] - - -def zg_node_init_genesis(binary: str, root_dir: str, num_nodes: int): - assert num_nodes == 1, "Makefile deploy only supports one blockchain node" - - tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - os.environ.setdefault( - "ZGS_BLOCKCHAIN_RPC_ENDPOINT", - f"http://127.0.0.1:{arrange_port(PortCategory.ZG_ETH_HTTP, 0)}", - ) - - log_file = tempfile.NamedTemporaryFile( - dir=root_dir, delete=False, prefix="init_genesis_", suffix=".log" - ) - ret = subprocess.run( - args=_chain_make_args(root_dir, "deploy"), - cwd=tests_dir, - stdout=log_file, - stderr=log_file, - ) - log_file.close() - - assert ret.returncode == 0, ( - "Failed to deploy 0gchain genesis, see more details in log file: %s" - % log_file.name - ) - - -class ZGNode(BlockchainNode): - def __init__( - self, - index, - root_dir, - binary, - updated_config, - contract_path, - log, - rpc_timeout=10, - ): - assert index == 0, "Makefile start only supports one blockchain node" - - self._root_dir = root_dir - data_dir = os.path.join(root_dir, "0gchaind", "node" + str(index)) - rpc_url = os.environ.get( - "ZGS_BLOCKCHAIN_RPC_ENDPOINT", - f"http://127.0.0.1:{arrange_port(PortCategory.ZG_ETH_HTTP, 0)}", - ) - self._make_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - - super().__init__( - index, - data_dir, - rpc_url, - binary, - {}, - contract_path, - log, - BlockChainNodeType.ZG, - rpc_timeout, - ) - - def setup_config(self): - """Already initialized by Makefile deploy""" - - def start(self): - self.log.info("Starting 0gchaind via Makefile") - ret = subprocess.run( - args=_chain_make_args(self._root_dir, "start"), - cwd=self._make_dir, - ) - assert ret.returncode == 0, "Failed to start 0gchaind via Makefile" - self.running = True - - def stop(self, expected_stderr="", kill=False, wait=True): - ret = subprocess.run( - args=_chain_make_args(self._root_dir, "stop"), - cwd=self._make_dir, - ) - assert ret.returncode == 0, "Failed to stop 0gchaind via Makefile" - self.running = False - - def wait_for_rpc_connection(self): - rpc = SimpleRpcProxy(self.rpc_url, timeout=self.rpc_timeout) - - def check(): - return rpc.eth_syncing() is False - - wait_until(check, timeout=self.rpc_timeout) - self.rpc_connected = True - self.rpc = rpc diff --git a/tests/client_test_framework/zgs_node.py b/tests/client_test_framework/zgs_node.py deleted file mode 100644 index 6bfbd88..0000000 --- a/tests/client_test_framework/zgs_node.py +++ /dev/null @@ -1,140 +0,0 @@ -import os -import shutil -import base64 - -from config.node_config import ZGS_CONFIG, update_config -from test_framework.blockchain_node import NodeType, TestNode -from utility.utils import ( - PortCategory, - arrange_port, - initialize_toml_config, -) - - -class ZgsNode(TestNode): - def __init__( - self, - index, - root_dir, - binary, - updated_config, - log_contract_address, - mine_contract_address, - reward_contract_address, - log, - rpc_timeout=10, - libp2p_nodes=None, - ): - local_conf = ZGS_CONFIG.copy() - if libp2p_nodes is None: - if index == 0: - libp2p_nodes = [] - else: - libp2p_nodes = [] - for i in range(index): - libp2p_nodes.append( - f"/ip4/127.0.0.1/tcp/{arrange_port(PortCategory.ZGS_P2P, i)}" - ) - - rpc_listen_address = f"127.0.0.1:{arrange_port(PortCategory.ZGS_RPC, index)}" - grpc_listen_address = f"127.0.0.1:{arrange_port(PortCategory.ZGS_GRPC, index)}" - - indexed_config = { - "network_libp2p_port": arrange_port(PortCategory.ZGS_P2P, index), - "network_discovery_port": arrange_port(PortCategory.ZGS_P2P, index), - "rpc": { - "listen_address": rpc_listen_address, - "listen_address_admin": rpc_listen_address, - "listen_address_grpc": grpc_listen_address, - }, - "network_libp2p_nodes": libp2p_nodes, - "log_contract_address": log_contract_address, - "mine_contract_address": mine_contract_address, - "reward_contract_address": reward_contract_address, - "blockchain_rpc_endpoint": os.environ.get( - "ZGS_BLOCKCHAIN_RPC_ENDPOINT", - f"http://127.0.0.1:{arrange_port(PortCategory.ZG_ETH_HTTP, 0)}", - ), - } - # Set configs for this specific node. - update_config(local_conf, indexed_config) - # Overwrite with personalized configs. - update_config(local_conf, updated_config) - data_dir = os.path.join(root_dir, "zgs_node" + str(index)) - rpc_url = "http://" + rpc_listen_address - super().__init__( - NodeType.Zgs, - index, - data_dir, - rpc_url, - binary, - local_conf, - log, - rpc_timeout, - ) - - def setup_config(self): - os.mkdir(self.data_dir) - log_config_path = os.path.join(self.data_dir, self.config["log_config_file"]) - with open(log_config_path, "w") as f: - f.write("trace,hyper=info,h2=info") - - initialize_toml_config(self.config_file, self.config) - - def wait_for_rpc_connection(self): - self._wait_for_rpc_connection(lambda rpc: rpc.zgs_getStatus() is not None) - - def start(self): - self.log.info("Start zerog_storage node %d", self.index) - super().start() - - # rpc - def zgs_get_status(self): - return self.rpc.zgs_getStatus()["connectedPeers"] - - def zgs_upload_segment(self, segment): - return self.rpc.zgs_uploadSegment([segment]) - - def zgs_download_segment(self, data_root, start_index, end_index): - return self.rpc.zgs_downloadSegment([data_root, start_index, end_index]) - - def zgs_download_segment_decoded( - self, data_root: str, start_chunk_index: int, end_chunk_index: int - ) -> bytes: - encodedSegment = self.rpc.zgs_downloadSegment( - [data_root, start_chunk_index, end_chunk_index] - ) - return None if encodedSegment is None else base64.b64decode(encodedSegment) - - def zgs_get_file_info(self, data_root): - return self.rpc.zgs_getFileInfo([data_root, True]) - - def zgs_get_file_info_by_tx_seq(self, tx_seq): - return self.rpc.zgs_getFileInfoByTxSeq([tx_seq]) - - def shutdown(self): - self.rpc.admin_shutdown() - self.wait_until_stopped() - - def admin_start_sync_file(self, tx_seq): - return self.rpc.admin_startSyncFile([tx_seq]) - - def admin_start_sync_chunks( - self, tx_seq: int, start_chunk_index: int, end_chunk_index: int - ): - return self.rpc.admin_startSyncChunks( - [tx_seq, start_chunk_index, end_chunk_index] - ) - - def admin_get_sync_status(self, tx_seq): - return self.rpc.admin_getSyncStatus([tx_seq]) - - def sync_status_is_completed_or_unknown(self, tx_seq): - status = self.rpc.admin_getSyncStatus([tx_seq]) - return status == "Completed" or status == "unknown" - - def admin_get_file_location(self, tx_seq, all_shards=True): - return self.rpc.admin_getFileLocation([tx_seq, all_shards]) - - def clean_data(self): - shutil.rmtree(os.path.join(self.data_dir, "db")) diff --git a/tests/config/cosmos-genesis.json b/tests/config/cosmos-genesis.json index 0f0fd94..369d820 100644 --- a/tests/config/cosmos-genesis.json +++ b/tests/config/cosmos-genesis.json @@ -1,7 +1,7 @@ { "app_name": "0gchaind", "app_version": "v0.2.0-alpha.4-892-g6b920eb40", - "genesis_time": "2026-02-06T06:21:15.944985Z", + "genesis_time": "2026-02-12T09:47:03.20331Z", "chain_id": "0gchaind-local", "initial_height": 1, "app_hash": null, diff --git a/tests/test_all.py b/tests/test_all.py index 396bf89..1fcd5d1 100755 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -16,7 +16,7 @@ run_all( test_dir=os.path.dirname(__file__), - slow_tests={}, + slow_tests={"cli_file_upload_download_test.py", "cli_file_encrypted_upload_download_test.py"}, long_manual_tests={}, skip_tests={}, ) diff --git a/tests/test_framework/test_framework.py b/tests/test_framework/test_framework.py index f440714..48fc519 100644 --- a/tests/test_framework/test_framework.py +++ b/tests/test_framework/test_framework.py @@ -23,7 +23,7 @@ from test_framework.zgs_node import ZgsNode from test_framework.blockchain_node import BlockChainNodeType from test_framework.zg_node import ZGNode, zg_node_init_genesis -from utility.utils import PortMin, is_windows_platform, wait_until, assert_equal +from utility.utils import PortMin, PORT_RANGE, is_windows_platform, wait_until, assert_equal from utility.build_binary import build_cli from utility.submission import create_submission, submit_data @@ -427,6 +427,10 @@ def main(self): self.options = parser.parse_args() PortMin.n = self.options.port_min + # Calculate and log port range + port_max = self.options.port_min + PORT_RANGE + print(f"[PORT INFO] Test assigned port range: {self.options.port_min}-{port_max-1} (range size: {PORT_RANGE})", flush=True) + # Set up temp directory and start logging if self.options.tmpdir: self.options.tmpdir = os.path.abspath(self.options.tmpdir) diff --git a/tests/test_framework/zg_node.py b/tests/test_framework/zg_node.py index c5ee06c..dee4fca 100644 --- a/tests/test_framework/zg_node.py +++ b/tests/test_framework/zg_node.py @@ -38,8 +38,9 @@ def zg_node_init_genesis(binary: str, root_dir: str, num_nodes: int): assert num_nodes == 1, "Makefile deploy only supports one blockchain node" tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - os.environ["ZGS_BLOCKCHAIN_RPC_ENDPOINT"] = ( - f"http://127.0.0.1:{arrange_port(PortCategory.ZG_ETH_HTTP, 0)}" + os.environ.setdefault( + "ZGS_BLOCKCHAIN_RPC_ENDPOINT", + f"http://127.0.0.1:{arrange_port(PortCategory.ZG_ETH_HTTP, 0)}", ) log_file = tempfile.NamedTemporaryFile( @@ -73,12 +74,11 @@ def __init__( assert index == 0, "Makefile start only supports one blockchain node" self._root_dir = root_dir - os.environ.setdefault( + data_dir = os.path.join(root_dir, "0gchaind", "node" + str(index)) + rpc_url = os.environ.get( "ZGS_BLOCKCHAIN_RPC_ENDPOINT", f"http://127.0.0.1:{arrange_port(PortCategory.ZG_ETH_HTTP, 0)}", ) - data_dir = os.path.join(root_dir, "0gchaind", "node" + str(index)) - rpc_url = os.environ.get("ZGS_BLOCKCHAIN_RPC_ENDPOINT", "http://127.0.0.1:8545") self._make_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) super().__init__( @@ -97,7 +97,11 @@ def setup_config(self): """Already initialized by Makefile deploy""" def start(self): - self.log.info("Starting 0gchaind via Makefile") + eth_http = arrange_port(PortCategory.ZG_ETH_HTTP, 0) + eth_ws = arrange_port(PortCategory.ZG_ETH_WS, 0) + consensus_rpc = arrange_port(PortCategory.ZG_CONSENSUS_RPC, 0) + self.log.info("Starting 0gchaind via Makefile [ETH_HTTP: %d, ETH_WS: %d, CONSENSUS_RPC: %d]", + eth_http, eth_ws, consensus_rpc) ret = subprocess.run( args=_chain_make_args(self._root_dir, "start"), cwd=self._make_dir, diff --git a/tests/test_framework/zgs_node.py b/tests/test_framework/zgs_node.py index 6bfbd88..b469bb4 100644 --- a/tests/test_framework/zgs_node.py +++ b/tests/test_framework/zgs_node.py @@ -85,7 +85,9 @@ def wait_for_rpc_connection(self): self._wait_for_rpc_connection(lambda rpc: rpc.zgs_getStatus() is not None) def start(self): - self.log.info("Start zerog_storage node %d", self.index) + rpc_port = arrange_port(PortCategory.ZGS_RPC, self.index) + p2p_port = arrange_port(PortCategory.ZGS_P2P, self.index) + self.log.info("Start zerog_storage node %d [RPC: %d, P2P: %d]", self.index, rpc_port, p2p_port) super().start() # rpc diff --git a/transfer/download_parallel.go b/transfer/download_parallel.go index a0bb452..e8634a8 100644 --- a/transfer/download_parallel.go +++ b/transfer/download_parallel.go @@ -39,8 +39,8 @@ func newSegmentDownloader(downloader *Downloader, info *node.FileInfo, file *dow endSegmentIndex := (info.Tx.StartEntryIndex + core.NumSplits(int64(info.Tx.Size), core.DefaultChunkSize) - 1) / core.DefaultSegmentMaxChunks logrus.WithFields(logrus.Fields{ - "size": info.Tx.Size, - "startEntryIndex": info.Tx.StartEntryIndex, + "size": info.Tx.Size, + "startEntryIndex": info.Tx.StartEntryIndex, "numChunks": core.NumSplits(int64(info.Tx.Size), core.DefaultChunkSize), "startSegmentIndex": startSegmentIndex, "endSegmentIndex": endSegmentIndex, @@ -83,7 +83,7 @@ func (downloader *segmentDownloader) ParallelDo(ctx context.Context, routine, ta // there is no not-aligned & segment-crossed file startIndex := segmentIndex * core.DefaultSegmentMaxChunks endIndex := startIndex + core.DefaultSegmentMaxChunks - + if endIndex > downloader.numChunks { endIndex = downloader.numChunks } diff --git a/transfer/downloader.go b/transfer/downloader.go index 70c817c..e189ecb 100644 --- a/transfer/downloader.go +++ b/transfer/downloader.go @@ -33,6 +33,8 @@ type Downloader struct { routines int + encryptionKey []byte // optional 32-byte AES-256 decryption key + logger *logrus.Logger } @@ -54,6 +56,13 @@ func (downloader *Downloader) WithRoutines(routines int) *Downloader { return downloader } +// WithEncryptionKey sets the encryption key for post-download decryption. +// The key must be exactly 32 bytes (AES-256). +func (downloader *Downloader) WithEncryptionKey(key []byte) *Downloader { + downloader.encryptionKey = key + return downloader +} + func (downloader *Downloader) DownloadFragments(ctx context.Context, roots []string, filename string, withProof bool) error { outFile, err := os.Create(filename) if err != nil { @@ -111,6 +120,13 @@ func (downloader *Downloader) Download(ctx context.Context, root, filename strin return errors.WithMessage(err, "Failed to validate downloaded file") } + // Decrypt the file if an encryption key is set + if len(downloader.encryptionKey) > 0 { + if err = downloader.decryptDownloadedFile(filename); err != nil { + return errors.WithMessage(err, "Failed to decrypt downloaded file") + } + } + return nil } @@ -207,3 +223,29 @@ func (downloader *Downloader) validateDownloadFile(root, filename string, fileSi return nil } + +func (downloader *Downloader) decryptDownloadedFile(filename string) error { + if len(downloader.encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes") + } + + encrypted, err := os.ReadFile(filename) + if err != nil { + return errors.WithMessage(err, "Failed to read encrypted file") + } + + var key [32]byte + copy(key[:], downloader.encryptionKey) + decrypted, err := core.DecryptFile(&key, encrypted) + if err != nil { + return errors.WithMessage(err, "Failed to decrypt file") + } + + if err := os.WriteFile(filename, decrypted, 0644); err != nil { + return errors.WithMessage(err, "Failed to write decrypted file") + } + + downloader.logger.Info("Succeeded to decrypt the downloaded file") + + return nil +} diff --git a/transfer/uploader.go b/transfer/uploader.go index aee31a0..9ef7121 100644 --- a/transfer/uploader.go +++ b/transfer/uploader.go @@ -59,6 +59,7 @@ type UploadOption struct { Step int64 // step for uploading Method string // method for selecting nodes, can be "max", "random" or certain positive number in string FullTrusted bool // whether to use full trusted nodes + EncryptionKey []byte // optional 32-byte AES-256 encryption key; when set, data is encrypted before upload } // SubmitLogEntryOption option for submitting log entry @@ -223,6 +224,21 @@ func (uploader *Uploader) Upload(ctx context.Context, data core.IterableData, op } opt.Submitter = submitter } + // Wrap data with encryption if an encryption key is provided + if len(opt.EncryptionKey) > 0 { + if len(opt.EncryptionKey) != 32 { + return common.Hash{}, common.Hash{}, errors.New("encryption key must be 32 bytes") + } + var key [32]byte + copy(key[:], opt.EncryptionKey) + encData, err := core.NewEncryptedData(data, key) + if err != nil { + return common.Hash{}, common.Hash{}, errors.WithMessage(err, "Failed to create encrypted data") + } + data = encData + uploader.logger.Info("Data encryption enabled") + } + fastMode := opt.FastMode && data.Size() <= fastUploadMaxSize if opt.FastMode && !fastMode { uploader.logger.WithField("size", data.Size()).Info("Fast mode disabled for data size over limit")