From 62be088688b7be91d5d917646e3192eb7cee8610 Mon Sep 17 00:00:00 2001 From: FabijanC Date: Fri, 23 Dec 2022 15:56:39 +0100 Subject: [PATCH] Introduce Rust VM option (#372) * [skip ci] - a release commit follows which will run the CI/CD workflow Co-authored-by: fmoletta <99273364+fmoletta@users.noreply.github.com> Co-authored-by: Federica Co-authored-by: jrigada Co-authored-by: Juan Rigada <62958725+Jrigada@users.noreply.github.com> --- .circleci/config.yml | 6 + .dockerignore | 4 + Dockerfile | 11 +- page/docs/guide/development.md | 4 +- page/docs/guide/run.md | 36 ++ page/docs/intro.md | 18 +- poetry.lock | 14 +- pyproject.toml | 3 + scripts/install_dev_tools.sh | 9 + starknet_devnet/__init__.py | 21 ++ starknet_devnet/cairo_rs_py_patch.py | 542 +++++++++++++++++++++++++++ test/test_account_custom.py | 2 +- test/test_cairo_vm.py | 68 ++++ test/test_dump.py | 8 +- test/test_fork_cli_params.py | 19 +- test/util.py | 12 +- 16 files changed, 745 insertions(+), 32 deletions(-) create mode 100644 starknet_devnet/cairo_rs_py_patch.py create mode 100644 test/test_cairo_vm.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 94f877dd5..94d1917bf 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -65,6 +65,11 @@ jobs: <<: *test_template docker: - image: cimg/python:3.9-node + + test_py_min_rust: + <<: *test_template + environment: + STARKNET_DEVNET_CAIRO_VM: "rust" package_build_and_publish: docker: @@ -120,6 +125,7 @@ workflows: jobs: - test_py_min - test_py_max + - test_py_min_rust - image_build: <<: *on_master - image_build_arm: diff --git a/.dockerignore b/.dockerignore index 3af73c72e..cc15dba7f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -15,3 +15,7 @@ Dockerfile test/ .pylintrc page/ + +# Hardhat specific +node_modules/ +hardhat-cache/ diff --git a/Dockerfile b/Dockerfile index 69fa9aca5..eef2fe940 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,23 +6,22 @@ RUN apk add gmp-dev g++ gcc libffi-dev RUN pip3 install poetry -RUN poetry build -f wheel - -RUN poetry export -f requirements.txt --without-hashes > requirements.txt +# install rustc+cargo +RUN wget https://sh.rustup.rs -O - | sh -s -- -y +ENV PATH="/root/.cargo/bin:${PATH}" +RUN poetry build -f wheel +RUN poetry export -f requirements.txt --without-hashes --with vm > requirements.txt RUN pip3 wheel --no-cache-dir --no-deps --wheel-dir /wheels -r requirements.txt - FROM python:3.9.13-alpine3.16 RUN apk add --no-cache libgmpxx COPY --from=builder /dist/*.whl /wheels/ - COPY --from=builder /wheels /wheels RUN pip3 install --no-cache /wheels/* - RUN rm -rf /wheels ENV PYTHONUNBUFFERED=1 diff --git a/page/docs/guide/development.md b/page/docs/guide/development.md index cda95a0f8..2909d0525 100644 --- a/page/docs/guide/development.md +++ b/page/docs/guide/development.md @@ -4,9 +4,9 @@ sidebar_position: 18 # Development -If you're a developer willing to contribute, be sure to have installed [**Poetry**](https://pypi.org/project/poetry/) and all the dependency packages by running the following script. You are expected to have [**npm**](https://www.npmjs.com/). +If you're a developer willing to contribute, be sure to have installed [**Poetry**](https://pypi.org/project/poetry/) and all the dependency packages by running the installation script. Prerequisites for running the script: `gcc`, `g++`, `gmp`, `npm`. -```text +```bash ./scripts/install_dev_tools.sh ``` diff --git a/page/docs/guide/run.md b/page/docs/guide/run.md index e6d47af9c..f33c3a03c 100644 --- a/page/docs/guide/run.md +++ b/page/docs/guide/run.md @@ -119,3 +119,39 @@ docker run -p 127.0.0.1:5050:5050 shardlabs/starknet-devnet You may ignore any address-related output logged on container startup (e.g. `Running on all addresses` or `Running on http://172.17.0.2:5050`). What you will use is what you specified with the `-p` argument. If you don't specify the `HOST` part, the server will indeed be available on all of your host machine's addresses (localhost, local network IP, etc.), which may present a security issue if you don't want anyone from the local network to access your Devnet instance. + +## Run with the Rust implementation of Cairo VM + +By default, Devnet uses the [Python implementation](https://github.com/starkware-libs/cairo-lang/) of Cairo VM. + +Using the Rust implementation brings improvement for Cairo-VM-intensive operations, but introduces its own overhead, so it may not be useful for simple contracts. + +You can enable it by following these steps: + +1. Install compilers + +Make sure you have `gcc`, `g++` and [Rust](https://www.rust-lang.org/tools/install). + +2. Install [cairo-rs-py](https://github.com/lambdaclass/cairo-rs-py) in the [**same environment**](https://docs.python.org/3/library/venv.html) as Devnet: + +```bash +$ pip install cairo-rs-py +``` + +3. Set `STARKNET_DEVNET_CAIRO_VM=rust` + +```bash +$ STARKNET_DEVNET_CAIRO_VM=rust starknet-devnet +``` + +With Docker, use `-e`: + +```bash +$ docker run -it [OPTIONS] -e STARKNET_DEVNET_CAIRO_VM=rust shardlabs/starknet-devnet [ARGS] +``` + +To use the Python VM, **unset** the variable or set it to `python` + +```bash +$ STARKNET_DEVNET_CAIRO_VM=python starknet-devnet +``` diff --git a/page/docs/intro.md b/page/docs/intro.md index 8993ae6cb..03aad192f 100644 --- a/page/docs/intro.md +++ b/page/docs/intro.md @@ -1,11 +1,14 @@ --- sidebar_position: 1 --- + # Getting Started Let's discover **[starknet-devnet](https://github.com/Shard-Labs/starknet-devnet)**. :::danger Take care + ## ⚠️ Disclaimer ⚠️ + ::: - Devnet should not be used as a replacement for Alpha testnet. After testing on Devnet, be sure to test on testnet (alpha-goerli)! @@ -20,23 +23,30 @@ Works with Python versions >=3.8 and <3.10. On Ubuntu/Debian, first run: - ```bash -sudo apt install -y libgmp3-dev +$ sudo apt install -y libgmp3-dev ``` On Mac, you can use `brew`: ```bash -brew install gmp +$ brew install gmp ``` ## Install ```bash -pip install starknet-devnet +$ pip install starknet-devnet ``` +## Run + +``` +$ starknet-devnet +``` + +For more running possibilities, see [this](https://shard-labs.github.io/starknet-devnet/docs/guide/run). + ### Windows installation Follow this guide: https://www.spaceshard.io/blog/starknet-devnet-windows-tutorial diff --git a/poetry.lock b/poetry.lock index 785edf824..d8d709c49 100644 --- a/poetry.lock +++ b/poetry.lock @@ -156,6 +156,14 @@ sympy = "*" typeguard = "*" Web3 = "*" +[[package]] +name = "cairo-rs-py" +version = "0.1.0" +description = "" +category = "dev" +optional = false +python-versions = ">=3.7" + [[package]] name = "certifi" version = "2022.9.24" @@ -1210,7 +1218,7 @@ multidict = ">=4.0" [metadata] lock-version = "1.1" python-versions = ">=3.8,<3.10" -content-hash = "42f002ca22c20461ab5dde4668c39a54f4991b42225fe899f35eb0d0831d19fe" +content-hash = "5d7136b6419083b29b8a1c2eaae95c17c25daa6d61d4960e874ed45e7f8c9bc4" [metadata.files] aiohttp = [ @@ -1432,6 +1440,10 @@ cachetools = [ cairo-lang = [ {file = "cairo-lang-0.10.3.zip", hash = "sha256:3093946334590f199d99471912049c182cec1f90d4ad02029f460a0de3a27502"}, ] +cairo-rs-py = [ + {file = "cairo_rs_py-0.1.0-cp39-cp39-manylinux_2_34_x86_64.whl", hash = "sha256:35c6d478a859972f6703957af3b9c9d34912a6725255d0bdf5e8c9c5de26d342"}, + {file = "cairo_rs_py-0.1.0.tar.gz", hash = "sha256:363d11bcc5184c97a9643e440cb0a6a18443720a5c179b63a08946c7c928dea3"}, +] certifi = [ {file = "certifi-2022.9.24-py3-none-any.whl", hash = "sha256:90c1a32f1d68f940488354e36370f6cca89f0f106db09518524c88d6ed83f382"}, {file = "certifi-2022.9.24.tar.gz", hash = "sha256:0d9c601124e5a6ba9712dbc60d9c53c21e34f5f641fe83002317394311bdce14"}, diff --git a/pyproject.toml b/pyproject.toml index 735c416e2..951866024 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ black = "~22.6" requests = "~2.28" isort = "^5.10.1" +[tool.poetry.group.vm.dependencies] +cairo-rs-py = "~0.1.0" + [tool.isort] profile = "black" skip_gitignore = true diff --git a/scripts/install_dev_tools.sh b/scripts/install_dev_tools.sh index 18ebe8162..0737858e7 100755 --- a/scripts/install_dev_tools.sh +++ b/scripts/install_dev_tools.sh @@ -12,6 +12,15 @@ echo "python3: $(python3 --version)" pip3 install -U poetry==1.2.1 echo "poetry: $(poetry --version)" +# https://www.rust-lang.org/tools/install +# need rust to install cairo-rs-py +if rustc --version; then + echo "rustc installed" +else + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + source "$HOME/.cargo/env" +fi + # install dependencies poetry install poetry lock --check diff --git a/starknet_devnet/__init__.py b/starknet_devnet/__init__.py index c795e8cf9..5b9867196 100644 --- a/starknet_devnet/__init__.py +++ b/starknet_devnet/__init__.py @@ -1,6 +1,8 @@ """ Contains the server implementation and its utility classes and functions. """ + +import os import sys from copy import copy @@ -9,6 +11,8 @@ from starkware.crypto.signature.fast_pedersen_hash import pedersen_hash from starkware.starknet.services.api.contract_class import ContractClass +from .util import warn + __version__ = "0.4.2" @@ -43,3 +47,20 @@ def simpler_copy(self, memo): # pylint: disable=unused-argument setattr(ContractClass, "__deepcopy__", simpler_copy) + + +# Optionally apply cairo-rs-py monkey patch +_VM_VAR = "STARKNET_DEVNET_CAIRO_VM" +_cairo_vm = os.environ.get(_VM_VAR) +if _cairo_vm == "rust": + from starknet_devnet.cairo_rs_py_patch import cairo_rs_py_monkeypatch + + cairo_rs_py_monkeypatch() + warn("Using Cairo VM: Rust") + +elif not _cairo_vm or _cairo_vm == "python": + # python VM set by default + pass + +else: + sys.exit(f"Error: Invalid value of environment variable {_VM_VAR}: '{_cairo_vm}'") diff --git a/starknet_devnet/cairo_rs_py_patch.py b/starknet_devnet/cairo_rs_py_patch.py new file mode 100644 index 000000000..387995e5a --- /dev/null +++ b/starknet_devnet/cairo_rs_py_patch.py @@ -0,0 +1,542 @@ +"""Patch starknet methods to use cairo_rs_py""" + +# pylint: disable=bare-except +# pylint: disable=missing-function-docstring +# pylint: disable=protected-access +# pylint: disable=too-many-locals + + +import logging +import sys +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, cast + +import cairo_rs_py +from starkware.cairo.common.cairo_function_runner import CairoFunctionRunner +from starkware.cairo.common.structs import CairoStructFactory, CairoStructProxy + +# from starkware.cairo.lang.compiler.identifier_definition import StructDefinition +from starkware.cairo.lang.compiler.program import Program +from starkware.cairo.lang.compiler.scoped_name import ScopedName +from starkware.cairo.lang.vm.memory_segments import MemorySegmentManager +from starkware.cairo.lang.vm.relocatable import MaybeRelocatable, RelocatableValue +from starkware.cairo.lang.vm.utils import ResourcesError +from starkware.cairo.lang.vm.vm_exceptions import ( + SecurityError, + VmException, + VmExceptionBase, +) +from starkware.python.utils import safe_zip +from starkware.starknet.business_logic.execution.execute_entry_point import ( + FAULTY_CLASS_HASH, + ExecuteEntryPoint, +) +from starkware.starknet.business_logic.execution.objects import ( + TransactionExecutionContext, +) +from starkware.starknet.business_logic.fact_state.state import ExecutionResourcesManager +from starkware.starknet.business_logic.state.state_api import SyncState +from starkware.starknet.business_logic.utils import validate_contract_deployed +from starkware.starknet.core.os import os_utils, segment_utils, syscall_utils +from starkware.starknet.core.os.class_hash import ( + get_contract_class_struct, + load_program, +) +from starkware.starknet.core.os.syscall_utils import ( # get_runtime_type, + BusinessLogicSysCallHandler, +) +from starkware.starknet.definitions.error_codes import StarknetErrorCode +from starkware.starknet.definitions.general_config import StarknetGeneralConfig +from starkware.starknet.public import abi as starknet_abi +from starkware.starknet.public.abi import SYSCALL_PTR_OFFSET +from starkware.starknet.services.api.contract_class import ContractClass +from starkware.starkware_utils.error_handling import ( + ErrorCode, + StarkException, + stark_assert, + wrap_with_stark_exception, +) + +logger = logging.getLogger(__name__) + + +def cairo_rs_py_run( + self, + state: SyncState, + resources_manager: ExecutionResourcesManager, + general_config: StarknetGeneralConfig, + tx_execution_context: TransactionExecutionContext, +) -> Tuple[CairoFunctionRunner, syscall_utils.BusinessLogicSysCallHandler]: + """ + Runs the selected entry point with the given calldata in the code of the contract deployed + at self.code_address. + The execution is done in the context (e.g., storage) of the contract at + self.contract_address. + Returns the corresponding CairoFunctionRunner and BusinessLogicSysCallHandler in order to + retrieve the execution information. + """ + # Prepare input for Cairo function runner. + class_hash = self._get_code_class_hash(state=state) + + # Hack to prevent version 0 attack on argent accounts. + if (tx_execution_context.version == 0) and (class_hash == FAULTY_CLASS_HASH): + raise StarkException( + code=StarknetErrorCode.TRANSACTION_FAILED, message="Fraud attempt blocked." + ) + + contract_class = state.get_contract_class(class_hash=class_hash) + contract_class.validate() + + entry_point = self._get_selected_entry_point( + contract_class=contract_class, class_hash=class_hash + ) + + # Run the specified contract entry point with given calldata. + with wrap_with_stark_exception(code=StarknetErrorCode.SECURITY_ERROR): + runner = cairo_rs_py.CairoRunner( # pylint: disable=no-member + program=contract_class.program.dumps(), + entrypoint=None, + layout="all", + proof_mode=False, + ) + runner.initialize_function_runner() + os_context = os_utils.prepare_os_context(runner=runner) + + validate_contract_deployed(state=state, contract_address=self.contract_address) + + initial_syscall_ptr = cast( + RelocatableValue, os_context[starknet_abi.SYSCALL_PTR_OFFSET] + ) + syscall_handler = syscall_utils.BusinessLogicSysCallHandler( + execute_entry_point_cls=ExecuteEntryPoint, + tx_execution_context=tx_execution_context, + state=state, + resources_manager=resources_manager, + caller_address=self.caller_address, + contract_address=self.contract_address, + general_config=general_config, + initial_syscall_ptr=initial_syscall_ptr, + ) + + # Positional arguments are passed to *args in the 'run_from_entrypoint' function. + entry_points_args = [ + self.entry_point_selector, + os_context, + len(self.calldata), + # Allocate and mark the segment as read-only (to mark every input array as read-only). + syscall_handler._allocate_segment(segments=runner, data=self.calldata), + ] + + try: + runner.run_from_entrypoint( + entry_point.offset, + entry_points_args, + hint_locals={ + "syscall_handler": syscall_handler, + }, + static_locals={ + "__find_element_max_size": 2**20, + "__squash_dict_max_size": 2**20, + "__keccak_max_size": 2**20, + "__usort_max_size": 2**20, + "__chained_ec_op_max_len": 1000, + }, + # run_resources=tx_execution_context.run_resources, + verify_secure=True, + ) + except VmException as exception: + code: ErrorCode = StarknetErrorCode.TRANSACTION_FAILED + + if isinstance(exception.inner_exc, syscall_utils.HandlerException): + stark_exception = exception.inner_exc.stark_exception + code = stark_exception.code + called_contract_address = exception.inner_exc.called_contract_address + message_prefix = ( + f"Error in the called contract ({hex(called_contract_address)}):\n" + ) + # Override python's traceback and keep the Cairo one of the inner exception. + exception.notes = [message_prefix + str(stark_exception.message)] + if isinstance(exception.inner_exc, ResourcesError): + code = StarknetErrorCode.OUT_OF_RESOURCES + + raise StarkException(code=code, message=str(exception)) from exception + except VmExceptionBase as exception: + raise StarkException( + code=StarknetErrorCode.TRANSACTION_FAILED, message=str(exception) + ) from exception + except SecurityError as exception: + raise StarkException( + code=StarknetErrorCode.SECURITY_ERROR, message=str(exception) + ) from exception + except Exception as exception: + logger.error("Got an unexpected exception.", exc_info=True) + raise StarkException( + code=StarknetErrorCode.UNEXPECTED_FAILURE, + message="Got an unexpected exception during the execution of the transaction.", + ) from exception + + # Complete handler validations. + os_utils.validate_and_process_os_context( + runner=runner, + syscall_handler=syscall_handler, + initial_os_context=os_context, + ) + + # When execution starts the stack holds entry_points_args + [ret_fp, ret_pc]. + args_ptr = runner.initial_fp - (len(entry_points_args) + 2) + + # The arguments are touched by the OS and should not be counted as holes, mark them + # as accessed. + # assert isinstance(args_ptr, RelocatableValue) # Downcast. + runner.mark_as_accessed(address=args_ptr, size=len(entry_points_args)) + + return runner, syscall_handler + + +def cairo_rs_py_compute_class_hash_inner( + contract_class: ContractClass, + hash_func: Callable[[int, int], int], # pylint: disable=unused-argument +) -> int: + program = load_program() + contract_class_struct = get_contract_class_struct( + identifiers=program.identifiers, contract_class=contract_class + ) + + runner = cairo_rs_py.CairoRunner( # pylint: disable=no-member + program=program.dumps(), entrypoint=None, layout="all", proof_mode=False + ) + runner.initialize_function_runner() + hash_ptr = runner.add_additional_hash_builtin() + + run_function_runner( + runner, + program, + "starkware.starknet.core.os.contracts.class_hash", + hash_ptr=hash_ptr, + contract_class=contract_class_struct, + use_full_name=True, + verify_secure=False, + ) + _, class_hash = runner.get_return_values(2) + return class_hash + + +def run_function_runner( + runner, + program, + func_name: str, + *args, + hint_locals: Optional[Dict[str, Any]] = None, + static_locals: Optional[Dict[str, Any]] = None, + verify_secure: Optional[bool] = None, + trace_on_failure: bool = False, + apply_modulo_to_args: Optional[bool] = None, + use_full_name: bool = False, + verify_implicit_args_segment: bool = False, + **kwargs, +) -> Tuple[Tuple[MaybeRelocatable, ...], Tuple[MaybeRelocatable, ...]]: + """ + Runs func_name(*args). + args are converted to Cairo-friendly ones using gen_arg. + + Returns the return values of the function, splitted into 2 tuples of implicit values and + explicit values. Structs will be flattened to a sequence of felts as part of the returned + tuple. + + Additional params: + verify_secure - Run verify_secure_runner to do extra verifications. + trace_on_failure - Run the tracer in case of failure to help debugging. + apply_modulo_to_args - Apply modulo operation on integer arguments. + use_full_name - Treat 'func_name' as a fully qualified identifier name, rather than a + relative one. + verify_implicit_args_segment - For each implicit argument, verify that the argument and the + return value are in the same segment. + """ + assert isinstance(program, Program) + entrypoint = program.get_label(func_name, full_name_lookup=use_full_name) + + structs_factory = CairoStructFactory.from_program(program=program) + func = ScopedName.from_string(scope=func_name) + + full_args_struct = structs_factory.build_func_args(func=func) + all_args = full_args_struct(*args, **kwargs) # pylint: disable=not-callable + + try: + runner.run_from_entrypoint( + entrypoint, + all_args, + typed_args=True, + hint_locals=hint_locals, + static_locals=static_locals, + verify_secure=verify_secure, + apply_modulo_to_args=apply_modulo_to_args, + ) + except (VmException, SecurityError, AssertionError) as ex: + if trace_on_failure: + print( + f"""\ +Got {type(ex).__name__} exception during the execution of {func_name}: +{str(ex)} +""" + ) + # trace_runner(runner=runner) + raise + + # The number of implicit arguments is identical to the number of implicit return values. + n_implicit_ret_vals = structs_factory.get_implicit_args_length(func=func) + n_explicit_ret_vals = structs_factory.get_explicit_return_values_length(func=func) + n_ret_vals = n_explicit_ret_vals + n_implicit_ret_vals + implicit_retvals = tuple( + runner.get_range(runner.get_ap() - n_ret_vals, n_implicit_ret_vals) + ) + + explicit_retvals = tuple( + runner.get_range(runner.get_ap() - n_explicit_ret_vals, n_explicit_ret_vals) + ) + + # Verify the memory segments of the implicit arguments. + if verify_implicit_args_segment: + implicit_args = all_args[:n_implicit_ret_vals] + for implicit_arg, implicit_retval in safe_zip(implicit_args, implicit_retvals): + assert isinstance( + implicit_arg, RelocatableValue + ), f"Implicit arguments must be RelocatableValues, {implicit_arg} is not." + assert isinstance(implicit_retval, RelocatableValue), ( + f"Argument {implicit_arg} is a RelocatableValue, but the returned value " + f"{implicit_retval} is not." + ) + assert implicit_arg.segment_index == implicit_retval.segment_index, ( + f"Implicit argument {implicit_arg} is not on the same segment as the returned " + f"{implicit_retval}." + ) + assert implicit_retval.offset >= implicit_arg.offset, ( + f"The offset of the returned implicit argument {implicit_retval} is less than " + f"the offset of the input {implicit_arg}." + ) + + return implicit_retvals, explicit_retvals + + +def cairo_rs_py_prepare_os_context( + runner: CairoFunctionRunner, +) -> List[MaybeRelocatable]: + syscall_segment = runner.add_segment() + os_context: List[MaybeRelocatable] = [syscall_segment] + os_context.extend(runner.get_program_builtins_initial_stack()) + + return os_context + + +def cairo_rs_py_validate_and_process_os_context( + runner: CairoFunctionRunner, + syscall_handler: syscall_utils.BusinessLogicSysCallHandler, + initial_os_context: List[MaybeRelocatable], +): + """ + Validates and processes an OS context that was returned by a transaction. + Returns the syscall processor object containing the accumulated syscall information. + """ + os_context_end = runner.get_ap() - 2 + stack_ptr = os_context_end + # The returned values are os_context, retdata_size, retdata_ptr. + stack_ptr = runner.get_builtins_final_stack(stack_ptr) + + final_os_context_ptr = stack_ptr - 1 + assert final_os_context_ptr + len(initial_os_context) == os_context_end + + # Validate system calls. + syscall_base_ptr, syscall_stop_ptr = segment_utils.get_os_segment_ptr_range( + runner=runner, ptr_offset=SYSCALL_PTR_OFFSET, os_context=initial_os_context + ) + + segment_utils.validate_segment_pointers( + segments=runner, + segment_base_ptr=syscall_base_ptr, + segment_stop_ptr=syscall_stop_ptr, + ) + syscall_handler.post_run(runner=runner, syscall_stop_ptr=syscall_stop_ptr) + + +def cairo_rs_py_allocate_segment( + self, segments: MemorySegmentManager, data: Iterable[MaybeRelocatable] +) -> RelocatableValue: + try: + segment_start = segments.add_segment() + except: + segment_start = segments.add() + + segment_end = segments.write_arg(ptr=segment_start, arg=data) + self.read_only_segments.append((segment_start, segment_end - segment_start)) + return segment_start + + +def cairo_rs_py_read_and_validate_syscall_request( + self, + syscall_name: str, + segments: MemorySegmentManager, + syscall_ptr: RelocatableValue, +) -> CairoStructProxy: + """ + Returns the system call request written in the syscall segment, starting at syscall_ptr. + Performs validations on the request. + """ + # Update syscall count. + self._count_syscall(syscall_name=syscall_name) + + request = self._read_syscall_request( + syscall_name=syscall_name, segments=segments, syscall_ptr=syscall_ptr + ) + + assert ( + syscall_ptr == self.expected_syscall_ptr + ), f"Bad syscall_ptr, Expected {self.expected_syscall_ptr}, got {syscall_ptr}." + + syscall_info = self.syscall_info[syscall_name] + self.expected_syscall_ptr += syscall_info.syscall_size + + selector = request.selector + assert isinstance(selector, int), ( + f"The selector argument to syscall {syscall_name} is of unexpected type. " + f"Expected: int; got: {type(selector).__name__}." + ) + assert ( + selector == syscall_info.selector + ), f"Bad syscall selector, expected {syscall_info.selector}. Got: {selector}" + + # args_struct_def: StructDefinition = ( + # syscall_info.syscall_request_struct.struct_definition_ + # ) + # for arg, (arg_name, arg_def) in safe_zip(request, args_struct_def.members.items()): + # expected_type = get_runtime_type(arg_def.cairo_type) + # assert isinstance(arg, expected_type), ( + # f"Argument {arg_name} to syscall {syscall_name} is of unexpected type. " + # f"Expected: value of type {expected_type}; got: {arg}." + # ) + + return request + + +def cairo_rs_py_get_os_segment_ptr_range( + runner: CairoFunctionRunner, ptr_offset: int, os_context: List[MaybeRelocatable] +) -> Tuple[MaybeRelocatable, MaybeRelocatable]: + """ + Returns the base and stop ptr of the OS-designated segment that starts at ptr_offset. + """ + allowed_offsets = (SYSCALL_PTR_OFFSET,) + assert ( + ptr_offset in allowed_offsets + ), f"Illegal OS ptr offset; must be one of: {allowed_offsets}." + + # The returned values are os_context, retdata_size, retdata_ptr. + os_context_end = runner.get_ap() - 2 + final_os_context_ptr = os_context_end - len(os_context) + + return os_context[ptr_offset], runner.get(final_os_context_ptr + ptr_offset) + + +def cairo_rs_py_validate_segment_pointers( + segments: MemorySegmentManager, + segment_base_ptr: MaybeRelocatable, + segment_stop_ptr: MaybeRelocatable, +): + # assert isinstance(segment_base_ptr, RelocatableValue) + assert ( + segment_base_ptr.offset == 0 + ), f"Segment base pointer must be zero; got {segment_base_ptr.offset}." + + expected_stop_ptr = segment_base_ptr + segments.get_segment_used_size( + index=segment_base_ptr.segment_index + ) + + stark_assert( + expected_stop_ptr == segment_stop_ptr, + code=StarknetErrorCode.SECURITY_ERROR, + message=( + f"Invalid stop pointer for segment. " + f"Expected: {expected_stop_ptr}, found: {segment_stop_ptr}." + ), + ) + + +def cairo_rs_py_get_return_values(runner: CairoFunctionRunner) -> List[int]: + """ + Extracts the return values of a StarkNet contract function from the Cairo runner. + """ + with wrap_with_stark_exception( + code=StarknetErrorCode.INVALID_RETURN_DATA, + message="Error extracting return data.", + logger=logger, + exception_types=[Exception], + ): + ret_data_size, ret_data_ptr = runner.get_return_values(2) + values = runner.memory.get_range(ret_data_ptr, ret_data_size) + values = runner.get_range(ret_data_ptr, ret_data_size) + stark_assert( + all(isinstance(value, int) for value in values), + code=StarknetErrorCode.INVALID_RETURN_DATA, + message="Return data expected to be non-relocatable.", + ) + return cast(List[int], values) + + +def cairo_rs_py_validate_read_only_segments(self, runner: CairoFunctionRunner): + """ + Validates that there were no out of bounds writes to read-only segments and marks + them as accessed. + """ + segments = runner + + for segment_ptr, segment_size in self.read_only_segments: + used_size = segments.get_segment_used_size(index=segment_ptr.segment_index) + stark_assert( + used_size == segment_size, + code=StarknetErrorCode.SECURITY_ERROR, + message="Out of bounds write to a read-only segment.", + ) + runner.mark_as_accessed(address=segment_ptr, size=segment_size) + + +def cairo_rs_py_monkeypatch(): + setattr(ExecuteEntryPoint, "_run", cairo_rs_py_run) + setattr( + sys.modules["starkware.starknet.core.os.class_hash"], + "class_hash_inner", + cairo_rs_py_compute_class_hash_inner, + ) + setattr( + sys.modules["starkware.starknet.core.os.os_utils"], + "prepare_os_context", + cairo_rs_py_prepare_os_context, + ) + setattr( + sys.modules["starkware.starknet.core.os.os_utils"], + "validate_and_process_os_context", + cairo_rs_py_validate_and_process_os_context, + ) + setattr( + BusinessLogicSysCallHandler, "_allocate_segment", cairo_rs_py_allocate_segment + ) + setattr( + BusinessLogicSysCallHandler, + "_read_and_validate_syscall_request", + cairo_rs_py_read_and_validate_syscall_request, + ) + setattr( + BusinessLogicSysCallHandler, + "validate_read_only_segments", + cairo_rs_py_validate_read_only_segments, + ) + setattr( + sys.modules["starkware.starknet.core.os.syscall_utils"], + "get_os_segment_ptr_range", + cairo_rs_py_get_os_segment_ptr_range, + ) + setattr( + sys.modules["starkware.starknet.core.os.segment_utils"], + "validate_segment_pointers", + cairo_rs_py_validate_segment_pointers, + ) + setattr( + sys.modules["starkware.starknet.business_logic.utils"], + "get_return_values", + cairo_rs_py_get_return_values, + ) diff --git a/test/test_account_custom.py b/test/test_account_custom.py index efccdde40..8cc62923d 100644 --- a/test/test_account_custom.py +++ b/test/test_account_custom.py @@ -58,7 +58,7 @@ def test_invalid_path(class_path: str, expected_error: str): """Test behavior on providing nonexistent path""" proc = ACTIVE_DEVNET.start("--account-class", class_path, stderr=subprocess.PIPE) assert proc.returncode == 1 - assert proc.stderr.read().decode("utf-8") == expected_error + assert expected_error in proc.stderr.read().decode("utf-8") @pytest.mark.account_custom diff --git a/test/test_cairo_vm.py b/test/test_cairo_vm.py new file mode 100644 index 000000000..f427f37b5 --- /dev/null +++ b/test/test_cairo_vm.py @@ -0,0 +1,68 @@ +"""Test specifying cairo VM""" + +import os +import subprocess + +import pytest + +from .util import DevnetBackgroundProc, read_stream, terminate_and_wait + +ACTIVE_DEVNET = DevnetBackgroundProc() + +_VM_VAR = "STARKNET_DEVNET_CAIRO_VM" +_RUST_VM_LOG_LINE = "Using Cairo VM: Rust" + + +@pytest.mark.parametrize( + "cairo_vm, assert_rust_vm_logged", + [ + ("", False), + ("python", False), + ("rust", True), + ], +) +def test_valid_cairo_vm(cairo_vm, assert_rust_vm_logged): + """Test if the invalid chain id fails""" + + env_copy = os.environ.copy() + env_copy[_VM_VAR] = cairo_vm + + proc = ACTIVE_DEVNET.start(stderr=subprocess.PIPE, env=env_copy) + terminate_and_wait(proc) + + stderr = proc.stderr.read().decode("utf-8") + if assert_rust_vm_logged: + assert _RUST_VM_LOG_LINE in stderr + else: + assert _RUST_VM_LOG_LINE not in stderr + + assert proc.returncode == 0 + + +def test_passing_if_no_cairo_vm_set(): + """If no vm env var set, it should assume python and pass""" + env_copy = os.environ.copy() + if _VM_VAR in env_copy: + del env_copy[_VM_VAR] + + proc = ACTIVE_DEVNET.start(stderr=subprocess.PIPE, env=env_copy) + terminate_and_wait(proc) + + assert _RUST_VM_LOG_LINE not in proc.stderr.read().decode("utf-8") + assert proc.returncode == 0 + + +@pytest.mark.parametrize("cairo_vm", ["invalid_value", " rust"]) +def test_invalid_cairo_vm(cairo_vm): + """Test random invalid cairo vm specifications""" + + env_copy = os.environ.copy() + env_copy[_VM_VAR] = cairo_vm + proc = ACTIVE_DEVNET.start(stderr=subprocess.PIPE, env=env_copy) + + terminate_and_wait(proc) + assert ( + f"Error: Invalid value of environment variable {_VM_VAR}: '{cairo_vm}'" + in read_stream(proc.stderr) + ) + assert proc.returncode == 1 diff --git a/test/test_dump.py b/test/test_dump.py index 3d37ed248..662268afd 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -120,7 +120,7 @@ def test_load_via_cli_if_no_file(): ) assert devnet_proc.returncode == 1 expected_msg = f"Error: Cannot load from {DUMP_PATH}. Make sure the file exists and contains a Devnet dump.\n" - assert expected_msg == devnet_proc.stderr.read().decode("utf-8") + assert expected_msg in devnet_proc.stderr.read().decode("utf-8") def test_mint_after_load(): @@ -174,7 +174,7 @@ def test_dumping_if_nonexistent_dir_via_cli(): assert devnet_proc.returncode == 1 expected_msg = f"Invalid dump path: directory '{NONEXISTENT_DIR}' not found.\n" - assert expected_msg == devnet_proc.stderr.read().decode("utf-8") + assert expected_msg in devnet_proc.stderr.read().decode("utf-8") @devnet_in_background() @@ -328,7 +328,7 @@ def test_invalid_dump_on_option(): assert devnet_proc.returncode == 1 expected_msg = b"Error: Invalid --dump-on option: obviously-invalid. Valid options: exit, transaction\n" - assert devnet_proc.stderr.read() == expected_msg + assert expected_msg in devnet_proc.stderr.read() def test_dump_path_not_present_with_dump_on_present(): @@ -337,7 +337,7 @@ def test_dump_path_not_present_with_dump_on_present(): assert devnet_proc.returncode == 1 expected_msg = b"Error: --dump-path required if --dump-on present\n" - assert devnet_proc.stderr.read() == expected_msg + assert expected_msg in devnet_proc.stderr.read() def assert_load(dump_path: str, contract_address: str, expected_value: str): diff --git a/test/test_fork_cli_params.py b/test/test_fork_cli_params.py index dd2eccb43..63cdce79a 100644 --- a/test/test_fork_cli_params.py +++ b/test/test_fork_cli_params.py @@ -21,8 +21,8 @@ def test_invalid_fork_network(): ) assert read_stream(proc.stdout) == "" assert ( - read_stream(proc.stderr) - == f"Error: Invalid fork-network (must be a URL or one of {{alpha-goerli, alpha-goerli2, alpha-mainnet}}). Received: {invalid_name}\n" + f"Error: Invalid fork-network (must be a URL or one of {{alpha-goerli, alpha-goerli2, alpha-mainnet}}). Received: {invalid_name}\n" + in read_stream(proc.stderr) ) assert proc.returncode == 1 @@ -37,9 +37,8 @@ def test_url_not_sequencer(): stdout=subprocess.PIPE, ) assert read_stream(proc.stdout) == "" - assert ( - read_stream(proc.stderr) - == f"Error: {invalid_url} is not a valid StarkNet sequencer\n" + assert f"Error: {invalid_url} is not a valid StarkNet sequencer\n" in read_stream( + proc.stderr ) assert proc.returncode == 1 @@ -78,9 +77,8 @@ def test_block_provided_without_network(): "--fork-block", "123", stderr=subprocess.PIPE, stdout=subprocess.PIPE ) assert read_stream(proc.stdout) == "" - assert ( - read_stream(proc.stderr) - == "Error: --fork-network required if --fork-block present\n" + assert "Error: --fork-network required if --fork-block present\n" in read_stream( + proc.stderr ) assert proc.returncode == 1 @@ -98,9 +96,10 @@ def test_malformed_block_id(fork_block: str): ) assert read_stream(proc.stdout) == "" assert ( - read_stream(proc.stderr) - == f"The value of --fork-block must be a non-negative integer or 'latest', got: {fork_block}\n" + f"The value of --fork-block must be a non-negative integer or 'latest', got: {fork_block}\n" + in read_stream(proc.stderr) ) + assert proc.returncode == 1 diff --git a/test/util.py b/test/util.py index d23aa848e..46c51e583 100644 --- a/test/util.py +++ b/test/util.py @@ -27,7 +27,7 @@ class ReturnCodeAssertionError(AssertionError): """Error to be raised when the return code of an executed process is not as expected.""" -def run_devnet_in_background(*args, stderr=None, stdout=None): +def run_devnet_in_background(*args, stderr=None, stdout=None, env=None): """ Runs starknet-devnet in background. Sleep before devnet is responsive. @@ -52,7 +52,9 @@ def run_devnet_in_background(*args, stderr=None, stdout=None): *args, ] # pylint: disable=consider-using-with - proc = subprocess.Popen(command, close_fds=True, stderr=stderr, stdout=stdout) + proc = subprocess.Popen( + command, close_fds=True, stderr=stderr, stdout=stdout, env=env + ) healthcheck_url = f"http://{HOST}:{port}/is_alive" ensure_server_alive(healthcheck_url, proc) @@ -656,10 +658,12 @@ class DevnetBackgroundProc: def __init__(self): self.proc = None - def start(self, *args, stderr=None, stdout=None): + def start(self, *args, stderr=None, stdout=None, env=None): """Starts a new devnet-server instance. Previously active instance will be stopped.""" self.stop() - self.proc = run_devnet_in_background(*args, stderr=stderr, stdout=stdout) + self.proc = run_devnet_in_background( + *args, stderr=stderr, stdout=stdout, env=env + ) return self.proc def stop(self):