From 96056069c91659c9be10f8c3951b1f830a30c4e0 Mon Sep 17 00:00:00 2001 From: Elias Tazartes <66871571+Eikix@users.noreply.github.com> Date: Thu, 16 Jan 2025 11:15:35 +0100 Subject: [PATCH] feat: set_account (#437) tests are broken so far, will investigate --------- Co-authored-by: enitrat --- cairo/ethereum/cancun/fork_types.cairo | 6 ++- cairo/ethereum/cancun/state.cairo | 52 ++++++++++++++++------- cairo/ethereum/cancun/trie.cairo | 38 +++++++++-------- cairo/tests/ethereum/cancun/test_state.py | 12 +++++- cairo/tests/ethereum/cancun/test_trie.py | 14 +++--- cairo/tests/test_serde.py | 1 + cairo/tests/utils/args_gen.py | 7 +-- cairo/tests/utils/serde.py | 11 +++++ 8 files changed, 97 insertions(+), 44 deletions(-) diff --git a/cairo/ethereum/cancun/fork_types.cairo b/cairo/ethereum/cancun/fork_types.cairo index f1aefe05..bf3f9f39 100644 --- a/cairo/ethereum/cancun/fork_types.cairo +++ b/cairo/ethereum/cancun/fork_types.cairo @@ -85,6 +85,10 @@ struct Account { value: AccountStruct*, } +struct OptionalAccount { + value: AccountStruct*, +} + struct AddressAccountDictAccess { key: Address, prev_value: Account, @@ -130,7 +134,7 @@ func EMPTY_ACCOUNT() -> Account { return account; } -func Account__eq__(a: Account, b: Account) -> bool { +func Account__eq__(a: OptionalAccount, b: OptionalAccount) -> bool { if (cast(a.value, felt) == 0) { let b_is_none = is_zero(cast(b.value, felt)); let res = bool(b_is_none); diff --git a/cairo/ethereum/cancun/state.cairo b/cairo/ethereum/cancun/state.cairo index 983e87e7..cfe049f2 100644 --- a/cairo/ethereum/cancun/state.cairo +++ b/cairo/ethereum/cancun/state.cairo @@ -7,6 +7,7 @@ from starkware.cairo.common.math import assert_not_zero from ethereum.cancun.fork_types import ( Address, Account, + OptionalAccount, MappingAddressAccount, SetAddress, EMPTY_ACCOUNT, @@ -16,13 +17,14 @@ from ethereum.cancun.fork_types import ( ) from ethereum.cancun.trie import ( TrieBytes32U256, - TrieAddressAccount, - trie_get_TrieAddressAccount, + TrieAddressOptionalAccount, + trie_get_TrieAddressOptionalAccount, + trie_set_TrieAddressOptionalAccount, trie_get_TrieBytes32U256, trie_set_TrieBytes32U256, AccountStruct, TrieBytes32U256Struct, - TrieAddressAccountStruct, + TrieAddressOptionalAccountStruct, ) from ethereum_types.bytes import Bytes, Bytes32 from ethereum_types.numeric import U256, U256Struct, Bool, bool @@ -47,22 +49,22 @@ struct MappingAddressTrieBytes32U256 { value: MappingAddressTrieBytes32U256Struct*, } -struct TupleTrieAddressAccountMappingAddressTrieBytes32U256Struct { - trie_address_account: TrieAddressAccount, +struct TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct { + trie_address_account: TrieAddressOptionalAccount, mapping_address_trie: MappingAddressTrieBytes32U256, } -struct TupleTrieAddressAccountMappingAddressTrieBytes32U256 { - value: TupleTrieAddressAccountMappingAddressTrieBytes32U256Struct*, +struct TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256 { + value: TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct*, } -struct ListTupleTrieAddressAccountMappingAddressTrieBytes32U256Struct { - data: TupleTrieAddressAccountMappingAddressTrieBytes32U256*, +struct ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct { + data: TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256*, len: felt, } -struct ListTupleTrieAddressAccountMappingAddressTrieBytes32U256 { - value: ListTupleTrieAddressAccountMappingAddressTrieBytes32U256Struct*, +struct ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256 { + value: ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256Struct*, } struct TransientStorageSnapshotsStruct { @@ -84,9 +86,9 @@ struct TransientStorage { } struct StateStruct { - _main_trie: TrieAddressAccount, + _main_trie: TrieAddressOptionalAccount, _storage_tries: MappingAddressTrieBytes32U256, - _snapshots: ListTupleTrieAddressAccountMappingAddressTrieBytes32U256, + _snapshots: ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256, created_accounts: SetAddress, } @@ -94,13 +96,12 @@ struct State { value: StateStruct*, } -using OptionalAccount = Account; func get_account_optional{poseidon_ptr: PoseidonBuiltin*, state: State}( address: Address ) -> OptionalAccount { let trie = state.value._main_trie; with trie { - let account = trie_get_TrieAddressAccount(address); + let account = trie_get_TrieAddressOptionalAccount(address); } return account; @@ -114,7 +115,26 @@ func get_account{poseidon_ptr: PoseidonBuiltin*, state: State}(address: Address) return empty_account; } - return account; + tempvar res = Account(account.value); + return res; +} + +func set_account{poseidon_ptr: PoseidonBuiltin*, state: State}( + address: Address, account: OptionalAccount +) { + let trie = state.value._main_trie; + with trie { + trie_set_TrieAddressOptionalAccount(address, account); + } + tempvar state = State( + new StateStruct( + _main_trie=trie, + _storage_tries=state.value._storage_tries, + _snapshots=state.value._snapshots, + created_accounts=state.value.created_accounts, + ), + ); + return (); } func get_storage{poseidon_ptr: PoseidonBuiltin*, state: State}( diff --git a/cairo/ethereum/cancun/trie.cairo b/cairo/ethereum/cancun/trie.cairo index c1657819..395f952c 100644 --- a/cairo/ethereum/cancun/trie.cairo +++ b/cairo/ethereum/cancun/trie.cairo @@ -34,6 +34,7 @@ from ethereum.cancun.fork_types import ( Account__eq__, AccountStruct, Address, + OptionalAccount, Bytes32U256DictAccess, MappingAddressAccount, MappingAddressAccountStruct, @@ -167,14 +168,14 @@ struct Node { value: NodeEnum*, } -struct TrieAddressAccountStruct { +struct TrieAddressOptionalAccountStruct { secured: bool, - default: Account, + default: OptionalAccount, _data: MappingAddressAccount, } -struct TrieAddressAccount { - value: TrieAddressAccountStruct*, +struct TrieAddressOptionalAccount { + value: TrieAddressOptionalAccountStruct*, } struct TrieBytes32U256Struct { @@ -334,7 +335,8 @@ func encode_node{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: Kecc // @notice Copies the trie to a new segment. // @dev This function simply creates a new segment for the new dict and associates it with the // dict_tracker of the source dict. -func copy_trieAddressAccount{range_check_ptr, trie: TrieAddressAccount}() -> TrieAddressAccount { +func copy_TrieAddressOptionalAccount{range_check_ptr, trie: TrieAddressOptionalAccount}( + ) -> TrieAddressOptionalAccount { alloc_locals; // TODO: soundness // We need to ensure it is sound when finalizing that copy. @@ -355,8 +357,8 @@ func copy_trieAddressAccount{range_check_ptr, trie: TrieAddressAccount}() -> Tri ids.new_dict_ptr = __dict_manager.new_dict(segments, copied_data) %} - tempvar res = TrieAddressAccount( - new TrieAddressAccountStruct( + tempvar res = TrieAddressOptionalAccount( + new TrieAddressOptionalAccountStruct( trie.value.secured, trie.value.default, MappingAddressAccount( @@ -393,9 +395,9 @@ func copy_trieBytes32U256{range_check_ptr, trie: TrieBytes32U256}() -> TrieBytes return res; } -func trie_get_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressAccount}( - key: Address -) -> Account { +func trie_get_TrieAddressOptionalAccount{ + poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressOptionalAccount +}(key: Address) -> OptionalAccount { alloc_locals; let dict_ptr = cast(trie.value._data.value.dict_ptr, DictAccess*); @@ -412,10 +414,10 @@ func trie_get_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddre trie.value._data.value.dict_ptr_start, new_dict_ptr, original_mapping ), ); - tempvar trie = TrieAddressAccount( - new TrieAddressAccountStruct(trie.value.secured, trie.value.default, mapping) + tempvar trie = TrieAddressOptionalAccount( + new TrieAddressOptionalAccountStruct(trie.value.secured, trie.value.default, mapping) ); - tempvar res = Account(cast(pointer, AccountStruct*)); + tempvar res = OptionalAccount(cast(pointer, AccountStruct*)); return res; } @@ -441,9 +443,9 @@ func trie_get_TrieBytes32U256{poseidon_ptr: PoseidonBuiltin*, trie: TrieBytes32U return res; } -func trie_set_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressAccount}( - key: Address, value: Account -) { +func trie_set_TrieAddressOptionalAccount{ + poseidon_ptr: PoseidonBuiltin*, trie: TrieAddressOptionalAccount +}(key: Address, value: OptionalAccount) { let dict_ptr_start = cast(trie.value._data.value.dict_ptr_start, DictAccess*); let dict_ptr = cast(trie.value._data.value.dict_ptr, DictAccess*); @@ -473,8 +475,8 @@ func trie_set_TrieAddressAccount{poseidon_ptr: PoseidonBuiltin*, trie: TrieAddre trie.value._data.value.original_mapping, ), ); - tempvar trie = TrieAddressAccount( - new TrieAddressAccountStruct(trie.value.secured, trie.value.default, mapping) + tempvar trie = TrieAddressOptionalAccount( + new TrieAddressOptionalAccountStruct(trie.value.secured, trie.value.default, mapping) ); return (); } diff --git a/cairo/tests/ethereum/cancun/test_state.py b/cairo/tests/ethereum/cancun/test_state.py index f2c1acff..9e7104b3 100644 --- a/cairo/tests/ethereum/cancun/test_state.py +++ b/cairo/tests/ethereum/cancun/test_state.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest from ethereum_types.bytes import Bytes32 from ethereum_types.numeric import U256 @@ -5,7 +7,7 @@ from hypothesis import strategies as st from hypothesis.strategies import composite -from ethereum.cancun.fork_types import Address +from ethereum.cancun.fork_types import Account, Address from ethereum.cancun.state import ( account_exists, account_has_code_or_nonce, @@ -14,6 +16,7 @@ get_storage, get_transient_storage, is_account_empty, + set_account, set_storage, set_transient_storage, ) @@ -67,6 +70,13 @@ def test_get_account_optional(self, cairo_run, data): assert result_cairo == get_account_optional(state, address) assert state_cairo == state + @given(data=state_and_address_and_optional_key(), account=...) + def test_set_account(self, cairo_run, data, account: Optional[Account]): + state, address = data + state_cairo = cairo_run("set_account", state, address, account) + set_account(state, address, account) + assert state_cairo == state + @given(data=state_and_address_and_optional_key()) def test_account_has_code_or_nonce(self, cairo_run, data): state, address = data diff --git a/cairo/tests/ethereum/cancun/test_trie.py b/cairo/tests/ethereum/cancun/test_trie.py index 82b5e674..caca4aa8 100644 --- a/cairo/tests/ethereum/cancun/test_trie.py +++ b/cairo/tests/ethereum/cancun/test_trie.py @@ -155,10 +155,12 @@ def test_patricialize(self, cairo_run, obj: Mapping[Bytes, Bytes]): class TestTrieOperations: @given(trie=..., key=...) - def test_trie_get_TrieAddressAccount( + def test_trie_get_TrieAddressOptionalAccount( self, cairo_run, trie: Trie[Address, Optional[Account]], key: Address ): - [trie_cairo, result_cairo] = cairo_run("trie_get_TrieAddressAccount", trie, key) + [trie_cairo, result_cairo] = cairo_run( + "trie_get_TrieAddressOptionalAccount", trie, key + ) result_py = trie_get(trie, key) assert result_cairo == result_py assert trie_cairo == trie @@ -173,14 +175,14 @@ def test_trie_get_TrieBytes32U256( assert trie_cairo == trie @given(trie=..., key=..., value=...) - def test_trie_set_TrieAddressAccount( + def test_trie_set_TrieAddressOptionalAccount( self, cairo_run, trie: Trie[Address, Optional[Account]], key: Address, value: Account, ): - cairo_trie = cairo_run("trie_set_TrieAddressAccount", trie, key, value) + cairo_trie = cairo_run("trie_set_TrieAddressOptionalAccount", trie, key, value) trie_set(trie, key, value) assert cairo_trie == trie @@ -196,7 +198,9 @@ def test_trie_set_TrieBytes32U256( def test_copy_trie_AddressAccount( self, cairo_run, trie: Trie[Address, Optional[Account]] ): - [original_trie, copied_trie] = cairo_run("copy_trieAddressAccount", trie) + [original_trie, copied_trie] = cairo_run( + "copy_TrieAddressOptionalAccount", trie + ) trie_copy_py = copy_trie(trie) assert original_trie == trie assert copied_trie == trie_copy_py diff --git a/cairo/tests/test_serde.py b/cairo/tests/test_serde.py index 3cff1a91..4bc37557 100644 --- a/cairo/tests/test_serde.py +++ b/cairo/tests/test_serde.py @@ -203,6 +203,7 @@ def test_type( Address, Root, Account, + Optional[Account], Bloom, VersionedHash, Tuple[VersionedHash, ...], diff --git a/cairo/tests/utils/args_gen.py b/cairo/tests/utils/args_gen.py index f668a9e9..9ad28e82 100644 --- a/cairo/tests/utils/args_gen.py +++ b/cairo/tests/utils/args_gen.py @@ -266,6 +266,7 @@ def __eq__(self, other): ("ethereum", "cancun", "fork_types", "SetAddress"): Set[Address], ("ethereum", "cancun", "fork_types", "Root"): Root, ("ethereum", "cancun", "fork_types", "Account"): Account, + ("ethereum", "cancun", "fork_types", "OptionalAccount"): Optional[Account], ("ethereum", "cancun", "fork_types", "Bloom"): Bloom, ("ethereum", "cancun", "fork_types", "VersionedHash"): VersionedHash, ("ethereum", "cancun", "fork_types", "TupleVersionedHash"): Tuple[ @@ -311,7 +312,7 @@ def __eq__(self, other): ("ethereum", "cancun", "trie", "BranchNode"): BranchNode, ("ethereum", "cancun", "trie", "InternalNode"): InternalNode, ("ethereum", "cancun", "trie", "Node"): Node, - ("ethereum", "cancun", "trie", "TrieAddressAccount"): Trie[ + ("ethereum", "cancun", "trie", "TrieAddressOptionalAccount"): Trie[ Address, Optional[Account] ], ("ethereum", "cancun", "trie", "TrieBytes32U256"): Trie[Bytes32, U256], @@ -335,13 +336,13 @@ def __eq__(self, other): "ethereum", "cancun", "state", - "TupleTrieAddressAccountMappingAddressTrieBytes32U256", + "TupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256", ): Tuple[Trie[Address, Optional[Account]], Mapping[Address, Trie[Bytes32, U256]]], ( "ethereum", "cancun", "state", - "ListTupleTrieAddressAccountMappingAddressTrieBytes32U256", + "ListTupleTrieAddressOptionalAccountMappingAddressTrieBytes32U256", ): List[ Tuple[Trie[Address, Optional[Account]], Mapping[Address, Trie[Bytes32, U256]]] ], diff --git a/cairo/tests/utils/serde.py b/cairo/tests/utils/serde.py index b118eff4..7c941d84 100644 --- a/cairo/tests/utils/serde.py +++ b/cairo/tests/utils/serde.py @@ -160,6 +160,16 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any: python_cls, *annotations = get_args(python_cls) origin_cls = get_origin(python_cls) + # arg_type = Optional[T] <=> arg_type_origin = Union[T, None] + if origin_cls is Union and get_args(python_cls)[1] is type(None): + # Get the value pointer: if it's zero, return None. + # Otherwise, consider this the non-optional type: + value_ptr = self.serialize_pointers(path, ptr)["value"] + if value_ptr is None: + return None + python_cls = get_args(python_cls)[0] + origin_cls = get_origin(python_cls) + if origin_cls is Union: value_ptr = self.serialize_pointers(path, ptr)["value"] if value_ptr is None: @@ -176,6 +186,7 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any: if value != 0 and value is not None } if len(variant_keys) != 1: + breakpoint() raise ValueError( f"Expected 1 item only to be relocatable in enum, got {len(variant_keys)}" )