From 3417ee395e8125dc4fe9b2e9f5047f64e939e0de Mon Sep 17 00:00:00 2001 From: Lior Goldberg Date: Sun, 19 Nov 2023 21:47:47 +0200 Subject: [PATCH] Cairo v0.13.0a1 (pre). --- src/starkware/cairo/lang/VERSION | 2 +- .../starknet/business_logic/state/state.py | 24 +++--- .../business_logic/state/state_api.py | 10 +-- .../os/execution/execute_entry_point.cairo | 13 ++++ .../os/execution/execute_transactions.cairo | 76 +++++++++++-------- .../starknet/core/os/program_hash.json | 2 +- .../transaction_hash/transaction_hash.cairo | 12 ++- .../transaction_hash_test_utils.py | 7 +- src/starkware/starknet/wallets/signer.py | 76 ++++++++++++++++--- 9 files changed, 149 insertions(+), 73 deletions(-) diff --git a/src/starkware/cairo/lang/VERSION b/src/starkware/cairo/lang/VERSION index 19b4ae8e..174c8065 100644 --- a/src/starkware/cairo/lang/VERSION +++ b/src/starkware/cairo/lang/VERSION @@ -1 +1 @@ -0.13.0a0 +0.13.0a1 diff --git a/src/starkware/starknet/business_logic/state/state.py b/src/starkware/starknet/business_logic/state/state.py index 68d86fa4..dcb9fa05 100644 --- a/src/starkware/starknet/business_logic/state/state.py +++ b/src/starkware/starknet/business_logic/state/state.py @@ -277,15 +277,13 @@ async def get_compiled_class(self, compiled_class_hash: int) -> CompiledClassBas return self.compiled_classes[compiled_class_hash] - async def get_raw_compiled_class(self, compiled_class_hash: int) -> RawCompiledClass: - if compiled_class_hash not in self.raw_compiled_classes: - self.raw_compiled_classes[ - compiled_class_hash - ] = await self.state_reader.get_raw_compiled_class( - compiled_class_hash=compiled_class_hash + async def get_raw_compiled_class(self, class_hash: int) -> RawCompiledClass: + if class_hash not in self.raw_compiled_classes: + self.raw_compiled_classes[class_hash] = await self.state_reader.get_raw_compiled_class( + class_hash=class_hash ) - return self.raw_compiled_classes[compiled_class_hash] + return self.raw_compiled_classes[class_hash] async def get_compiled_class_hash(self, class_hash: int) -> int: if class_hash not in self.cache.class_hash_to_compiled_class_hash: @@ -444,13 +442,13 @@ def get_compiled_class(self, compiled_class_hash: int) -> CompiledClassBase: return self.compiled_classes[compiled_class_hash] - def get_raw_compiled_class(self, compiled_class_hash: int) -> RawCompiledClass: - if compiled_class_hash not in self.raw_compiled_classes: - self.raw_compiled_classes[ - compiled_class_hash - ] = self.state_reader.get_raw_compiled_class(compiled_class_hash=compiled_class_hash) + def get_raw_compiled_class(self, class_hash: int) -> RawCompiledClass: + if class_hash not in self.raw_compiled_classes: + self.raw_compiled_classes[class_hash] = self.state_reader.get_raw_compiled_class( + class_hash=class_hash + ) - return self.raw_compiled_classes[compiled_class_hash] + return self.raw_compiled_classes[class_hash] def get_compiled_class_hash(self, class_hash: int) -> int: if class_hash not in self.cache.class_hash_to_compiled_class_hash: diff --git a/src/starkware/starknet/business_logic/state/state_api.py b/src/starkware/starknet/business_logic/state/state_api.py index de465c47..ea08bc9f 100644 --- a/src/starkware/starknet/business_logic/state/state_api.py +++ b/src/starkware/starknet/business_logic/state/state_api.py @@ -28,12 +28,12 @@ async def get_compiled_class(self, compiled_class_hash: int) -> CompiledClassBas Raises an exception if said class was not declared. """ - async def get_raw_compiled_class(self, compiled_class_hash: int) -> RawCompiledClass: + async def get_raw_compiled_class(self, class_hash: int) -> RawCompiledClass: """ - Returns the raw compiled class of the given compiled class hash. + Returns the raw compiled class of the given class hash. Raises an exception if said class was not declared. """ - compiled_class = await self.get_compiled_class(compiled_class_hash=compiled_class_hash) + compiled_class = await self.get_compiled_class_by_class_hash(class_hash=class_hash) if isinstance(compiled_class, CompiledClass): return RawCompiledClass( raw_compiled_class=CompiledClass.Schema().dumps(compiled_class), version=1 @@ -172,8 +172,8 @@ class SyncStateReader(ABC): def get_compiled_class(self, compiled_class_hash: int) -> CompiledClassBase: pass - def get_raw_compiled_class(self, compiled_class_hash: int) -> RawCompiledClass: - compiled_class = self.get_compiled_class(compiled_class_hash=compiled_class_hash) + def get_raw_compiled_class(self, class_hash: int) -> RawCompiledClass: + compiled_class = self.get_compiled_class_by_class_hash(class_hash=class_hash) if isinstance(compiled_class, CompiledClass): return RawCompiledClass( raw_compiled_class=CompiledClass.Schema().dumps(compiled_class), version=1 diff --git a/src/starkware/starknet/core/os/execution/execute_entry_point.cairo b/src/starkware/starknet/core/os/execution/execute_entry_point.cairo index 03c150b8..2936e110 100644 --- a/src/starkware/starknet/core/os/execution/execute_entry_point.cairo +++ b/src/starkware/starknet/core/os/execution/execute_entry_point.cairo @@ -237,6 +237,19 @@ func execute_entry_point{ %} // Check that the execution was successful. + %{ + return_values = ids.entry_point_return_values + if return_values.failure_flag != 0: + # Fetch the error, up to 100 elements. + retdata_size = return_values.retdata_end - return_values.retdata_start + error = memory.get_range(return_values.retdata_start, max(0, min(100, retdata_size))) + + print("Invalid return value in execute_entry_point:") + print(f" Class hash: {hex(ids.execution_context.class_hash)}") + print(f" Selector: {hex(ids.execution_context.execution_info.selector)}") + print(f" Size: {retdata_size}") + print(f" Error (at most 100 elements): {error}") + %} assert entry_point_return_values.failure_flag = 0; let remaining_gas = entry_point_return_values.gas_builtin; diff --git a/src/starkware/starknet/core/os/execution/execute_transactions.cairo b/src/starkware/starknet/core/os/execution/execute_transactions.cairo index c32ebf24..a126c92a 100644 --- a/src/starkware/starknet/core/os/execution/execute_transactions.cairo +++ b/src/starkware/starknet/core/os/execution/execute_transactions.cairo @@ -197,6 +197,7 @@ func execute_transactions_inner{ contract_class_changes: DictAccess*, outputs: OsCarriedOutputs*, }(block_context: BlockContext*, n_txs) { + %{ print(f"execute_transactions_inner: {ids.n_txs} transactions remaining.") %} if (n_txs == 0) { return (); } @@ -641,7 +642,7 @@ func check_and_increment_nonce{contract_state_changes: DictAccess*}(tx_info: TxI %} tempvar current_nonce = state_entry.nonce; - with_attr error_message("Unexpected nonce. Expected {current_nonce}, got {tx_info.nonce}.") { + with_attr error_message("Unexpected nonce.") { assert current_nonce = tx_info.nonce; } @@ -701,6 +702,15 @@ func run_validate{ block_context=block_context, execution_context=validate_execution_context ); if (is_deprecated == 0) { + %{ + # Fetch the result, up to 100 elements. + result = memory.get_range(ids.retdata, min(100, ids.retdata_size)) + + if result != [ids.VALIDATED]: + print("Invalid return value from __validate__:") + print(f" Size: {ids.retdata_size}") + print(f" Result (at most 100 elements): {result}") + %} assert retdata_size = 1; assert retdata[0] = VALIDATED; } @@ -805,8 +815,10 @@ func execute_deploy_account_transaction{ local constructor_execution_context: ExecutionContext*, local salt ) = prepare_constructor_execution_context(block_info=block_context.block_info_for_validate); local constructor_execution_info: ExecutionInfo* = constructor_execution_context.execution_info; + local sender_address = constructor_execution_info.contract_address; // Prepare validate_deploy calldata. + local validate_deploy_calldata_size = constructor_execution_context.calldata_size + 2; let (validate_deploy_calldata: felt*) = alloc(); assert validate_deploy_calldata[0] = constructor_execution_context.class_hash; assert validate_deploy_calldata[1] = salt; @@ -816,24 +828,6 @@ func execute_deploy_account_transaction{ len=constructor_execution_context.calldata_size, ); - // Note that the members of execution_info.tx_info are not initialized at this point. - local tx_info: TxInfo* = constructor_execution_info.tx_info; - local deprecated_tx_info: DeprecatedTxInfo* = constructor_execution_context.deprecated_tx_info; - local validate_deploy_execution_context: ExecutionContext* = new ExecutionContext( - entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, - class_hash=constructor_execution_context.class_hash, - calldata_size=constructor_execution_context.calldata_size + 2, - calldata=validate_deploy_calldata, - execution_info=new ExecutionInfo( - block_info=block_context.block_info_for_validate, - tx_info=tx_info, - caller_address=constructor_execution_info.caller_address, - contract_address=constructor_execution_info.contract_address, - selector=VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR, - ), - deprecated_tx_info=deprecated_tx_info, - ); - // Guess tx fields. // Compute transaction hash and prepare transaction info. // The version validation is done in `compute_deploy_account_transaction_hash()`. @@ -853,7 +847,7 @@ func execute_deploy_account_transaction{ local common_tx_fields: CommonTxFields = CommonTxFields( tx_hash_prefix=DEPLOY_ACCOUNT_HASH_PREFIX, version=nondet %{ tx.version %}, - sender_address=constructor_execution_info.contract_address, + sender_address=sender_address, max_fee=nondet %{ tx.max_fee if tx.version < 3 else 0 %}, chain_id=block_context.starknet_os_config.chain_id, nonce=nondet %{ tx.nonce %}, @@ -875,7 +869,9 @@ func execute_deploy_account_transaction{ let poseidon_ptr = builtin_ptrs.selectable.poseidon; with pedersen_ptr, poseidon_ptr { let transaction_hash = compute_deploy_account_transaction_hash( - common_fields=&common_tx_fields, execution_context=validate_deploy_execution_context + common_fields=&common_tx_fields, + calldata_size=validate_deploy_calldata_size, + calldata=validate_deploy_calldata, ); } update_builtin_ptrs(pedersen_ptr=pedersen_ptr, poseidon_ptr=poseidon_ptr); @@ -886,9 +882,10 @@ func execute_deploy_account_transaction{ f"Computed hash = {ids.transaction_hash}, Expected hash = {tx.hash_value}.") %} - // Assign the transaction info to both calls. - // Note that both constructor_execution_context and - // validate_deploy_execution_context hold this pointer. + // Initialize and fill the transaction info structs. + local tx_info: TxInfo* = constructor_execution_info.tx_info; + local deprecated_tx_info: DeprecatedTxInfo* = constructor_execution_context.deprecated_tx_info; + local signature_start: felt*; local signature_len: felt; %{ @@ -898,7 +895,7 @@ func execute_deploy_account_transaction{ assert_nn_le(signature_len, SIERRA_ARRAY_LEN_BOUND - 1); assert [tx_info] = TxInfo( version=common_tx_fields.version, - account_contract_address=constructor_execution_info.contract_address, + account_contract_address=sender_address, max_fee=common_tx_fields.max_fee, signature_start=signature_start, signature_end=&signature_start[signature_len], @@ -923,23 +920,38 @@ func execute_deploy_account_transaction{ deploy_contract( block_context=block_context, constructor_execution_context=constructor_execution_context ); - let updated_execution_context = update_class_hash_in_execution_context( - execution_context=validate_deploy_execution_context - ); // Handle nonce here since 'deploy_contract' verifies that the nonce is zeroed. check_and_increment_nonce(tx_info=tx_info); - // Runs the account contract's "__validate_deploy__" entry point, - // which is responsible for signature verification. + // Run the account contract's "__validate_deploy__" entry point. + + // Fetch the newest state entry, after constructor invocation. + let (state_entry: StateEntry*) = dict_read{dict_ptr=contract_state_changes}(key=sender_address); + // Prepare execution context. + local validate_deploy_execution_context: ExecutionContext* = new ExecutionContext( + entry_point_type=ENTRY_POINT_TYPE_EXTERNAL, + class_hash=state_entry.class_hash, + calldata_size=validate_deploy_calldata_size, + calldata=validate_deploy_calldata, + execution_info=new ExecutionInfo( + block_info=block_context.block_info_for_validate, + tx_info=tx_info, + caller_address=constructor_execution_info.caller_address, + contract_address=sender_address, + selector=VALIDATE_DEPLOY_ENTRY_POINT_SELECTOR, + ), + deprecated_tx_info=deprecated_tx_info, + ); + // Run the entrypoint. let (retdata_size, retdata, is_deprecated) = select_execute_entry_point_func( - block_context=block_context, execution_context=updated_execution_context + block_context=block_context, execution_context=validate_deploy_execution_context ); if (is_deprecated == 0) { assert retdata_size = 1; assert retdata[0] = VALIDATED; } - charge_fee(block_context=block_context, tx_execution_context=updated_execution_context); + charge_fee(block_context=block_context, tx_execution_context=validate_deploy_execution_context); %{ execution_helper.end_tx() %} return (); diff --git a/src/starkware/starknet/core/os/program_hash.json b/src/starkware/starknet/core/os/program_hash.json index 545cfd9d..9481dc5e 100644 --- a/src/starkware/starknet/core/os/program_hash.json +++ b/src/starkware/starknet/core/os/program_hash.json @@ -1,3 +1,3 @@ { - "program_hash": "0x45bb5419b84aa6706e6d14eec6dc25ec95eb94f0320414076d1526e81aaecaf" + "program_hash": "0x157ffd1dd481d6b307b53d990cf351ed87dea0fddbd96f8320f60c802cc4451" } diff --git a/src/starkware/starknet/core/os/transaction_hash/transaction_hash.cairo b/src/starkware/starknet/core/os/transaction_hash/transaction_hash.cairo index ef72710a..91f4dd7a 100644 --- a/src/starkware/starknet/core/os/transaction_hash/transaction_hash.cairo +++ b/src/starkware/starknet/core/os/transaction_hash/transaction_hash.cairo @@ -252,7 +252,7 @@ func compute_l1_handler_transaction_hash{pedersen_ptr: HashBuiltin*}( // See comment above `compute_invoke_transaction_hash()`. func compute_deploy_account_transaction_hash{ range_check_ptr, pedersen_ptr: HashBuiltin*, poseidon_ptr: PoseidonBuiltin* -}(common_fields: CommonTxFields*, execution_context: ExecutionContext*) -> felt { +}(common_fields: CommonTxFields*, calldata_size: felt, calldata: felt*) -> felt { alloc_locals; local version = common_fields.version; @@ -263,8 +263,8 @@ func compute_deploy_account_transaction_hash{ version=version, contract_address=common_fields.sender_address, entry_point_selector=0, - calldata_size=execution_context.calldata_size, - calldata=execution_context.calldata, + calldata_size=calldata_size, + calldata=calldata, max_fee=common_fields.max_fee, chain_id=common_fields.chain_id, additional_data_size=1, @@ -283,11 +283,9 @@ func compute_deploy_account_transaction_hash{ with hash_state { hash_tx_common_fields(common_fields=common_fields); // Hash and add the constructor calldata to the hash state. - poseidon_hash_update_with_nested_hash( - data_ptr=&execution_context.calldata[2], data_length=execution_context.calldata_size - 2 - ); + poseidon_hash_update_with_nested_hash(data_ptr=&calldata[2], data_length=calldata_size - 2); // Add the class hash and the contract address salt to the hash state. - poseidon_hash_update(data_ptr=execution_context.calldata, data_length=2); + poseidon_hash_update(data_ptr=calldata, data_length=2); } let transaction_hash = poseidon_hash_finalize(hash_state=hash_state); diff --git a/src/starkware/starknet/core/os/transaction_hash/transaction_hash_test_utils.py b/src/starkware/starknet/core/os/transaction_hash/transaction_hash_test_utils.py index 67054f5a..36181f51 100644 --- a/src/starkware/starknet/core/os/transaction_hash/transaction_hash_test_utils.py +++ b/src/starkware/starknet/core/os/transaction_hash/transaction_hash_test_utils.py @@ -205,11 +205,8 @@ def run_cairo_deploy_account_transaction_hash( n_resource_bounds=n_resource_bounds, resource_bounds=resource_bounds, ), - execution_context=create_execution_context( - program=program, - contract_address=contract_address, - calldata=calldata, - ), + calldata_size=len(calldata), + calldata=calldata, use_full_name=True, verify_secure=False, ) diff --git a/src/starkware/starknet/wallets/signer.py b/src/starkware/starknet/wallets/signer.py index 9fb575c7..06ad9c0b 100644 --- a/src/starkware/starknet/wallets/signer.py +++ b/src/starkware/starknet/wallets/signer.py @@ -465,7 +465,7 @@ def sign_deploy_syscall_deprecated_tx( deploy_from_zero=deploy_from_zero, salt=salt, ) - deploy_tx = OpenZeppelinSigner.sign_deprecated_invoke_tx( + deploy_tx = cls.sign_deprecated_invoke_tx( signer_address=account_address, private_key=private_key, call_function=CallFunction( @@ -508,7 +508,7 @@ def sign_deploy_syscall_tx( deploy_from_zero=deploy_from_zero, salt=salt, ) - deploy_tx = OpenZeppelinSigner.sign_invoke_tx( + deploy_tx = cls.sign_invoke_tx( sender_address=account_address, private_key=private_key, contract_address=account_address, @@ -522,19 +522,33 @@ def sign_deploy_syscall_tx( return contract_address, deploy_tx -class OpenZeppelinSigner(SignerBase): +class EcdsaSignerBase(SignerBase): + """ + Base class for signing transactions using ECDSA. + """ + + @classmethod + def sign_tx_hash(cls, tx_hash: int, private_key: Optional[int]) -> List[int]: + return [] if private_key is None else list(sign(msg_hash=tx_hash, priv_key=private_key)) + + +class OpenZeppelinSigner(EcdsaSignerBase): + """ + Contains signing logic for the OpenZeppelin Cairo 0 account contract. + """ + @classmethod def format_multicall_calldata(cls, calls: List[CallFunction]) -> List[int]: call_array_len = len(calls) multicall_calldata = [call_array_len] data_offset = 0 flat_calldata_list = [] - for call_function in calls: - flat_calldata_list.extend(call_function.calldata) - data_len = len(call_function.calldata) + for call in calls: + flat_calldata_list.extend(call.calldata) + data_len = len(call.calldata) call_entry = [ - call_function.contract_address, - call_function.entry_point_selector, + call.contract_address, + call.entry_point_selector, data_offset, data_len, ] @@ -544,6 +558,50 @@ def format_multicall_calldata(cls, calls: List[CallFunction]) -> List[int]: multicall_calldata.extend([len(flat_calldata_list), *flat_calldata_list]) return multicall_calldata + +class StandardSigner(EcdsaSignerBase): + """ + Contains signing logic for the starndard Cairo 1 account contract from the Cairo compiler repo. + + Assumes the following calldata format: `calls: Array`, where `Call` struct is + struct Call { + to: ContractAddress, + selector: felt252, + calldata: Array + } + """ + + @classmethod + def format_multicall_calldata(cls, calls: List[CallFunction]) -> List[int]: + multicall_calldata = [len(calls)] + for call in calls: + multicall_calldata += [ + call.contract_address, + call.entry_point_selector, + len(call.calldata), + *call.calldata, + ] + + return multicall_calldata + + +class TrivialSigner(SignerBase): + """ + Trivial implementation for accounts without multicalls nor signature verfication. + """ + @classmethod def sign_tx_hash(cls, tx_hash: int, private_key: Optional[int]) -> List[int]: - return [] if private_key is None else list(sign(msg_hash=tx_hash, priv_key=private_key)) + assert private_key is None, "Sigining is not supproted for the TrivialSigner." + return [] + + @classmethod + def format_multicall_calldata(cls, calls: List[CallFunction]) -> List[int]: + assert len(calls) == 1, "Multicall is not supported for the TrivialSigner." + (call,) = calls + return [ + call.contract_address, + call.entry_point_selector, + len(call.calldata), + *call.calldata, + ]