Skip to content

Commit

Permalink
feat: set storage (#414)
Browse files Browse the repository at this point in the history
Setting to draft as tests cannot pass until we skip serialization of
nested dictionaries that are filled with default_values in serde.py, as
seen in
https://github.com/ethereum/execution-specs/blob/master/src/ethereum/cancun/state.py#L318

closes #409

---------

Co-authored-by: enitrat <msaug@protonmail.com>
  • Loading branch information
Eikix and enitrat authored Jan 15, 2025
1 parent 0d064d2 commit 1436c75
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 43 deletions.
11 changes: 11 additions & 0 deletions cairo/ethereum/cancun/fork_types.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from ethereum_types.bytes import Bytes20, Bytes32, Bytes256, Bytes, BytesStruct,
from ethereum.utils.bytes import Bytes__eq__
from ethereum_types.numeric import Uint, U256, U256Struct, bool
from ethereum.crypto.hash import Hash32
from ethereum.utils.numeric import is_zero

using Address = Bytes20;

Expand Down Expand Up @@ -126,6 +127,16 @@ func EMPTY_ACCOUNT() -> Account {
}

func Account__eq__(a: Account, b: Account) -> bool {
if (cast(a.value, felt) == 0) {
let b_is_none = is_zero(cast(b.value, felt));
let res = bool(b_is_none);
return res;
}
if (cast(b.value, felt) == 0) {
let a_is_none = is_zero(cast(a.value, felt));
let res = bool(a_is_none);
return res;
}
if (a.value.nonce.value != b.value.nonce.value) {
tempvar res = bool(0);
return res;
Expand Down
123 changes: 120 additions & 3 deletions cairo/ethereum/cancun/state.cairo
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
from starkware.cairo.common.cairo_builtins import PoseidonBuiltin
from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.dict import dict_new
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc
from starkware.cairo.common.math import assert_not_zero
from ethereum_types.numeric import bool

from ethereum.cancun.fork_types import (
Address,
Account,
MappingAddressAccount,
SetAddress,
EMPTY_ACCOUNT,
MappingBytes32U256,
MappingBytes32U256Struct,
Bytes32U256DictAccess,
)
from ethereum.cancun.trie import (
TrieBytes32U256,
TrieAddressAccount,
trie_get_TrieAddressAccount,
trie_get_TrieBytes32U256,
trie_set_TrieBytes32U256,
AccountStruct,
TrieBytes32U256Struct,
)
from ethereum_types.bytes import Bytes, Bytes32
from src.utils.dict import hashdict_read, hashdict_write
from src.utils.dict import hashdict_read, hashdict_write, hashdict_get
from ethereum_types.numeric import U256, U256Struct

struct AddressTrieBytes32U256DictAccess {
Expand Down Expand Up @@ -120,7 +127,9 @@ func get_storage{poseidon_ptr: PoseidonBuiltin*, state: State}(

let storage_tries_dict_ptr = cast(storage_tries.value.dict_ptr, DictAccess*);

let (pointer) = hashdict_read{poseidon_ptr=poseidon_ptr, dict_ptr=storage_tries_dict_ptr}(
// Use `hashdict_get` instead of `hashdict_read` because `MappingAddressTrieBytes32U256` is not a
// `default_dict`. Accessing a key that does not exist in the dict would have panicked for `hashdict_read`.
let (pointer) = hashdict_get{poseidon_ptr=poseidon_ptr, dict_ptr=storage_tries_dict_ptr}(
1, &address.value
);

Expand Down Expand Up @@ -154,7 +163,7 @@ func get_storage{poseidon_ptr: PoseidonBuiltin*, state: State}(
let value = trie_get_TrieBytes32U256{poseidon_ptr=poseidon_ptr, trie=storage_trie}(key);

// Rebind the storage trie to the state
let new_storage_trie_ptr = cast(storage_trie_ptr, felt);
let new_storage_trie_ptr = cast(storage_trie.value, felt);

hashdict_write{poseidon_ptr=poseidon_ptr, dict_ptr=storage_tries_dict_ptr}(
1, &address.value, new_storage_trie_ptr
Expand All @@ -181,3 +190,111 @@ func get_storage{poseidon_ptr: PoseidonBuiltin*, state: State}(

return value;
}

func set_storage{poseidon_ptr: PoseidonBuiltin*, state: State}(
address: Address, key: Bytes32, value: U256
) {
alloc_locals;

let storage_tries = state.value._storage_tries;
let fp_and_pc = get_fp_and_pc();
local __fp__: felt* = fp_and_pc.fp_val;

// Assert that the account exists
let account = get_account_optional(address);
if (cast(account.value, felt) == 0) {
// TODO: think about which cases lead to this error and decide on the correct type of exception to raise
// perhaps AssertionError
with_attr error_message("Cannot set storage on non-existent account") {
assert 0 = 1;
}
}

let storage_tries_dict_ptr = cast(storage_tries.value.dict_ptr, DictAccess*);
// Use `hashdict_get` instead of `hashdict_read` because `MappingAddressTrieBytes32U256` is not a
// `default_dict`. Accessing a key that does not exist in the dict would have panicked for `hashdict_read`.
let (storage_trie_pointer) = hashdict_get{
poseidon_ptr=poseidon_ptr, dict_ptr=storage_tries_dict_ptr
}(1, &address.value);

if (cast(storage_trie_pointer, felt) == 0) {
// dict_new expects an initial_dict hint argument.
%{ initial_dict = {} %}
let (new_mapping_dict_ptr) = dict_new();
tempvar _storage_trie = new TrieBytes32U256Struct(
secured=bool(1),
default=U256(new U256Struct(0, 0)),
_data=MappingBytes32U256(
new MappingBytes32U256Struct(
dict_ptr_start=cast(new_mapping_dict_ptr, Bytes32U256DictAccess*),
dict_ptr=cast(new_mapping_dict_ptr, Bytes32U256DictAccess*),
original_mapping=cast(0, MappingBytes32U256Struct*),
),
),
);

let new_storage_trie_ptr = cast(_storage_trie, felt);
hashdict_write{poseidon_ptr=poseidon_ptr, dict_ptr=storage_tries_dict_ptr}(
1, &address.value, new_storage_trie_ptr
);

let new_storage_tries_dict_ptr = cast(
storage_tries_dict_ptr, AddressTrieBytes32U256DictAccess*
);

tempvar new_storage_tries = MappingAddressTrieBytes32U256(
new MappingAddressTrieBytes32U256Struct(
dict_ptr_start=storage_tries.value.dict_ptr_start,
dict_ptr=new_storage_tries_dict_ptr,
original_mapping=storage_tries.value.original_mapping,
),
);

tempvar state = State(
new StateStruct(
_main_trie=state.value._main_trie,
_storage_tries=storage_tries,
_snapshots=state.value._snapshots,
created_accounts=state.value.created_accounts,
),
);
return ();
}
let trie_struct = cast(storage_trie_pointer, TrieBytes32U256Struct*);
let storage_trie = TrieBytes32U256(trie_struct);
trie_set_TrieBytes32U256{poseidon_ptr=poseidon_ptr, trie=storage_trie}(key, value);

// From EELS <https://github.com/ethereum/execution-specs/blob/master/src/ethereum/cancun/state.py#L318>:
// if trie._data == {}:
// del state._storage_tries[address]
// TODO: Investigate whether this is needed inside provable code
// If the storage trie is empty, then write null ptr to the mapping address -> storage trie at address

// Update state
// 1. Write the updated storage trie to the mapping address -> storage trie
let storage_trie_ptr = cast(storage_trie.value, felt);
hashdict_write{poseidon_ptr=poseidon_ptr, dict_ptr=storage_tries_dict_ptr}(
1, &address.value, storage_trie_ptr
);
// 2. Create a new storage_tries instance with the updated storage trie at address
let new_storage_tries_dict_ptr = cast(
storage_tries_dict_ptr, AddressTrieBytes32U256DictAccess*
);
tempvar new_storage_tries = MappingAddressTrieBytes32U256(
new MappingAddressTrieBytes32U256Struct(
dict_ptr_start=storage_tries.value.dict_ptr_start,
dict_ptr=new_storage_tries_dict_ptr,
original_mapping=storage_tries.value.original_mapping,
),
);
// 3. Update state with the updated storage tries
tempvar state = State(
new StateStruct(
_main_trie=state.value._main_trie,
_storage_tries=new_storage_tries,
_snapshots=state.value._snapshots,
created_accounts=state.value.created_accounts,
),
);
return ();
}
44 changes: 43 additions & 1 deletion cairo/src/utils/dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,44 @@ func hashdict_read{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
return (value=value);
}

// A wrapper around dict_read that hashes the key before accessing the dictionary if the key
// does not fit in a felt.
// @dev This version returns 0, if the key is not found and the dict is NOT a defaultdict.
// @param key_len: The readnumber of felt values used to represent the key.
// @param key: The key to access the dictionary.
// TODO: write the associated squash function.
func hashdict_get{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
key_len: felt, key: felt*
) -> (value: felt) {
alloc_locals;
local felt_key;
if (key_len == 1) {
assert felt_key = key[0];
tempvar poseidon_ptr = poseidon_ptr;
} else {
let (felt_key_) = poseidon_hash_many(key_len, key);
assert felt_key = felt_key_;
tempvar poseidon_ptr = poseidon_ptr;
}

local value;
%{
from collections import defaultdict
dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
if isinstance(dict_tracker.data, defaultdict):
ids.value = dict_tracker.data[preimage]
else:
ids.value = dict_tracker.data.get(preimage, 0)
%}
dict_ptr.key = felt_key;
dict_ptr.prev_value = value;
dict_ptr.new_value = value;
let dict_ptr = dict_ptr + DictAccess.SIZE;
return (value=value);
}

// A wrapper around dict_write that hashes the key before accessing the dictionary if the key
// does not fit in a felt.
// @param key_len: The number of felt values used to represent the key.
Expand All @@ -102,10 +140,14 @@ func hashdict_write{poseidon_ptr: PoseidonBuiltin*, dict_ptr: DictAccess*}(
tempvar poseidon_ptr = poseidon_ptr;
}
%{
from collections import defaultdict
dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dict_tracker.current_ptr += ids.DictAccess.SIZE
preimage = tuple([memory[ids.key + i] for i in range(ids.key_len)])
ids.dict_ptr.prev_value = dict_tracker.data[preimage]
if isinstance(dict_tracker.data, defaultdict):
ids.dict_ptr.prev_value = dict_tracker.data[preimage]
else:
ids.dict_ptr.prev_value = 0
dict_tracker.data[preimage] = ids.new_value
%}
dict_ptr.key = felt_key;
Expand Down
34 changes: 30 additions & 4 deletions cairo/tests/ethereum/cancun/test_state.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import pytest
from ethereum_types.numeric import U256
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import composite

from ethereum.cancun.state import get_account, get_account_optional, get_storage
from ethereum.cancun.state import (
get_account,
get_account_optional,
get_storage,
set_storage,
)
from tests.utils.strategies import address, bytes32, state

pytestmark = pytest.mark.python_vm
Expand All @@ -18,7 +24,7 @@ def state_and_address_and_key(
# For address selection, use address_strategy if no keys in state
address_options = (
[st.sampled_from(list(state._main_trie._data.keys())), address_strategy]
if state._main_trie._data is not None and state._main_trie._data
if state._main_trie._data != {}
else [address_strategy]
)
address = draw(st.one_of(*address_options))
Expand All @@ -30,15 +36,15 @@ def state_and_address_and_key(

key_options = (
[st.sampled_from(list(storage._data.keys())), key_strategy]
if storage is not None and storage._data
if storage is not None and storage._data != {}
else [key_strategy]
)
key = draw(st.one_of(*key_options))

return state, address, key


class TestState:
class TestStateAccounts:
@given(
data=state_and_address_and_key(state_strategy=state, address_strategy=address),
)
Expand All @@ -59,6 +65,8 @@ def test_get_account_optional(self, cairo_run, data):
assert result_cairo == result_py
assert state_cairo == state


class TestStateStorage:
@given(
data=state_and_address_and_key(
state_strategy=state, address_strategy=address, key_strategy=bytes32
Expand All @@ -74,3 +82,21 @@ def test_get_storage(
result_py = get_storage(state, address, key)
assert result_cairo == result_py
assert state_cairo == state

@given(
data=state_and_address_and_key(
state_strategy=state, address_strategy=address, key_strategy=bytes32
),
value=...,
)
def test_set_storage(self, cairo_run, data, value: U256):
state, address, key = data
try:
state_cairo = cairo_run("set_storage", state, address, key, value)
except Exception as e:
with pytest.raises(type(e)):
set_storage(state, address, key, value)
return

set_storage(state, address, key, value)
assert state_cairo == state
6 changes: 5 additions & 1 deletion cairo/tests/ethereum/cancun/test_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def test_trie_get_TrieBytes32U256(

@given(trie=..., key=..., value=...)
def test_trie_set_TrieAddressAccount(
self, cairo_run, trie: Trie[Address, Account], key: Address, value: Account
self,
cairo_run,
trie: Trie[Address, Optional[Account]],
key: Address,
value: Account,
):
cairo_trie = cairo_run("trie_set_TrieAddressAccount", trie, key, value)
trie_set(trie, key, value)
Expand Down
1 change: 1 addition & 0 deletions cairo/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def test_type(
Trie[Bytes32, U256],
Trie[Address, Optional[Account]],
TransientStorage,
Mapping[Address, Trie[Bytes32, U256]],
State,
Tuple[
Trie[Address, Optional[Account]], Mapping[Address, Trie[Bytes32, U256]]
Expand Down
Loading

0 comments on commit 1436c75

Please sign in to comment.