Skip to content

Commit

Permalink
feat: rework error types (#338)
Browse files Browse the repository at this point in the history
read PR on gas before as this follows up
  • Loading branch information
enitrat authored Jan 3, 2025
1 parent 3226794 commit 88e7192
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 98 deletions.
2 changes: 1 addition & 1 deletion cairo/ethereum/cancun/vm.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ from ethereum.cancun.transactions import To
from ethereum.cancun.vm.stack import Stack
from ethereum.cancun.vm.memory import Memory

using OptionalEthereumException = EthereumException;
using OptionalEthereumException = EthereumException*;
using OptionalEvm = Evm;
using OptionalAddress = Address*;

Expand Down
70 changes: 22 additions & 48 deletions cairo/ethereum/cancun/vm/exceptions.cairo
Original file line number Diff line number Diff line change
@@ -1,49 +1,23 @@
from ethereum_types.bytes import BytesStruct

struct StackUnderflowError {
value: BytesStruct*,
}

struct StackOverflowError {
value: BytesStruct*,
}

struct OutOfGasError {
value: BytesStruct*,
}

struct InvalidOpcodeError {
value: BytesStruct*,
}

struct InvalidJumpDestError {
value: BytesStruct*,
}

struct StackDepthLimitError {
value: BytesStruct*,
}

struct WriteInStaticContextError {
value: BytesStruct*,
}

struct OutOfBoundsReadError {
value: BytesStruct*,
}

struct InvalidParameterError {
value: BytesStruct*,
}

struct InvalidContractPrefixError {
value: BytesStruct*,
}

struct AddressCollisionError {
value: BytesStruct*,
}

struct KZGProofError {
value: BytesStruct*,
struct ExceptionalHalt {
value: felt,
}

const StackUnderflowError = 'StackUnderflowError';
const StackOverflowError = 'StackOverflowError';
const OutOfGasError = 'OutOfGasError';
const InvalidOpcode = 'InvalidOpcode';
const InvalidJumpDestError = 'InvalidJumpDestError';
const StackDepthLimitError = 'StackDepthLimitError';
const WriteInStaticContext = 'WriteInStaticContext';
const OutOfBoundsRead = 'OutOfBoundsRead';
const InvalidParameter = 'InvalidParameter';
const InvalidContractPrefix = 'InvalidContractPrefix';
const AddressCollision = 'AddressCollision';
const KZGProofError = 'KZGProofError';

func InvalidOpcodeError(param: felt) -> ExceptionalHalt {
let param = param * 2 ** 30;
let error_string = InvalidOpcode + param;
let res = ExceptionalHalt(error_string);
return res;
}
8 changes: 4 additions & 4 deletions cairo/ethereum/cancun/vm/gas.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from ethereum_types.bytes import BytesStruct
from ethereum.cancun.blocks import Header
from ethereum.cancun.transactions import Transaction
from ethereum.cancun.vm import Evm, EvmStruct, EvmImpl
from ethereum.cancun.vm.exceptions import OutOfGasError
from ethereum.cancun.vm.exceptions import ExceptionalHalt, OutOfGasError

from starkware.cairo.common.math_cmp import is_le, is_not_zero, RC_BOUND
from starkware.cairo.common.math import assert_le_felt
Expand Down Expand Up @@ -32,7 +32,7 @@ struct MessageCallGas {
// @param evm The pointer to the current execution context.
// @param amount The amount of gas the current operation requires.
// @return EVM The pointer to the updated execution context.
func charge_gas{range_check_ptr, evm: Evm}(amount: Uint) -> OutOfGasError {
func charge_gas{range_check_ptr, evm: Evm}(amount: Uint) -> ExceptionalHalt* {
// This is equivalent to is_nn(evm.value.gas_left - amount)
with_attr error_message("charge_gas: gas_left > 2**128") {
assert [range_check_ptr] = evm.value.gas_left.value;
Expand Down Expand Up @@ -64,14 +64,14 @@ func charge_gas{range_check_ptr, evm: Evm}(amount: Uint) -> OutOfGasError {
let evm_struct = cast([fp - 4], EvmStruct*);
tempvar evm = Evm(evm_struct);
EvmImpl.set_gas_left(Uint(a));
tempvar ok = OutOfGasError(cast(0, BytesStruct*));
tempvar ok = cast(0, ExceptionalHalt*);
return ok;

not_enough_gas:
let range_check_ptr = [ap - 1];
let evm_struct = cast([fp - 4], EvmStruct*);
tempvar evm = Evm(evm_struct);
tempvar err = OutOfGasError(new BytesStruct(cast(0, felt*), 0));
tempvar err = new ExceptionalHalt(OutOfGasError);
return err;
}

Expand Down
14 changes: 7 additions & 7 deletions cairo/ethereum/cancun/vm/stack.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ethereum_types.numeric import U256, U256Struct
from ethereum_types.bytes import BytesStruct, Bytes
from starkware.cairo.common.dict import DictAccess, dict_read, dict_write
from ethereum.cancun.vm.exceptions import StackOverflowError, StackUnderflowError
from ethereum.cancun.vm.exceptions import ExceptionalHalt, StackOverflowError, StackUnderflowError

struct Stack {
value: StackStruct*,
Expand All @@ -21,11 +21,11 @@ struct StackDictAccess {

const STACK_MAX_SIZE = 1024;

func pop{stack: Stack}() -> (U256, StackUnderflowError) {
func pop{stack: Stack}() -> (U256, ExceptionalHalt*) {
alloc_locals;
let len = stack.value.len;
if (len == 0) {
tempvar err = StackUnderflowError(new BytesStruct(cast(0, felt*), 0));
tempvar err = new ExceptionalHalt(StackUnderflowError);
let val = U256(cast(0, U256Struct*));
return (val, err);
}
Expand All @@ -39,15 +39,15 @@ func pop{stack: Stack}() -> (U256, StackUnderflowError) {
tempvar stack = Stack(new StackStruct(stack.value.dict_ptr_start, new_dict_ptr, len - 1));
tempvar value = U256(cast(pointer, U256Struct*));

tempvar ok = StackUnderflowError(cast(0, BytesStruct*));
tempvar ok = cast(0, ExceptionalHalt*);
return (value, ok);
}

func push{stack: Stack}(value: U256) -> StackOverflowError {
func push{stack: Stack}(value: U256) -> ExceptionalHalt* {
alloc_locals;
let len = stack.value.len;
if (len == STACK_MAX_SIZE) {
tempvar err = StackOverflowError(new BytesStruct(cast(0, felt*), 0));
tempvar err = new ExceptionalHalt(StackOverflowError);
return err;
}

Expand All @@ -58,7 +58,7 @@ func push{stack: Stack}(value: U256) -> StackOverflowError {
let new_dict_ptr = cast(dict_ptr, StackDictAccess*);

tempvar stack = Stack(new StackStruct(stack.value.dict_ptr_start, new_dict_ptr, len + 1));
tempvar ok = StackOverflowError(cast(0, BytesStruct*));
tempvar ok = cast(0, ExceptionalHalt*);

return ok;
}
10 changes: 4 additions & 6 deletions cairo/ethereum/exceptions.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@
// Example:
// This is an error:
// ```
// tempvar inner_error = new BytesStruct(cast(0, felt*), 0);
// let error = EthereumException(inner_error);
// let error = cast(0, EthereumException*);
// ```
//
// This is not an error:
// ```
// tempvar no_error = EthereumException(cast(0, BytesStruct*));
// from ethereum.cancun.vm.exceptions import StackUnderflowError
// tempvar no_error = new EthereumException(StackUnderflowError);
// ```

from ethereum_types.bytes import BytesStruct

// @notice Base type for all exceptions _expected_ to be thrown during normal
// operation.
struct EthereumException {
value: BytesStruct*,
value: felt,
}
2 changes: 1 addition & 1 deletion cairo/tests/test_serde.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ from ethereum.cancun.vm.gas import MessageCallGas

from ethereum.cancun.trie import BranchNode, ExtensionNode, InternalNode, LeafNode, Node, Subnodes
from ethereum.exceptions import EthereumException
from ethereum.cancun.vm.exceptions import StackOverflowError, StackUnderflowError
from ethereum.cancun.vm.exceptions import ExceptionalHalt
from ethereum.cancun.state import TransientStorage
from ethereum.cancun.vm import Environment
13 changes: 10 additions & 3 deletions cairo/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
Node,
Trie,
)
from ethereum.cancun.vm.exceptions import StackOverflowError, StackUnderflowError
from ethereum.cancun.vm.exceptions import (
InvalidOpcode,
StackOverflowError,
StackUnderflowError,
)
from ethereum.cancun.vm.gas import MessageCallGas
from ethereum.crypto.hash import Hash32
from ethereum.exceptions import EthereumException
Expand Down Expand Up @@ -257,14 +261,17 @@ def test_exception(
segments,
serde,
gen_arg,
err: Union[EthereumException, StackOverflowError, StackUnderflowError],
err: Union[
EthereumException, StackOverflowError, StackUnderflowError, InvalidOpcode
],
):
base = segments.gen_arg([gen_arg(type(err), err)])
result = serde.serialize(to_cairo_type(type(err)), base, shift=0)
assert issubclass(result.__class__, Exception)

@pytest.mark.parametrize(
"error_type", [EthereumException, StackOverflowError, StackUnderflowError]
"error_type",
[EthereumException, StackOverflowError, StackUnderflowError, InvalidOpcode],
)
def test_none_exception(self, to_cairo_type, serde, gen_arg, error_type):
base = gen_arg(error_type, None)
Expand Down
26 changes: 13 additions & 13 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __eq__(self, other):

vm_exception_classes = inspect.getmembers(
sys.modules["ethereum.cancun.vm.exceptions"],
lambda x: inspect.isclass(x) and issubclass(x, ExceptionalHalt),
lambda x: inspect.isclass(x) and issubclass(x, EthereumException),
)

vm_exception_mappings = {
Expand All @@ -230,7 +230,6 @@ def __eq__(self, other):
f"{name}",
): cls
for name, cls in vm_exception_classes
if cls is not ExceptionalHalt
}

_cairo_struct_to_python_type: Dict[Tuple[str, ...], Any] = {
Expand Down Expand Up @@ -589,16 +588,12 @@ def _gen_arg(
)

if isinstance(arg_type, type) and issubclass(arg_type, Exception):
# For exceptions, we either return 0 (no error) or create an error with a message
# For exceptions, we either return 0 (no error) or the ascii representation of the error message
if arg is None:
return 0

error_bytes = str(arg).encode()
message_ptr = segments.add()
segments.load_data(message_ptr, list(error_bytes))
struct_ptr = segments.add()
segments.load_data(struct_ptr, [message_ptr, len(error_bytes)])
return struct_ptr
error_bytes = str(arg.__class__.__name__).encode()
error_int = int.from_bytes(error_bytes, "big")
return error_int

return arg

Expand Down Expand Up @@ -648,9 +643,14 @@ def to_cairo_type(program: Program, type_name: Type):
_python_type_to_cairo_struct = {
v: k for k, v in _cairo_struct_to_python_type.items()
}
scope = ScopedName(
_python_type_to_cairo_struct[_type_aliases.get(type_name, type_name)]
)

if isinstance(type_name, type) and issubclass(type_name, Exception):
scope = ScopedName(_python_type_to_cairo_struct[ExceptionalHalt])
else:
scope = ScopedName(
_python_type_to_cairo_struct[_type_aliases.get(type_name, type_name)]
)

identifier = program.identifiers.as_dict()[scope]

if isinstance(identifier, TypeDefinition):
Expand Down
41 changes: 26 additions & 15 deletions cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@
from starkware.cairo.lang.vm.memory_dict import UnknownMemoryError
from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager

from ethereum.cancun.vm.exceptions import InvalidOpcode
from ethereum.crypto.hash import Hash32
from tests.utils.args_gen import Memory, Stack, to_python_type
from tests.utils.args_gen import Memory, Stack, to_python_type, vm_exception_classes

# Sentinel object for indicating no error in exception handling
NO_ERROR_FLAG = object()
Expand Down Expand Up @@ -298,21 +299,22 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
and isinstance(python_cls, type)
and issubclass(python_cls, Exception)
):
tuple_struct_ptr = self.serialize_pointers(path, ptr)["value"]
if not tuple_struct_ptr:
error_value = self.serialize_pointers(path, ptr)["value"]
if error_value == 0:
return NO_ERROR_FLAG
value_type = (
get_struct_definition(self.program, path).members["value"].cairo_type
)
struct_name = value_type.pointee.scope.path[-1]
path = (*path[:-1], struct_name)
raw = self.serialize_pointers(path, tuple_struct_ptr)
error_bytes = bytes(
[self.memory.get(raw["data"] + i) for i in range(raw["len"])]
# Get the first 30 bytes for the error message
error_bytes = (error_value & ((1 << (30 * 8)) - 1)).to_bytes(30, "big")
ascii_value = error_bytes.decode().strip("\x00")
actual_error_cls = next(
(cls for name, cls in vm_exception_classes if name == ascii_value), None
)
if error_bytes == b"":
return python_cls()
return python_cls(error_bytes.decode())
if actual_error_cls is None:
raise ValueError(f"Unknown error class: {ascii_value}")
if actual_error_cls is InvalidOpcode:
# Custom parameter (isolated case)
param_value = (error_value >> (30 * 8)) & 0xFF
return InvalidOpcode(param_value)
return actual_error_cls()

if python_cls == Bytes256:
base_ptr = self.memory.get(ptr)
Expand Down Expand Up @@ -407,7 +409,16 @@ def _serialize(self, cairo_type, ptr, length=1):
pointee = self.memory.get(ptr)
# Edge case: 0 pointers are not pointer but no data
if pointee == 0:
return None
if isinstance(cairo_type.pointee, TypeFelt):
return None
# If the pointer is to an exception, return the error flag
python_cls = to_python_type(cairo_type.pointee.scope.path)
return (
NO_ERROR_FLAG
if isinstance(python_cls, type)
and issubclass(python_cls, Exception)
else None
)
if isinstance(cairo_type.pointee, TypeFelt):
return self.serialize_list(pointee)
serialized = self.serialize_list(
Expand Down

0 comments on commit 88e7192

Please sign in to comment.