Skip to content

Commit

Permalink
feat: remove nonce management, use starknet nonce (kkrt-labs#649)
Browse files Browse the repository at this point in the history
This pr removes nonce management and uses starknet nonce wherever nonce
is required.

## Pull request type

<!-- Please try to limit your pull request to one type, submit multiple
pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [x] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

We currently manage nonce ourselves which causes multiple issues,
migrating to starknet nonce is a better solution and the PR implements
the required changes for that.

Resolves kkrt-labs#648 

## What is the new behavior?

- nonce management has been removed
- we now use starknet nonce wherever required

## Other information

- The rpc is already [using starknet
nonce](https://github.com/kkrt-labs/kakarot-rpc/blob/d6e3676de6fc9e804dc751aa3493810c95957bba/crates/core/src/client/mod.rs#L338),
which means that the current transactions by tools like metamask are
already being signed with starknet nonce and not our own nonce when RPC
is being used.
  • Loading branch information
bajpai244 authored Jul 25, 2023
1 parent 97efad6 commit 40fc24a
Show file tree
Hide file tree
Showing 21 changed files with 142 additions and 205 deletions.
18 changes: 0 additions & 18 deletions src/kakarot/accounts/contract/contract_account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,3 @@ func is_initialized{
}() -> (is_initialized: felt) {
return ContractAccount.is_initialized();
}

// @notice This function is used to read the nonce from storage
// @return nonce: The current nonce of the contract account
@view
func get_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (nonce: felt) {
return Accounts.get_nonce();
}

// @notice This function increases the contract accounts nonce by 1
// @return nonce: The new nonce of the contract account
@external
func increment_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
nonce: felt
) {
Ownable.assert_only_owner();
Accounts.increment_nonce();
return Accounts.get_nonce();
}
22 changes: 0 additions & 22 deletions src/kakarot/accounts/eoa/externally_owned_account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,3 @@ func bytecode_len{
}() -> (len: felt) {
return (len=0);
}

// @notice This function is used to read the nonce from storage
// @return nonce: The current nonce of the contract account
@view
func get_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (nonce: felt) {
return Accounts.get_nonce();
}

// @notice This function increases the contract accounts nonce by 1
// @dev Currently external for testing purposes. Otherwise would not be needed.
// @return nonce: The new nonce of the contract account
@external
func increment_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
nonce: felt
) {
let (caller) = get_caller_address();
with_attr error_message("ExternallyOwnedAccount: nonce can only be incremented by self") {
assert caller = 0;
}
Accounts.increment_nonce();
return Accounts.get_nonce();
}
8 changes: 4 additions & 4 deletions src/kakarot/accounts/eoa/library.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
%lang starknet

from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.starknet.common.syscalls import CallContract
from starkware.starknet.common.syscalls import CallContract, get_tx_info
from starkware.cairo.common.uint256 import Uint256, uint256_not
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.bool import TRUE, FALSE
Expand Down Expand Up @@ -91,9 +91,10 @@ namespace ExternallyOwnedAccount {
}

let (address) = evm_address.read();
let (nonce) = Accounts.get_nonce();
let (tx_info) = get_tx_info();

EthTransaction.validate(
address, nonce, [call_array].data_len, calldata + [call_array].data_offset
address, tx_info.nonce, [call_array].data_len, calldata + [call_array].data_offset
);

return validate(
Expand Down Expand Up @@ -152,7 +153,6 @@ namespace ExternallyOwnedAccount {
response + return_data_len,
);

Accounts.increment_nonce();
return (response_len=return_data_len + response_len);
}
}
17 changes: 0 additions & 17 deletions src/kakarot/accounts/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,4 @@ namespace Accounts {
return (account_address=account_address);
}

// @notice This function is used to read the nonce from storage
// @return nonce: The current nonce of the contract account
func get_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
nonce: felt
) {
return nonce.read();
}

// @notice This function increases the accounts nonce by 1
// @return nonce: The incremented nonce of the contract account
func increment_nonce{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
nonce: felt
) {
let (current_nonce: felt) = nonce.read();
nonce.write(current_nonce + 1);
return (nonce=current_nonce + 1);
}
}
8 changes: 3 additions & 5 deletions src/kakarot/instructions/system_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ from starkware.cairo.common.math import split_felt
from starkware.cairo.common.math_cmp import is_le, is_not_zero, is_nn
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.uint256 import Uint256
from starkware.starknet.common.syscalls import deploy as deploy_syscall, get_contract_address
from starkware.starknet.common.syscalls import deploy as deploy_syscall, get_contract_address, get_tx_info

// Internal dependencies
from kakarot.constants import contract_account_class_hash, native_token_address, Constants
Expand Down Expand Up @@ -811,14 +811,12 @@ namespace CreateHelper {
// so we use popped_len to derive the way we should handle
// the creation of evm addresses
if (popped_len != 4) {
let (nonce) = IContractAccount.get_nonce(ctx.starknet_contract_address);
let (tx_info) = get_tx_info();
let (evm_contract_address) = CreateHelper.get_create_address(
ctx.evm_contract_address, nonce
ctx.evm_contract_address, tx_info.nonce
);
let (nonce) = IContractAccount.increment_nonce(ctx.starknet_contract_address);
let (contract_account_class_hash_) = contract_account_class_hash.read();
let (starknet_contract_address) = Accounts.create(
contract_account_class_hash_, evm_contract_address
Expand Down
10 changes: 0 additions & 10 deletions src/kakarot/interfaces/interfaces.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ namespace IAccount {
func bytecode() -> (bytecode_len: felt, bytecode: felt*) {
}

func get_nonce() -> (nonce: felt) {
}

func increment_nonce() -> (nonce: felt) {
}
}

@contract_interface
Expand All @@ -75,11 +70,6 @@ namespace IContractAccount {
func write_storage(key: Uint256, value: Uint256) {
}

func get_nonce() -> (nonce: felt) {
}

func increment_nonce() -> (nonce: felt) {
}
}

@contract_interface
Expand Down
4 changes: 2 additions & 2 deletions src/kakarot/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ namespace Kakarot {
alloc_locals;
let (caller_address) = get_caller_address();
let (sender_evm_address) = IAccount.get_evm_address(caller_address);
let (nonce) = IAccount.get_nonce(caller_address);
let (evm_contract_address) = CreateHelper.get_create_address(sender_evm_address, nonce);
let (tx_info) = get_tx_info();
let (evm_contract_address) = CreateHelper.get_create_address(sender_evm_address, tx_info.nonce);
let (class_hash) = contract_account_class_hash.read();
let (starknet_contract_address) = Accounts.create(class_hash, evm_contract_address);
let (empty_array: felt*) = alloc();
Expand Down
24 changes: 0 additions & 24 deletions tests/integration/accounts/test_contract_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,27 +91,3 @@ async def test_should_give_infinite_allowance_to_kakarot(
)
== "Uint256(low=340282366920938463463374607431768211455, high=340282366920938463463374607431768211455)"
)

class TestNonce:
async def test_should_increment_nonce(
self, contract_account: StarknetContract, kakarot
):
# Get current contract account nonce
initial_nonce = (await contract_account.get_nonce().call()).result.nonce

# Increment nonce
await contract_account.increment_nonce().execute(
caller_address=kakarot.contract_address
)

# Get new nonce
assert (
initial_nonce + 1
== (await contract_account.get_nonce().call()).result.nonce
)

async def test_should_raise_when_caller_is_not_kakarot(
self, contract_account: StarknetContract
):
with kakarot_error():
await contract_account.increment_nonce().execute(caller_address=1)
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def addresses(starknet, kakarot, externally_owned_account_class) -> List[W

# Randomly generated private keys
private_keys = [predefined_private_key] + [
generate_random_private_key(seed=i) for i in range(3)
generate_random_private_key(seed=i) for i in range(15)
]

wallets = []
Expand Down
41 changes: 29 additions & 12 deletions tests/integration/solidity_contracts/PlainOpcodes/conftest.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,51 @@
import pytest_asyncio


@pytest_asyncio.fixture(scope="module")
async def counter(deploy_solidity_contract, owner):
return await deploy_solidity_contract("PlainOpcodes", "Counter", caller_eoa=owner)
@pytest_asyncio.fixture(scope="session")
def counter_deployer(addresses):
return addresses[1]

@pytest_asyncio.fixture(scope="session")
def caller_deployer(addresses):
return addresses[2]

@pytest_asyncio.fixture(scope="module")
async def caller(deploy_solidity_contract, owner):
@pytest_asyncio.fixture(scope="session")
def plain_opcodes_deployer(addresses):
return addresses[3]

@pytest_asyncio.fixture(scope="session")
def safe_deployer(addresses):
return addresses[4]


@pytest_asyncio.fixture(scope="package")
async def counter(deploy_solidity_contract, counter_deployer):
return await deploy_solidity_contract("PlainOpcodes", "Counter", caller_eoa=counter_deployer)


@pytest_asyncio.fixture(scope="package")
async def caller(deploy_solidity_contract, caller_deployer):
return await deploy_solidity_contract(
"PlainOpcodes",
"Caller",
caller_eoa=owner,
caller_eoa=caller_deployer,
)


@pytest_asyncio.fixture(scope="module")
async def plain_opcodes(deploy_solidity_contract, owner, counter):
@pytest_asyncio.fixture(scope="package")
async def plain_opcodes(deploy_solidity_contract, plain_opcodes_deployer, counter):
return await deploy_solidity_contract(
"PlainOpcodes",
"PlainOpcodes",
counter.evm_contract_address,
caller_eoa=owner,
caller_eoa=plain_opcodes_deployer,
)


@pytest_asyncio.fixture(scope="module")
async def safe(deploy_solidity_contract, owner):
@pytest_asyncio.fixture(scope="package")
async def safe(deploy_solidity_contract, safe_deployer):
return await deploy_solidity_contract(
"PlainOpcodes",
"Safe",
caller_eoa=owner,
caller_eoa=safe_deployer
)
32 changes: 16 additions & 16 deletions tests/integration/solidity_contracts/PlainOpcodes/test_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,32 @@ async def test_should_return_0_after_deployment(self, counter):
assert await counter.count() == 0

class TestInc:
async def test_should_increase_count(self, counter, addresses):
await counter.inc(caller_address=addresses[1].starknet_address)
async def test_should_increase_count(self, counter, counter_deployer):
await counter.inc(caller_address=counter_deployer.starknet_address)
assert await counter.count() == 1

class TestDec:
async def test_should_raise_when_count_is_0(self, counter, addresses):
async def test_should_raise_when_count_is_0(self, counter, counter_deployer):
with kakarot_error("count should be strictly greater than 0"):
await counter.dec(caller_address=addresses[1].starknet_address)
await counter.dec(caller_address=counter_deployer.starknet_address)

async def test_should_decrease_count(self, counter, addresses):
await counter.inc(caller_address=addresses[1].starknet_address)
await counter.dec(caller_address=addresses[1].starknet_address)
async def test_should_decrease_count(self, counter, counter_deployer):
await counter.inc(caller_address=counter_deployer.starknet_address)
await counter.dec(caller_address=counter_deployer.starknet_address)
assert await counter.count() == 0

async def test_should_decrease_count_unchecked(self, counter, addresses):
await counter.inc(caller_address=addresses[1].starknet_address)
await counter.decUnchecked(caller_address=addresses[1].starknet_address)
async def test_should_decrease_count_unchecked(self, counter, counter_deployer):
await counter.inc(caller_address=counter_deployer.starknet_address)
await counter.decUnchecked(caller_address=counter_deployer.starknet_address)
assert await counter.count() == 0

async def test_should_decrease_count_in_place(self, counter, addresses):
await counter.inc(caller_address=addresses[1].starknet_address)
await counter.decInPlace(caller_address=addresses[1].starknet_address)
async def test_should_decrease_count_in_place(self, counter, counter_deployer):
await counter.inc(caller_address=counter_deployer.starknet_address)
await counter.decInPlace(caller_address=counter_deployer.starknet_address)
assert await counter.count() == 0

class TestReset:
async def test_should_set_count_to_0(self, counter, addresses):
await counter.inc(caller_address=addresses[1].starknet_address)
await counter.reset(caller_address=addresses[1].starknet_address)
async def test_should_set_count_to_0(self, counter, counter_deployer):
await counter.inc(caller_address=counter_deployer.starknet_address)
await counter.reset(caller_address=counter_deployer.starknet_address)
assert await counter.count() == 0
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ async def test_should_increase_counter(
self,
counter,
plain_opcodes,
addresses,
counter_deployer,
):
await plain_opcodes.opcodeCall(caller_address=addresses[1].starknet_address)
await plain_opcodes.opcodeCall(caller_address=counter_deployer.starknet_address)
assert await counter.count() == 1

class TestBlockhash:
Expand Down Expand Up @@ -146,8 +146,8 @@ async def test_should_emit_log3(self, plain_opcodes, addresses, event):
assert log_receipt["address"] == expected_address
assert plain_opcodes.events.Log3 == [event]

async def test_should_emit_log4(self, plain_opcodes, addresses, event):
await plain_opcodes.opcodeLog4(caller_address=addresses[0].starknet_address)
async def test_should_emit_log4(self, plain_opcodes, plain_opcodes_deployer, event):
await plain_opcodes.opcodeLog4(caller_address=plain_opcodes_deployer.starknet_address)
# the contract address is set at deploy time, we verify that event address is
# getting correctly set by asserting equality
expected_address = plain_opcodes.address
Expand All @@ -160,15 +160,15 @@ async def test_should_deploy_bytecode_at_address(
self,
plain_opcodes,
counter,
addresses,
plain_opcodes_deployer,
get_starknet_address,
get_solidity_contract,
):
salt = 1234
evm_address = await plain_opcodes.create2(
bytecode=counter.constructor().data_in_transaction,
salt=salt,
caller_address=addresses[0].starknet_address,
caller_address=plain_opcodes_deployer.starknet_address,
)
starknet_address = get_starknet_address(salt)
deployed_counter = get_solidity_contract(
Expand Down Expand Up @@ -259,6 +259,6 @@ async def test_should_return_owner_as_origin_and_caller_as_sender(

class TestLoop:
@pytest.mark.parametrize("steps", [0, 1, 2, 10])
async def test_loop_should_write_to_storage(self, plain_opcodes, owner, steps):
await plain_opcodes.testLoop(steps, caller_address=owner.starknet_address)
async def test_loop_should_write_to_storage(self, plain_opcodes, plain_opcodes_deployer, steps):
await plain_opcodes.testLoop(steps, caller_address=plain_opcodes_deployer.starknet_address)
assert await plain_opcodes.loopValue() == steps
7 changes: 5 additions & 2 deletions tests/integration/solidity_contracts/Solmate/test_erc20.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@
TEST_SUPPLY = 10**18
TEST_AMOUNT = int(0.9 * 10**18)

@pytest_asyncio.fixture(scope="session")
def erc_20_deployer(addresses):
return addresses[5]

@pytest_asyncio.fixture(scope="module")
async def erc_20(deploy_solidity_contract, owner):
async def erc_20(deploy_solidity_contract, erc_20_deployer):
return await deploy_solidity_contract(
"Solmate",
"ERC20",
"Kakarot Token",
"KKT",
18,
caller_eoa=owner,
caller_eoa=erc_20_deployer,
)


Expand Down
Loading

0 comments on commit 40fc24a

Please sign in to comment.