Skip to content

Commit

Permalink
fix strategy generations
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Jan 16, 2025
1 parent 09e6bbb commit b735b46
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 31 deletions.
55 changes: 38 additions & 17 deletions cairo/tests/ethereum/cancun/test_state.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Optional

import pytest
from ethereum_types.bytes import Bytes32
from ethereum_types.numeric import U256
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import composite

from ethereum.cancun.fork_types import Account, Address
from ethereum.cancun.fork_types import Account
from ethereum.cancun.state import (
account_exists,
account_exists_and_is_empty,
Expand All @@ -26,7 +25,6 @@
set_storage,
set_transient_storage,
)
from tests.utils.args_gen import TransientStorage
from tests.utils.strategies import address, bytes32, state, transient_storage

pytestmark = pytest.mark.python_vm
Expand Down Expand Up @@ -63,6 +61,37 @@ def state_and_address_and_optional_key(
return state, address, key


@composite
def transient_storage_and_address_and_optional_key(
draw,
transient_storage_strategy=transient_storage,
address_strategy=address,
key_strategy=None,
):
transient_storage = draw(transient_storage_strategy)

# Shuffle from a random addres of an
address_options = []
if transient_storage._tries:
address_options.append(st.sampled_from(list(transient_storage._tries.keys())))
address_options.append(address_strategy)
address = draw(st.one_of(*address_options))

if key_strategy is None:
return transient_storage, address

# Shuffle from a random key of the address, if it exists
key_options = []
if address in transient_storage._tries:
key_options.append(
st.sampled_from(list(transient_storage._tries[address]._data.keys()))
)
key_options.append(key_strategy)
key = draw(st.one_of(*key_options))

return transient_storage, address, key


class TestStateAccounts:
@given(data=state_and_address_and_optional_key())
def test_get_account(self, cairo_run, data):
Expand Down Expand Up @@ -182,18 +211,13 @@ def test_destroy_storage(self, cairo_run, data):


class TestTransientStorage:
@given(
transient_storage=transient_storage,
address=...,
key=...,
)
@given(data=transient_storage_and_address_and_optional_key(key_strategy=bytes32))
def test_get_transient_storage(
self,
cairo_run,
transient_storage: TransientStorage,
address: Address,
key: Bytes32,
data,
):
transient_storage, address, key = data
transient_storage_cairo, result_cairo = cairo_run(
"get_transient_storage",
transient_storage,
Expand All @@ -204,19 +228,16 @@ def test_get_transient_storage(
assert transient_storage_cairo == transient_storage

@given(
transient_storage=transient_storage,
address=...,
key=...,
data=transient_storage_and_address_and_optional_key(key_strategy=bytes32),
value=...,
)
def test_set_transient_storage(
self,
cairo_run,
transient_storage: TransientStorage,
address: Address,
key: Bytes32,
data,
value: U256,
):
transient_storage, address, key = data
transient_storage_cairo = cairo_run(
"set_transient_storage",
transient_storage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def evm_with_accessed_storage_keys(draw):
if not use_random_key and accessed_storage_keys:
# Draw a key from the set and put it on top of the stack
_, key = draw(st.sampled_from(accessed_storage_keys))
evm.stack.insert(0, U256.from_be_bytes(key))
evm.stack.insert(0, U256.from_le_bytes(key))

return evm

Expand Down
24 changes: 11 additions & 13 deletions cairo/tests/utils/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
)


def trie_strategy(thing):
def trie_strategy(thing, min_size=0):
key_type, value_type = thing.__args__
value_type_origin = get_origin(value_type) or value_type

Expand All @@ -127,7 +127,7 @@ def trie_strategy(thing):
st.none()
if value_type_origin is Union and get_args(value_type)[1] is type(None)
else (
st.just(value_type(0))
st.just(U256(0))
if value_type is U256
else st.nothing() # No other type is accepted
)
Expand All @@ -139,8 +139,11 @@ def non_default_strategy(default):
# For Optional types, just use the base type strategy (which won't generate None)
base_type = get_args(value_type)[0]
return st.from_type(base_type)
# For other types, use the regular strategy as default values are rare, and filter them out
return st.from_type(value_type).filter(lambda x: x != default)
elif value_type is U256:
# For U256, we don't want to generate 0 as default value
return st.integers(min_value=1, max_value=2**256 - 1).map(U256)
else:
raise ValueError(f"Unsupported default type in Trie: {value_type}")

# In a trie, a key that has a default value is considered not included in the trie.
# Thus it needs to be filtered out from the data generated.
Expand All @@ -152,6 +155,7 @@ def non_default_strategy(default):
_data=st.dictionaries(
st.from_type(key_type),
non_default_strategy(default),
min_size=min_size,
max_size=20,
),
)
Expand Down Expand Up @@ -240,18 +244,14 @@ def tuple_strategy(thing):
TransientStorage,
_tries=st.fixed_dictionaries(
{
address: trie_strategy(Trie[Bytes32, U256]).filter(
lambda t: bool(t._data)
)
address: trie_strategy(Trie[Bytes32, U256], min_size=1)
for address in addresses
}
),
_snapshots=st.lists(
st.fixed_dictionaries(
{
address: trie_strategy(Trie[Bytes32, U256]).filter(
lambda t: bool(t._data)
)
address: trie_strategy(Trie[Bytes32, U256], min_size=1)
for address in addresses
}
),
Expand Down Expand Up @@ -369,9 +369,7 @@ def tuple_strategy(thing):
_storage_tries=st.integers(max_value=len(addresses)).flatmap(
lambda i: st.fixed_dictionaries(
{
address: trie_strategy(Trie[Bytes32, U256]).filter(
lambda t: bool(t._data)
)
address: trie_strategy(Trie[Bytes32, U256], min_size=1)
for address in addresses[:i]
}
)
Expand Down

0 comments on commit b735b46

Please sign in to comment.