Skip to content

Commit

Permalink
feat: set_transient_storage (#429)
Browse files Browse the repository at this point in the history
Close #408
  • Loading branch information
obatirou authored Jan 15, 2025
1 parent 70c2d7c commit 0563748
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 8 deletions.
70 changes: 68 additions & 2 deletions cairo/ethereum/cancun/state.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.dict import dict_new
from starkware.cairo.common.registers import get_fp_and_pc
from starkware.cairo.common.math import assert_not_zero
from ethereum_types.numeric import bool

from ethereum.cancun.fork_types import (
Address,
Expand All @@ -26,7 +25,7 @@ from ethereum.cancun.trie import (
TrieAddressAccountStruct,
)
from ethereum_types.bytes import Bytes, Bytes32
from ethereum_types.numeric import U256, U256Struct
from ethereum_types.numeric import U256, U256Struct, Bool, bool

from src.utils.dict import hashdict_read, hashdict_write, hashdict_get

Expand Down Expand Up @@ -340,3 +339,70 @@ func get_transient_storage{poseidon_ptr: PoseidonBuiltin*, transient_storage: Tr

return value;
}

func set_transient_storage{poseidon_ptr: PoseidonBuiltin*, transient_storage: TransientStorage}(
address: Address, key: Bytes32, value: U256
) {
alloc_locals;
let fp_and_pc = get_fp_and_pc();
local __fp__: felt* = fp_and_pc.fp_val;

let transient_storage_tries_dict_ptr = cast(
transient_storage.value._tries.value.dict_ptr, DictAccess*
);
let (trie_ptr) = hashdict_get{dict_ptr=transient_storage_tries_dict_ptr}(1, &address.value);

if (trie_ptr == 0) {
%{ initial_dict = {} %}
let (empty_dict) = dict_new();
tempvar new_trie = new TrieBytes32U256Struct(
secured=Bool(1),
default=U256(new U256Struct(0, 0)),
_data=MappingBytes32U256(
new MappingBytes32U256Struct(
dict_ptr_start=cast(empty_dict, Bytes32U256DictAccess*),
dict_ptr=cast(empty_dict, Bytes32U256DictAccess*),
original_mapping=cast(0, MappingBytes32U256Struct*),
),
),
);
let new_trie_ptr = cast(new_trie, felt);
hashdict_write{poseidon_ptr=poseidon_ptr, dict_ptr=transient_storage_tries_dict_ptr}(
1, &address.value, new_trie_ptr
);
tempvar trie_ptr = new_trie_ptr;
} else {
tempvar trie_ptr = trie_ptr;
}

let transient_storage_tries_dict_ptr = transient_storage_tries_dict_ptr;
tempvar trie = TrieBytes32U256(cast(trie_ptr, TrieBytes32U256Struct*));
with trie {
trie_set_TrieBytes32U256{poseidon_ptr=poseidon_ptr}(key, value);
}

// Trie is not deleted if empty
// From EELS https://github.com/ethereum/execution-specs/blob/5c82ed6ac3eb992c7d87320a3e771b5e852a06df/src/ethereum/cancun/state.py#L697:
// if trie._data == {}:
// del transient_storage._tries[address]

// Update the transient storage tries
hashdict_write{poseidon_ptr=poseidon_ptr, dict_ptr=transient_storage_tries_dict_ptr}(
1, &address.value, cast(trie.value, felt)
);
let new_storage_tries_dict_ptr = cast(
transient_storage_tries_dict_ptr, AddressTrieBytes32U256DictAccess*
);
tempvar transient_storage_tries = MappingAddressTrieBytes32U256(
new MappingAddressTrieBytes32U256Struct(
transient_storage.value._tries.value.dict_ptr_start,
new_storage_tries_dict_ptr,
transient_storage.value._tries.value.original_mapping,
),
);
tempvar transient_storage = TransientStorage(
new TransientStorageStruct(transient_storage_tries, transient_storage.value._snapshots)
);

return ();
}
29 changes: 27 additions & 2 deletions cairo/tests/ethereum/cancun/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
get_storage,
get_transient_storage,
set_storage,
set_transient_storage,
)
from tests.utils.args_gen import TransientStorage
from tests.utils.strategies import address, bytes32, state, transient_storage_lite
from tests.utils.strategies import address, bytes32, state, transient_storage

pytestmark = pytest.mark.python_vm

Expand Down Expand Up @@ -108,7 +109,7 @@ def test_set_storage(self, cairo_run, data, value: U256):

class TestTransientStorage:
@given(
transient_storage=transient_storage_lite,
transient_storage=transient_storage,
address=...,
key=...,
)
Expand All @@ -128,3 +129,27 @@ def test_get_transient_storage(
result_py = get_transient_storage(transient_storage, address, key)
assert result_cairo == result_py
assert transient_storage_cairo == transient_storage

@given(
transient_storage=transient_storage,
address=...,
key=...,
value=...,
)
def test_set_transient_storage(
self,
cairo_run,
transient_storage: TransientStorage,
address: Address,
key: Bytes32,
value: U256,
):
transient_storage_cairo = cairo_run(
"set_transient_storage",
transient_storage,
address,
key,
value,
)
set_transient_storage(transient_storage, address, key, value)
assert transient_storage_cairo == transient_storage
12 changes: 11 additions & 1 deletion cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from starkware.cairo.lang.vm.memory_dict import UnknownMemoryError
from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager

from ethereum.cancun.state import State
from ethereum.cancun.state import State, TransientStorage
from ethereum.cancun.vm.exceptions import InvalidOpcode
from ethereum.crypto.hash import Hash32
from tests.utils.args_gen import Memory, Stack, to_python_type, vm_exception_classes
Expand Down Expand Up @@ -372,6 +372,16 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
for k in keys_to_delete:
del value["_storage_tries"][k]

if python_cls is TransientStorage:
if value["_tries"] is not None and value["_tries"] != {}:
# First collect all keys with empty tries
keys_to_delete = [
k for k, v in value["_tries"].items() if v._data == {}
]
# Cannot iterate over a dict while deleting items from it
for k in keys_to_delete:
del value["_tries"][k]

adjusted_value = {
k: (
None
Expand Down
7 changes: 4 additions & 3 deletions cairo/tests/utils/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def tuple_strategy(thing):
# Using this list instead of the hash32 strategy to avoid data_to_large errors
BLOCK_HASHES_LIST = [Hash32(Bytes32(bytes([i] * 32))) for i in range(256)]

transient_storage_lite = st.lists(
address, max_size=MAX_ADDRESS_TRANSIENT_STORAGE_SIZE
transient_storage = st.lists(
address, max_size=MAX_ADDRESS_TRANSIENT_STORAGE_SIZE, unique=True
).flatmap(
lambda addresses: st.builds(
TransientStorage,
Expand Down Expand Up @@ -282,7 +282,7 @@ def tuple_strategy(thing):
blob_versioned_hashes=st.lists(
st.from_type(VersionedHash), min_size=0, max_size=5
).map(tuple),
transient_storage=transient_storage_lite,
transient_storage=transient_storage,
)
)

Expand Down Expand Up @@ -457,3 +457,4 @@ def register_type_strategies():
st.register_type_strategy(Evm, evm)
st.register_type_strategy(tuple, tuple_strategy)
st.register_type_strategy(State, state)
st.register_type_strategy(TransientStorage, transient_storage)

0 comments on commit 0563748

Please sign in to comment.