Skip to content

Commit

Permalink
Merge branch 'main' into feat/increment_nonce
Browse files Browse the repository at this point in the history
  • Loading branch information
Eikix authored Jan 17, 2025
2 parents 2ff10d8 + a2d2f0f commit 609320d
Show file tree
Hide file tree
Showing 15 changed files with 260 additions and 71 deletions.
2 changes: 1 addition & 1 deletion cairo/ethereum/cancun/vm.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ namespace EvmImpl {
return ();
}

func set_refund_counter{evm: Evm}(new_refund_counter: Uint) {
func set_refund_counter{evm: Evm}(new_refund_counter: felt) {
tempvar evm = Evm(
new EvmStruct(
pc=evm.value.pc,
Expand Down
149 changes: 147 additions & 2 deletions cairo/ethereum/cancun/vm/instructions/storage.cairo
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from ethereum.cancun.vm.stack import pop, push
from ethereum.cancun.vm import Evm, EvmImpl, Environment, EnvImpl
from ethereum.cancun.vm.exceptions import ExceptionalHalt, WriteInStaticContext
from ethereum.cancun.vm.exceptions import ExceptionalHalt, WriteInStaticContext, OutOfGasError
from ethereum.cancun.vm.gas import charge_gas, GasConstants
from ethereum.cancun.state import get_storage, get_transient_storage, set_transient_storage
from ethereum.utils.numeric import U256__eq__
from ethereum.cancun.state import (
get_storage,
get_storage_original,
set_storage,
get_transient_storage,
set_transient_storage,
)
from ethereum.cancun.fork_types import (
SetTupleAddressBytes32,
SetTupleAddressBytes32DictAccess,
Expand All @@ -21,6 +28,7 @@ from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, BitwiseBuilti
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.math_cmp import is_le

// @notice Loads to the stack the value corresponding to a certain key from the
// storage of the current account.
Expand Down Expand Up @@ -96,6 +104,143 @@ func sload{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, poseidon_ptr: Poseidon
return ok;
}

// @notice Stores a value at a certain key in the current context's storage.
func sstore{
range_check_ptr, bitwise_ptr: BitwiseBuiltin*, poseidon_ptr: PoseidonBuiltin*, evm: Evm
}() -> ExceptionalHalt* {
alloc_locals;
// STACK
let stack = evm.value.stack;
with stack {
let (key, err) = pop();
if (cast(err, felt) != 0) {
return err;
}
let (new_value, err) = pop();
if (cast(err, felt) != 0) {
return err;
}
}

let is_gas_left_not_enough = is_le(evm.value.gas_left.value, GasConstants.GAS_CALL_STIPEND);
if (is_gas_left_not_enough != 0) {
tempvar err = new ExceptionalHalt(OutOfGasError);
return err;
}

// Get storage values
let key_bytes32 = U256_to_be_bytes(key);
let state = evm.value.env.value.state;
let current_target = evm.value.message.value.current_target;
with state {
let original_value = get_storage_original(current_target, key_bytes32);
let current_value = get_storage(current_target, key_bytes32);
}

// Gas calculation
// Check accessed storage keys
tempvar accessed_tuple = TupleAddressBytes32(
new TupleAddressBytes32Struct(current_target, key_bytes32)
);
let (serialized_keys: felt*) = alloc();
assert serialized_keys[0] = accessed_tuple.value.address.value;
assert serialized_keys[1] = accessed_tuple.value.bytes32.value.low;
assert serialized_keys[2] = accessed_tuple.value.bytes32.value.high;
let dict_ptr = cast(evm.value.accessed_storage_keys.value.dict_ptr, DictAccess*);
with dict_ptr {
let (is_present) = hashdict_read(3, serialized_keys);
if (is_present == 0) {
hashdict_write(3, serialized_keys, 1);
tempvar gas_cost = GasConstants.GAS_COLD_SLOAD;
tempvar poseidon_ptr = poseidon_ptr;
tempvar dict_ptr = dict_ptr;
} else {
tempvar gas_cost = 0;
tempvar poseidon_ptr = poseidon_ptr;
tempvar dict_ptr = dict_ptr;
}
}
let gas_cost = [ap - 3];
let poseidon_ptr = cast([ap - 2], PoseidonBuiltin*);
let dict_ptr = cast([ap - 1], DictAccess*);

let new_dict_ptr = cast(dict_ptr, SetTupleAddressBytes32DictAccess*);
tempvar new_accessed_storage_keys = SetTupleAddressBytes32(
new SetTupleAddressBytes32Struct(
evm.value.accessed_storage_keys.value.dict_ptr_start, new_dict_ptr
),
);

// Calculate storage gas cost
tempvar zero_u256 = U256(new U256Struct(0, 0));
let is_original_eq_current = U256__eq__(original_value, current_value);
let is_current_eq_new = U256__eq__(current_value, new_value);
let is_original_zero = U256__eq__(original_value, zero_u256);
if (is_original_eq_current.value != 0) {
if (is_current_eq_new.value == 0) {
if (is_original_zero.value != 0) {
tempvar gas_cost = gas_cost + GasConstants.GAS_STORAGE_SET;
} else {
tempvar gas_cost = gas_cost + (
GasConstants.GAS_STORAGE_UPDATE - GasConstants.GAS_COLD_SLOAD
);
}
}
} else {
tempvar gas_cost = GasConstants.GAS_WARM_ACCESS;
}
let gas_cost = [ap - 1];

tempvar refund_counter = evm.value.refund_counter;
let is_original_eq_new = U256__eq__(original_value, new_value);
// Refund calculation
if (is_current_eq_new.value == 0) {
let is_current_zero = U256__eq__(current_value, zero_u256);
let is_new_zero = U256__eq__(new_value, zero_u256);
if (is_original_zero.value == 0 and is_current_zero.value == 0 and is_new_zero.value != 0) {
refund_counter = refund_counter + GasConstants.GAS_STORAGE_CLEAR_REFUND;
}
if (is_original_zero.value == 0 and is_current_zero.value != 0) {
refund_counter = refund_counter - GasConstants.GAS_STORAGE_CLEAR_REFUND;
}
if (is_original_eq_new.value != 0) {
if (is_original_zero.value != 0) {
refund_counter = refund_counter +
(GasConstants.GAS_STORAGE_SET - GasConstants.GAS_WARM_ACCESS);
} else {
refund_counter = refund_counter +
(GasConstants.GAS_STORAGE_UPDATE - GasConstants.GAS_COLD_SLOAD - GasConstants.GAS_WARM_ACCESS);
}
}
}

// Charge gas
let err = charge_gas(Uint(gas_cost));
if (cast(err, felt) != 0) {
return err;
}
// Check static call
if (evm.value.message.value.is_static.value != 0) {
tempvar err = new ExceptionalHalt(WriteInStaticContext);
return err;
}

// Set storage
with state {
set_storage(current_target, key_bytes32, new_value);
}

// Update EVM state
let env = evm.value.env;
EnvImpl.set_state{env=env}(state);
EvmImpl.set_env(env);
EvmImpl.set_pc_stack(Uint(evm.value.pc.value + 1), stack);
EvmImpl.set_refund_counter(refund_counter);
EvmImpl.set_accessed_storage_keys(new_accessed_storage_keys);
let ok = cast(0, ExceptionalHalt*);
return ok;
}

// @notice Loads to the stack the value corresponding to a certain key from the
// transient storage of the current account.
func tload{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, poseidon_ptr: PoseidonBuiltin*, evm: Evm}(
Expand Down
24 changes: 12 additions & 12 deletions cairo/tests/ethereum/cancun/vm/instructions/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from hypothesis import given

from ethereum.cancun.vm.instructions.arithmetic import (
Expand All @@ -15,6 +14,7 @@
sub,
)
from tests.utils.args_gen import Evm
from tests.utils.errors import strict_raises
from tests.utils.evm_builder import EvmBuilder

arithmetic_tests_strategy = EvmBuilder().with_stack().with_gas_left().build()
Expand All @@ -26,7 +26,7 @@ def test_add(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("add", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
add(evm)
return

Expand All @@ -38,7 +38,7 @@ def test_sub(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("sub", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
sub(evm)
return

Expand All @@ -50,7 +50,7 @@ def test_mul(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("mul", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
mul(evm)
return

Expand All @@ -63,7 +63,7 @@ def test_div(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("div", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
div(evm)
return

Expand All @@ -75,7 +75,7 @@ def test_sdiv(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("sdiv", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
sdiv(evm)
return

Expand All @@ -87,7 +87,7 @@ def test_mod(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("mod", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
mod(evm)
return

Expand All @@ -99,7 +99,7 @@ def test_smod(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("smod", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
smod(evm)
return

Expand All @@ -111,7 +111,7 @@ def test_addmod(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("addmod", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
addmod(evm)
return

Expand All @@ -123,7 +123,7 @@ def test_mulmod(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("mulmod", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
mulmod(evm)
return

Expand All @@ -135,7 +135,7 @@ def test_exp(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("exp", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
exp(evm)
return

Expand All @@ -147,7 +147,7 @@ def test_signextend(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("signextend", evm)
except Exception as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
signextend(evm)
return

Expand Down
18 changes: 9 additions & 9 deletions cairo/tests/ethereum/cancun/vm/instructions/test_bitwise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from hypothesis import given

from ethereum.cancun.vm.exceptions import ExceptionalHalt
Expand All @@ -13,6 +12,7 @@
get_byte,
)
from tests.utils.args_gen import Evm
from tests.utils.errors import strict_raises
from tests.utils.evm_builder import EvmBuilder

bitwise_tests_strategy = EvmBuilder().with_stack().with_gas_left().build()
Expand All @@ -24,7 +24,7 @@ def test_and(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("bitwise_and", evm)
except ExceptionalHalt as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
bitwise_and(evm)
return

Expand All @@ -36,7 +36,7 @@ def test_or(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("bitwise_or", evm)
except ExceptionalHalt as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
bitwise_or(evm)
return

Expand All @@ -48,7 +48,7 @@ def test_xor(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("bitwise_xor", evm)
except ExceptionalHalt as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
bitwise_xor(evm)
return

Expand All @@ -60,7 +60,7 @@ def test_not(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("bitwise_not", evm)
except ExceptionalHalt as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
bitwise_not(evm)
return

Expand All @@ -72,7 +72,7 @@ def test_get_byte(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("get_byte", evm)
except ExceptionalHalt as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
get_byte(evm)
return

Expand All @@ -84,7 +84,7 @@ def test_shl(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("bitwise_shl", evm)
except ExceptionalHalt as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
bitwise_shl(evm)
return

Expand All @@ -96,7 +96,7 @@ def test_shr(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("bitwise_shr", evm)
except ExceptionalHalt as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
bitwise_shr(evm)
return

Expand All @@ -108,7 +108,7 @@ def test_sar(self, cairo_run, evm: Evm):
try:
cairo_result = cairo_run("bitwise_sar", evm)
except ExceptionalHalt as cairo_error:
with pytest.raises(type(cairo_error)):
with strict_raises(type(cairo_error)):
bitwise_sar(evm)
return

Expand Down
Loading

0 comments on commit 609320d

Please sign in to comment.