From 7a617d34a1309ded2297d8a2c0fda23db4624c88 Mon Sep 17 00:00:00 2001 From: Oba Date: Mon, 13 Jan 2025 18:25:32 +0100 Subject: [PATCH] feat: get_transient_storage --- cairo/ethereum/cancun/state.cairo | 66 +++++++++++++++++++++-- cairo/tests/ethereum/cancun/test_state.py | 30 ++++++++++- cairo/tests/utils/strategies.py | 27 +++++++++- 3 files changed, 118 insertions(+), 5 deletions(-) diff --git a/cairo/ethereum/cancun/state.cairo b/cairo/ethereum/cancun/state.cairo index 7b810b90..e90cc627 100644 --- a/cairo/ethereum/cancun/state.cairo +++ b/cairo/ethereum/cancun/state.cairo @@ -1,7 +1,7 @@ from starkware.cairo.common.cairo_builtins import PoseidonBuiltin from starkware.cairo.common.dict_access import DictAccess from starkware.cairo.common.dict import dict_new -from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc +from starkware.cairo.common.registers import get_fp_and_pc from starkware.cairo.common.math import assert_not_zero from ethereum_types.numeric import bool @@ -23,11 +23,13 @@ from ethereum.cancun.trie import ( trie_set_TrieBytes32U256, AccountStruct, TrieBytes32U256Struct, + TrieAddressAccountStruct, ) from ethereum_types.bytes import Bytes, Bytes32 -from src.utils.dict import hashdict_read, hashdict_write, hashdict_get from ethereum_types.numeric import U256, U256Struct +from src.utils.dict import hashdict_read, hashdict_write, hashdict_get + struct AddressTrieBytes32U256DictAccess { key: Address, prev_value: TrieBytes32U256, @@ -187,7 +189,6 @@ func get_storage{poseidon_ptr: PoseidonBuiltin*, state: State}( created_accounts=state.value.created_accounts, ), ); - return value; } @@ -298,3 +299,62 @@ func set_storage{poseidon_ptr: PoseidonBuiltin*, state: State}( ); return (); } + +func get_transient_storage{poseidon_ptr: PoseidonBuiltin*, transient_storage: TransientStorage}( + address: Address, key: Bytes32 +) -> U256 { + alloc_locals; + let fp_and_pc = get_fp_and_pc(); + local __fp__: felt* = fp_and_pc.fp_val; + + let transient_storage_tries_dict_ptr = cast( + transient_storage.value._tries.value.dict_ptr, DictAccess* + ); + with poseidon_ptr { + let (trie_ptr) = hashdict_get{dict_ptr=transient_storage_tries_dict_ptr}(1, &address.value); + } + + // If no storage trie is associated to that address, return the 0 default + if (trie_ptr == 0) { + let new_transient_storage_tries_dict_ptr = cast( + transient_storage_tries_dict_ptr, AddressTrieBytes32U256DictAccess* + ); + tempvar transient_storage_tries = MappingAddressTrieBytes32U256( + new MappingAddressTrieBytes32U256Struct( + transient_storage.value._tries.value.dict_ptr_start, + new_transient_storage_tries_dict_ptr, + transient_storage.value._tries.value.original_mapping, + ), + ); + tempvar transient_storage = TransientStorage( + new TransientStorageStruct(transient_storage_tries, transient_storage.value._snapshots) + ); + tempvar result = U256(new U256Struct(0, 0)); + return result; + } + + let trie = TrieBytes32U256(cast(trie_ptr, TrieBytes32U256Struct*)); + with trie { + let value = trie_get_TrieBytes32U256(key); + } + + // Rebind the trie to the transient storage + hashdict_write{poseidon_ptr=poseidon_ptr, dict_ptr=transient_storage_tries_dict_ptr}( + 1, &address.value, cast(trie.value, felt) + ); + let new_storage_tries_dict_ptr = cast( + transient_storage_tries_dict_ptr, AddressTrieBytes32U256DictAccess* + ); + tempvar transient_storage_tries = MappingAddressTrieBytes32U256( + new MappingAddressTrieBytes32U256Struct( + transient_storage.value._tries.value.dict_ptr_start, + new_storage_tries_dict_ptr, + transient_storage.value._tries.value.original_mapping, + ), + ); + tempvar transient_storage = TransientStorage( + new TransientStorageStruct(transient_storage_tries, transient_storage.value._snapshots) + ); + + return value; +} diff --git a/cairo/tests/ethereum/cancun/test_state.py b/cairo/tests/ethereum/cancun/test_state.py index fa33c11c..e74ec892 100644 --- a/cairo/tests/ethereum/cancun/test_state.py +++ b/cairo/tests/ethereum/cancun/test_state.py @@ -1,16 +1,20 @@ import pytest +from ethereum_types.bytes import Bytes32 from ethereum_types.numeric import U256 from hypothesis import given from hypothesis import strategies as st from hypothesis.strategies import composite +from ethereum.cancun.fork_types import Address from ethereum.cancun.state import ( get_account, get_account_optional, get_storage, + get_transient_storage, set_storage, ) -from tests.utils.strategies import address, bytes32, state +from tests.utils.args_gen import TransientStorage +from tests.utils.strategies import address, bytes32, state, transient_storage_lite pytestmark = pytest.mark.python_vm @@ -100,3 +104,27 @@ def test_set_storage(self, cairo_run, data, value: U256): set_storage(state, address, key, value) assert state_cairo == state + + +class TestTransientStorage: + @given( + transient_storage=transient_storage_lite, + address=..., + key=..., + ) + def test_get_transient_storage( + self, + cairo_run, + transient_storage: TransientStorage, + address: Address, + key: Bytes32, + ): + [transient_storage_cairo, result_cairo] = cairo_run( + "get_transient_storage", + transient_storage, + address, + key, + ) + result_py = get_transient_storage(transient_storage, address, key) + assert result_cairo == result_py + assert transient_storage_cairo == transient_storage diff --git a/cairo/tests/utils/strategies.py b/cairo/tests/utils/strategies.py index e9b499b9..8550be34 100644 --- a/cairo/tests/utils/strategies.py +++ b/cairo/tests/utils/strategies.py @@ -224,6 +224,31 @@ def tuple_strategy(thing): # Using this list instead of the hash32 strategy to avoid data_to_large errors BLOCK_HASHES_LIST = [Hash32(Bytes32(bytes([i] * 32))) for i in range(256)] +transient_storage_lite = st.lists(address, max_size=MAX_ADDRESS_SET_SIZE).flatmap( + lambda addresses: st.builds( + TransientStorage, + _tries=st.fixed_dictionaries( + { + address: trie_strategy(Trie[Bytes32, U256]).filter( + lambda t: bool(t._data) + ) + for address in addresses + } + ), + _snapshots=st.lists( + st.fixed_dictionaries( + { + address: trie_strategy(Trie[Bytes32, U256]).filter( + lambda t: bool(t._data) + ) + for address in addresses + } + ), + max_size=20, + ), + ) +) + environment_lite = st.integers(min_value=0).flatmap( # Generate block number first lambda number: st.builds( Environment, @@ -247,7 +272,7 @@ def tuple_strategy(thing): blob_versioned_hashes=st.lists( st.from_type(VersionedHash), min_size=0, max_size=5 ).map(tuple), - transient_storage=st.from_type(TransientStorage), + transient_storage=transient_storage_lite, ) )