diff --git a/Cargo.lock b/Cargo.lock index 5fd6a45bb..643429656 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4682,7 +4682,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" name = "smt_trie" version = "0.1.1" dependencies = [ - "ethereum-types", + "alloy", "hex-literal", "plonky2", "rand", diff --git a/evm_arithmetization/src/witness/operation.rs b/evm_arithmetization/src/witness/operation.rs index c8e2a5376..7f674cdbf 100644 --- a/evm_arithmetization/src/witness/operation.rs +++ b/evm_arithmetization/src/witness/operation.rs @@ -250,7 +250,7 @@ pub(crate) fn generate_poseidon_general>( let hash = hashout2u(poseidon_hash_padded_byte_vec(input.clone())); - push_no_write(generation_state, hash); + push_no_write(generation_state, hash.into()); state.push_poseidon(poseidon_op); diff --git a/evm_arithmetization/src/world.rs b/evm_arithmetization/src/world.rs index d01e7ef26..854edeeb5 100644 --- a/evm_arithmetization/src/world.rs +++ b/evm_arithmetization/src/world.rs @@ -18,7 +18,7 @@ pub struct KeccakHash; impl Hasher for PoseidonHash { fn hash(bytes: &[u8]) -> H256 { - hash_bytecode_h256(bytes) + hash_bytecode_h256(bytes).compat() } } @@ -367,9 +367,8 @@ impl World for Type2World { Ok(()) } fn root(&mut self) -> H256 { - let mut it = [0; 32]; - smt_trie::utils::hashout2u(self.as_smt().root).to_big_endian(&mut it); - H256(it) + let root = smt_trie::utils::hashout2u(self.as_smt().root); + H256::from(root.to_be_bytes()) } } @@ -411,7 +410,7 @@ impl Type2World { ); } for ( - addr, + &addr, Type2Entry { balance, nonce, @@ -430,11 +429,16 @@ impl Type2World { (code_length, key_code_length), ] { if let Some(value) = value { - smt.set(key_fn(*addr), *value); + let addr = addr.compat(); + let value = (*value).compat(); + smt.set(key_fn(addr), value); } } - for (slot, value) in storage { - smt.set(key_storage(*addr, *slot), *value); + for (&slot, &value) in storage { + let addr = addr.compat(); + let slot = slot.compat(); + let value = value.compat(); + smt.set(key_storage(addr, slot), value); } } smt diff --git a/smt_trie/Cargo.toml b/smt_trie/Cargo.toml index 6df15c11c..c1a4cdb9d 100644 --- a/smt_trie/Cargo.toml +++ b/smt_trie/Cargo.toml @@ -12,7 +12,7 @@ homepage.workspace = true keywords.workspace = true [dependencies] -ethereum-types.workspace = true +alloy.workspace = true plonky2.workspace = true rand.workspace = true serde = { workspace = true, features = ["derive", "rc"] } diff --git a/smt_trie/src/bits.rs b/smt_trie/src/bits.rs index 0fbcac0e1..40758ba38 100644 --- a/smt_trie/src/bits.rs +++ b/smt_trie/src/bits.rs @@ -1,6 +1,6 @@ use std::ops::Add; -use ethereum_types::{BigEndianHash, H256, U256}; +use alloy::primitives::{B256, U256}; use serde::{Deserialize, Serialize}; pub type Bit = bool; @@ -22,11 +22,11 @@ impl From for Bits { } } -impl From for Bits { - fn from(packed: H256) -> Self { +impl From for Bits { + fn from(packed: B256) -> Self { Bits { count: 256, - packed: packed.into_uint(), + packed: packed.into(), } } } @@ -38,7 +38,7 @@ impl Add for Bits { assert!(self.count + rhs.count <= 256, "Overflow"); Self { count: self.count + rhs.count, - packed: self.packed * (U256::one() << rhs.count) + rhs.packed, + packed: self.packed * (U256::from(1) << rhs.count) + rhs.packed, } } } @@ -47,7 +47,7 @@ impl Bits { pub const fn empty() -> Self { Bits { count: 0, - packed: U256::zero(), + packed: U256::ZERO, } } @@ -57,7 +57,7 @@ impl Bits { pub fn pop_next_bit(&mut self) -> Bit { assert!(!self.is_empty(), "Cannot pop from empty bits"); - let b = !(self.packed & U256::one()).is_zero(); + let b = !(self.packed & U256::from(1)).is_zero(); self.packed >>= 1; self.count -= 1; b @@ -65,11 +65,11 @@ impl Bits { pub fn get_bit(&self, i: usize) -> Bit { assert!(i < self.count, "Index out of bounds"); - !(self.packed & (U256::one() << (self.count - 1 - i))).is_zero() + !(self.packed & (U256::from(1) << (self.count - 1 - i))).is_zero() } pub fn push_bit(&mut self, bit: Bit) { - self.packed = self.packed * 2 + U256::from(bit as u64); + self.packed = self.packed * U256::from(2) + U256::from(bit as u64); self.count += 1; } diff --git a/smt_trie/src/code.rs b/smt_trie/src/code.rs index 2d1a98661..e02266fee 100644 --- a/smt_trie/src/code.rs +++ b/smt_trie/src/code.rs @@ -1,6 +1,6 @@ /// Functions to hash contract bytecode using Poseidon. /// See `hashContractBytecode()` in https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation. -use ethereum_types::H256; +use alloy::primitives::B256; use plonky2::field::types::Field; use plonky2::hash::poseidon::{self, Poseidon}; @@ -43,7 +43,7 @@ pub fn poseidon_pad_byte_vec(bytes: &mut Vec) { *bytes.last_mut().unwrap() |= 0x80; } -pub fn hash_bytecode_h256(code: &[u8]) -> H256 { +pub fn hash_bytecode_h256(code: &[u8]) -> B256 { hashout2h(hash_contract_bytecode(code.to_vec())) } diff --git a/smt_trie/src/keys.rs b/smt_trie/src/keys.rs index 1f122adbb..314a9b5ee 100644 --- a/smt_trie/src/keys.rs +++ b/smt_trie/src/keys.rs @@ -2,7 +2,7 @@ /// This module contains functions to generate keys for the SMT. /// See https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation. -use ethereum_types::{Address, U256}; +use alloy::primitives::{Address, U256}; use plonky2::{field::types::Field, hash::poseidon::Poseidon}; use crate::smt::{Key, F}; @@ -74,8 +74,9 @@ pub fn key_storage(addr: Address, slot: U256) -> Key { let capacity: [F; 4] = { let mut arr = [F::ZERO; 12]; for i in 0..4 { - arr[2 * i] = F::from_canonical_u32(slot.0[i] as u32); - arr[2 * i + 1] = F::from_canonical_u32((slot.0[i] >> 32) as u32); + let limbs = slot.as_limbs()[i]; + arr[2 * i] = F::from_canonical_u32(limbs as u32); + arr[2 * i + 1] = F::from_canonical_u32((limbs >> 32) as u32); } F::poseidon(arr)[0..4].try_into().unwrap() }; diff --git a/smt_trie/src/smt.rs b/smt_trie/src/smt.rs index f9ea73319..d2236ae5e 100644 --- a/smt_trie/src/smt.rs +++ b/smt_trie/src/smt.rs @@ -3,7 +3,7 @@ use std::borrow::Borrow; use std::collections::{HashMap, HashSet}; -use ethereum_types::U256; +use alloy::primitives::U256; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::poseidon::{Poseidon, PoseidonHash}; @@ -145,7 +145,7 @@ impl Smt { .copied() .unwrap_or_default() .is_zero()); - U256::zero() + U256::ZERO }; } else { let b = keys.get_bit(level as usize); @@ -347,7 +347,7 @@ impl Smt { /// Delete the key in the SMT. pub fn delete(&mut self, key: Key) { self.kv_store.remove(&key); - self.set(key, U256::zero()); + self.set(key, U256::ZERO); } /// Set the key to the hash in the SMT. @@ -416,7 +416,7 @@ impl Smt { &self, keys: I, ) -> Vec { - let mut v = vec![U256::zero(); 2]; // For empty hash node. + let mut v = vec![U256::ZERO; 2]; // For empty hash node. let key = Key(self.root.elements); let mut keys_to_include = HashSet::new(); @@ -433,7 +433,7 @@ impl Smt { serialize(self, key, &mut v, Bits::empty(), &keys_to_include); if v.len() == 2 { - v.extend([U256::zero(); 2]); + v.extend([U256::ZERO; 2]); } v } @@ -457,7 +457,7 @@ fn serialize( if !keys_to_include.contains(&cur_bits) || smt.db.get_node(&key).is_none() { let index = v.len(); - v.push(HASH_TYPE.into()); + v.push(U256::from(HASH_TYPE)); v.push(key2u(key)); index } else if let Some(node) = smt.db.get_node(&key) { @@ -473,7 +473,7 @@ fn serialize( let rem_key = Key(node.0[0..4].try_into().unwrap()); let val = limbs2f(val_a); let index = v.len(); - v.push(LEAF_TYPE.into()); + v.push(U256::from(LEAF_TYPE)); v.push(key2u(rem_key)); v.push(val); index @@ -481,14 +481,24 @@ fn serialize( let key_left = Key(node.0[0..4].try_into().unwrap()); let key_right = Key(node.0[4..8].try_into().unwrap()); let index = v.len(); - v.push(INTERNAL_TYPE.into()); - v.push(U256::zero()); - v.push(U256::zero()); - let i_left = - serialize(smt, key_left, v, cur_bits.add_bit(false), keys_to_include).into(); + v.push(U256::from(INTERNAL_TYPE)); + v.push(U256::ZERO); + v.push(U256::ZERO); + let i_left = U256::from(serialize( + smt, + key_left, + v, + cur_bits.add_bit(false), + keys_to_include, + )); v[index + 1] = i_left; - let i_right = - serialize(smt, key_right, v, cur_bits.add_bit(true), keys_to_include).into(); + let i_right = U256::from(serialize( + smt, + key_right, + v, + cur_bits.add_bit(true), + keys_to_include, + )); v[index + 2] = i_right; index } @@ -507,15 +517,16 @@ pub fn hash_serialize_u256(v: &[U256]) -> U256 { } fn _hash_serialize(v: &[U256], ptr: usize) -> HashOut { - assert!(v[ptr] <= u8::MAX.into()); - match v[ptr].as_u64() as u8 { + let byte: u8 = v[ptr].try_into().expect("U256 should have been <= u8::MAX"); + match byte { HASH_TYPE => u2h(v[ptr + 1]), INTERNAL_TYPE => { let mut node = Node([F::ZERO; 12]); for b in 0..2 { let child_index = v[ptr + 1 + b]; - let child_hash = _hash_serialize(v, child_index.as_usize()); + let child_index = *(child_index.as_limbs().first().unwrap()) as usize; + let child_hash = _hash_serialize(v, child_index); node.0[b * 4..(b + 1) * 4].copy_from_slice(&child_hash.elements); } F::poseidon(node.0)[0..4].try_into().unwrap() diff --git a/smt_trie/src/smt_test.rs b/smt_trie/src/smt_test.rs index c086e17dc..c58c3681a 100644 --- a/smt_trie/src/smt_test.rs +++ b/smt_trie/src/smt_test.rs @@ -1,4 +1,4 @@ -use ethereum_types::U256; +use alloy::primitives::U256; use plonky2::field::types::{Field, Sample}; use plonky2::hash::hash_types::HashOut; use rand::seq::SliceRandom; @@ -18,11 +18,11 @@ fn test_add_and_rem() { let mut smt = Smt::::default(); let k = Key(F::rand_array()); - let v = U256(thread_rng().gen()); + let v = U256::from(thread_rng().gen::()); smt.set(k, v); assert_eq!(v, smt.get(k)); - smt.set(k, U256::zero()); + smt.set(k, U256::ZERO); assert_eq!(smt.root.elements, [F::ZERO; 4]); let ser = smt.serialize(); @@ -48,7 +48,7 @@ fn test_add_and_rem_hermez() { .map(F::from_canonical_u64) ); - smt.set(k, U256::zero()); + smt.set(k, U256::ZERO); assert_eq!(smt.root.elements, [F::ZERO; 4]); let ser = smt.serialize(); @@ -60,8 +60,8 @@ fn test_update_element_1() { let mut smt = Smt::::default(); let k = Key(F::rand_array()); - let v1 = U256(thread_rng().gen()); - let v2 = U256(thread_rng().gen()); + let v1 = U256::from(thread_rng().gen::()); + let v2 = U256::from(thread_rng().gen::()); smt.set(k, v1); let root = smt.root; smt.set(k, v2); @@ -79,12 +79,12 @@ fn test_add_shared_element_2() { let k1 = Key(F::rand_array()); let k2 = Key(F::rand_array()); assert_ne!(k1, k2, "Unlucky"); - let v1 = U256(thread_rng().gen()); - let v2 = U256(thread_rng().gen()); + let v1 = U256::from(thread_rng().gen::()); + let v2 = U256::from(thread_rng().gen::()); smt.set(k1, v1); smt.set(k2, v2); - smt.set(k1, U256::zero()); - smt.set(k2, U256::zero()); + smt.set(k1, U256::ZERO); + smt.set(k2, U256::ZERO); assert_eq!(smt.root.elements, [F::ZERO; 4]); let ser = smt.serialize(); @@ -98,15 +98,15 @@ fn test_add_shared_element_3() { let k1 = Key(F::rand_array()); let k2 = Key(F::rand_array()); let k3 = Key(F::rand_array()); - let v1 = U256(thread_rng().gen()); - let v2 = U256(thread_rng().gen()); - let v3 = U256(thread_rng().gen()); + let v1 = U256::from(thread_rng().gen::()); + let v2 = U256::from(thread_rng().gen::()); + let v3 = U256::from(thread_rng().gen::()); smt.set(k1, v1); smt.set(k2, v2); smt.set(k3, v3); - smt.set(k1, U256::zero()); - smt.set(k2, U256::zero()); - smt.set(k3, U256::zero()); + smt.set(k1, U256::ZERO); + smt.set(k2, U256::ZERO); + smt.set(k3, U256::ZERO); assert_eq!(smt.root.elements, [F::ZERO; 4]); let ser = smt.serialize(); @@ -120,7 +120,7 @@ fn test_add_remove_128() { let kvs = (0..128) .map(|_| { let k = Key(F::rand_array()); - let v = U256(thread_rng().gen()); + let v = U256::from(thread_rng().gen::()); smt.set(k, v); (k, v) }) @@ -129,7 +129,7 @@ fn test_add_remove_128() { smt.set(k, v); } for &(k, _) in &kvs { - smt.set(k, U256::zero()); + smt.set(k, U256::ZERO); } assert_eq!(smt.root.elements, [F::ZERO; 4]); @@ -144,7 +144,7 @@ fn test_should_read_random() { let kvs = (0..128) .map(|_| { let k = Key(F::rand_array()); - let v = U256(thread_rng().gen()); + let v = U256::from(thread_rng().gen::()); smt.set(k, v); (k, v) }) @@ -226,21 +226,25 @@ fn test_leaf_one_level_depth() { ] .map(F::from_canonical_u64)); - let v0 = U256::from_dec_str( + let v0 = U256::from_str_radix( "8163644824788514136399898658176031121905718480550577527648513153802600646339", + 10, ) .unwrap(); - let v1 = U256::from_dec_str( + let v1 = U256::from_str_radix( "115792089237316195423570985008687907853269984665640564039457584007913129639934", + 10, ) .unwrap(); - let v2 = U256::from_dec_str( + let v2 = U256::from_str_radix( "115792089237316195423570985008687907853269984665640564039457584007913129639935", + 10, ) .unwrap(); - let v3 = U256::from_dec_str("7943875943875408").unwrap(); - let v4 = U256::from_dec_str( + let v3 = U256::from_str_radix("7943875943875408", 10).unwrap(); + let v4 = U256::from_str_radix( "35179347944617143021579132182092200136526168785636368258055676929581544372820", + 10, ) .unwrap(); @@ -269,10 +273,10 @@ fn test_no_write_0() { let k1 = Key(F::rand_array()); let k2 = Key(F::rand_array()); - let v = U256(thread_rng().gen()); + let v = U256::from(thread_rng().gen::()); smt.set(k1, v); let root = smt.root; - smt.set(k2, U256::zero()); + smt.set(k2, U256::ZERO); assert_eq!(smt.root, root); let ser = smt.serialize(); @@ -286,7 +290,7 @@ fn test_set_hash_first_level() { let kvs = (0..128) .map(|_| { let k = Key(F::rand_array()); - let v = U256(random()); + let v = U256::from(random::()); smt.set(k, v); (k, v) }) @@ -299,11 +303,11 @@ fn test_set_hash_first_level() { let mut hash_smt = Smt::::default(); let zero = Bits { count: 1, - packed: U256::zero(), + packed: U256::ZERO, }; let one = Bits { count: 1, - packed: U256::one(), + packed: U256::from(1), }; hash_smt.set_hash( zero, @@ -334,7 +338,7 @@ fn test_set_hash_order() { .map(|i| { let k = Bits { count: level, - packed: i.into(), + packed: U256::from(i), }; let hash = HashOut { elements: F::rand_array(), @@ -353,7 +357,7 @@ fn test_set_hash_order() { break key; } }; - let val = U256(random()); + let val = U256::from(random::()); smt.set(key, val); let mut second_smt = Smt::::default(); @@ -375,7 +379,7 @@ fn test_serialize_and_prune() { for _ in 0..128 { let k = Key(F::rand_array()); - let v = U256(random()); + let v = U256::from(random::()); smt.set(k, v); } @@ -399,9 +403,9 @@ fn test_serialize_and_prune() { assert_eq!( trivial_ser, vec![ - U256::zero(), - U256::zero(), - HASH_TYPE.into(), + U256::ZERO, + U256::ZERO, + U256::from(HASH_TYPE), hashout2u(smt.root) ] ); diff --git a/smt_trie/src/utils.rs b/smt_trie/src/utils.rs index 6b30b8972..dd5524470 100644 --- a/smt_trie/src/utils.rs +++ b/smt_trie/src/utils.rs @@ -1,4 +1,4 @@ -use ethereum_types::{H256, U256}; +use alloy::primitives::{B256, U256}; use plonky2::field::types::{Field, PrimeField64}; use plonky2::hash::poseidon::Poseidon; @@ -31,7 +31,10 @@ pub(crate) fn hash_key_hash(k: Key, h: [F; 4]) -> [F; 4] { /// Split a U256 into 8 32-bit limbs in little-endian order. pub(crate) fn f2limbs(x: U256) -> [F; 8] { - std::array::from_fn(|i| F::from_canonical_u32((x >> (32 * i)).low_u32())) + std::array::from_fn(|i| { + let x = *(x >> (32 * i)).as_limbs().first().unwrap(); + F::from_canonical_u32(x as u32) + }) } /// Pack 8 32-bit limbs in little-endian order into a U256. @@ -39,7 +42,7 @@ pub(crate) fn limbs2f(limbs: [F; 8]) -> U256 { limbs .into_iter() .enumerate() - .fold(U256::zero(), |acc, (i, x)| { + .fold(U256::ZERO, |acc, (i, x)| { acc + (U256::from(x.to_canonical_u64()) << (i * 32)) }) } @@ -49,28 +52,26 @@ pub fn hashout2u(h: HashOut) -> U256 { key2u(Key(h.elements)) } -/// Convert a `HashOut` to a `H256`. -pub fn hashout2h(h: HashOut) -> H256 { - let mut it = [0; 32]; - hashout2u(h).to_big_endian(&mut it); - H256(it) +/// Convert a `HashOut` to a `B256`. +pub fn hashout2h(h: HashOut) -> B256 { + B256::new(hashout2u(h).to_be_bytes()) } /// Convert a `Key` to a `U256`. pub fn key2u(key: Key) -> U256 { - U256(key.0.map(|x| x.to_canonical_u64())) + U256::from_limbs(key.0.map(|x| x.to_canonical_u64())) } /// Convert a `U256` to a `Hashout`. pub(crate) fn u2h(x: U256) -> HashOut { HashOut { - elements: x.0.map(F::from_canonical_u64), + elements: x.as_limbs().map(F::from_canonical_u64), } } /// Convert a `U256` to a `Key`. pub(crate) fn u2k(x: U256) -> Key { - Key(x.0.map(F::from_canonical_u64)) + Key(x.as_limbs().map(F::from_canonical_u64)) } /// Given a node, return the index of the unique non-zero sibling, or -1 if