From 9497af429b2bd448d7085c41c772ff068f8e59d5 Mon Sep 17 00:00:00 2001 From: FabijanC Date: Mon, 29 Aug 2022 17:23:47 +0200 Subject: [PATCH] Add formatting (#237) [skip ci] --- .github/pull_request_template.md | 5 +- README.md | 3 +- poetry.lock | 94 +++++++--- pyproject.toml | 1 + scripts/format.sh | 5 + starknet_devnet/account.py | 52 ++++-- starknet_devnet/accounts.py | 20 ++- starknet_devnet/block_info_generator.py | 6 +- starknet_devnet/blocks.py | 37 ++-- starknet_devnet/blueprints/base.py | 44 +++-- starknet_devnet/blueprints/feeder_gateway.py | 77 +++++++-- starknet_devnet/blueprints/gateway.py | 22 ++- starknet_devnet/blueprints/postman.py | 9 +- starknet_devnet/blueprints/rpc/call.py | 13 +- starknet_devnet/blueprints/rpc/classes.py | 31 +++- starknet_devnet/blueprints/rpc/routes.py | 32 +++- starknet_devnet/blueprints/rpc/state.py | 8 +- starknet_devnet/blueprints/rpc/storage.py | 14 +- .../blueprints/rpc/structures/payloads.py | 58 +++++-- .../blueprints/rpc/structures/responses.py | 43 ++++- .../blueprints/rpc/structures/types.py | 9 +- .../blueprints/rpc/transactions.py | 105 +++++++++--- starknet_devnet/blueprints/rpc/utils.py | 33 ++-- starknet_devnet/blueprints/shared.py | 3 +- starknet_devnet/constants.py | 10 +- starknet_devnet/contract_wrapper.py | 17 +- starknet_devnet/contracts.py | 10 +- starknet_devnet/devnet_config.py | 74 ++++---- starknet_devnet/dump.py | 3 +- starknet_devnet/fee_token.py | 44 +++-- starknet_devnet/general_config.py | 36 ++-- starknet_devnet/origin.py | 44 ++--- starknet_devnet/postman_wrapper.py | 59 ++++--- starknet_devnet/server.py | 52 +++--- starknet_devnet/starknet_wrapper.py | 141 +++++++++------ starknet_devnet/state.py | 5 +- starknet_devnet/transactions.py | 55 +++--- starknet_devnet/util.py | 73 ++++---- test/account.py | 34 ++-- test/rpc/conftest.py | 17 +- test/rpc/rpc_utils.py | 21 ++- test/rpc/test_rpc_blocks.py | 80 +++++---- test/rpc/test_rpc_call.py | 52 +++--- test/rpc/test_rpc_class.py | 33 ++-- test/rpc/test_rpc_estimate_fee.py | 25 ++- test/rpc/test_rpc_misc.py | 26 +-- test/rpc/test_rpc_storage.py | 27 +-- test/rpc/test_rpc_transactions.py | 121 ++++++------- test/settings.py | 2 + test/shared.py | 14 +- test/test_account.py | 91 ++++++---- test/test_api_specifications.py | 1 + test/test_block_number.py | 15 +- test/test_declare.py | 39 +++-- test/test_deploy.py | 29 +++- test/test_dump.py | 88 +++++++--- test/test_endpoints.py | 70 ++++++-- test/test_estimate_fee.py | 67 ++++---- test/test_fee_token.py | 123 +++++++------ test/test_general_workflow.py | 35 ++-- test/test_general_workflow_auth.py | 34 ++-- test/test_general_workflow_lite.py | 21 ++- test/test_general_workflow_lite_block_hash.py | 24 +-- .../test_general_workflow_lite_deploy_hash.py | 17 +- test/test_postman.py | 143 ++++++++++------ test/test_restart.py | 20 ++- test/test_state_update.py | 33 +++- test/test_timestamps.py | 45 +++-- test/test_transaction_trace.py | 58 +++++-- test/test_tx_version.py | 1 + test/util.py | 161 +++++++++++++----- test/web3_util.py | 9 +- 72 files changed, 1903 insertions(+), 1020 deletions(-) create mode 100755 scripts/format.sh diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 87f3a3da2..cfbac9889 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -14,8 +14,9 @@ ## Checklist: -- [ ] No linter errors -- [ ] Performed a self-review of the code +- [ ] Applied formatting - `./scripts/format.sh` +- [ ] No linter errors - `./scripts/lint.sh` +- [ ] Performed code self-review - [ ] Rebased to the last commit of the target branch (or merged it into my branch) - [ ] Documented the changes - [ ] Linked the issues which this PR resolves diff --git a/README.md b/README.md index d19afb414..289540507 100644 --- a/README.md +++ b/README.md @@ -609,9 +609,10 @@ poetry run starknet-devnet ./scripts/starknet_devnet_debug.sh ``` -### Development - Lint +### Development - Format and lint ```text +./scripts/format.sh ./scripts/lint.sh ``` diff --git a/poetry.lock b/poetry.lock index 9caacde82..e19e385aa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -119,6 +119,29 @@ category = "main" optional = false python-versions = "*" +[[package]] +name = "black" +version = "22.6.0" +description = "The uncompromising code formatter." +category = "dev" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} +typed-ast = {version = ">=1.4.2", markers = "python_version < \"3.8\" and implementation_name == \"cpython\""} +typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "cachetools" version = "5.2.0" @@ -779,6 +802,14 @@ python-versions = "*" [package.dependencies] six = ">=1.9.0" +[[package]] +name = "pathspec" +version = "0.9.0" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" + [[package]] name = "pipdeptree" version = "2.2.1" @@ -1219,7 +1250,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = ">=3.7.2,<3.10" -content-hash = "abb66ed5566bb5bfbae7e8e6188dbedae49797882b665259d125bd860845ee78" +content-hash = "f0e598b9887778ff28a621307b0a91614a44a15b8a1d1bc95969995517efb497" [metadata.files] aiohttp = [ @@ -1331,6 +1362,7 @@ base58 = [ bitarray = [ {file = "bitarray-1.2.2.tar.gz", hash = "sha256:27a69ffcee3b868abab3ce8b17c69e02b63e722d4d64ffd91d659f81e9984954"}, ] +black = [] cachetools = [ {file = "cachetools-5.2.0-py3-none-any.whl", hash = "sha256:f9f17d2aec496a9aa6b76f53e3b614c965223c061982d434d160f930c698a9db"}, {file = "cachetools-5.2.0.tar.gz", hash = "sha256:6a94c6402995a99c3970cc7e4884bb60b4a8639938157eeed436098bf9831757"}, @@ -1357,10 +1389,7 @@ colorama = [ {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] crypto-cpp-py = [ - {file = "crypto_cpp_py-1.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:3a802e2b4d015f080a114c2ffeef7a117509c73f368372aca8e3f337dfdc7f51"}, {file = "crypto_cpp_py-1.0.4-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:29422520f1dabca90d3031fd5c53234aabe126dc903dd48992d26e78f104df8f"}, - {file = "crypto_cpp_py-1.0.4-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:05e3b586d03b553e532d5d44d119fd4f6a8e86107245b03bdaac61381324f969"}, - {file = "crypto_cpp_py-1.0.4-cp37-cp37m-macosx_12_0_x86_64.whl", hash = "sha256:2d77ec9927d9b9eaa0fd9060ad616080fa78095c899a3b9055cd131243b28095"}, {file = "crypto_cpp_py-1.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6590aeb6dabf674b231f32eb82d6272d3cdf9b689054195d6ebb6cba1c51fb8d"}, {file = "crypto_cpp_py-1.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:16e5314bf0829cbf578b8bfe228f13370c4666d0613dfd490c30396c30ee5e46"}, {file = "crypto_cpp_py-1.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:246a642524b61f4bc04f0deccd391e62273816174c4e87448cd9531f0c33d91a"}, @@ -1371,10 +1400,7 @@ crypto-cpp-py = [ {file = "crypto_cpp_py-1.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4166fbe3ea025c3dfdcbf9269323a24c43af3f92880fbd8008891ea6b9489fc2"}, {file = "crypto_cpp_py-1.0.4-cp37-cp37m-win32.whl", hash = "sha256:e78d08225bd20d829f119ba2e00c11b69bf6d2a99b655e56181580290392a010"}, {file = "crypto_cpp_py-1.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:059e2ba230a22762087032fd508f7e8b5213ea5a18d1b091eb5c796093af9548"}, - {file = "crypto_cpp_py-1.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:afe891d1a28b4d905257b410563fc4b2fc0c4d536c1d8c9b0bdb7c6d18a7f565"}, {file = "crypto_cpp_py-1.0.4-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:02520302ebcfa209bb79cb6a70fadffdd090b73e302142e48e1757bb513f2eca"}, - {file = "crypto_cpp_py-1.0.4-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:447af208c5e6ff59a9d44c69806d2e3a71f32ec63b9435d7837fefe258a09f0b"}, - {file = "crypto_cpp_py-1.0.4-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:7d128ce933440303706c36eb6f5aa32ee76316138ce17cb12217ad283b030f4f"}, {file = "crypto_cpp_py-1.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eeeab30e7c9409d102dfb0d6f2a9024f518f5104e4f615964ae422c6cca8370"}, {file = "crypto_cpp_py-1.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1952ecad9ae8f9756e0510a45dae6cbd3ff85b594968266b90f504bd2c4b083"}, {file = "crypto_cpp_py-1.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:712e03d6516b3cc12a4ff334b47e5683c58baf7c923f6ba9d1cf6478800ad550"}, @@ -1385,10 +1411,7 @@ crypto-cpp-py = [ {file = "crypto_cpp_py-1.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8c94e012a1e591bbe112778ec64576fa5b3dd4856208c0a6ac3c51c5ad62c61d"}, {file = "crypto_cpp_py-1.0.4-cp38-cp38-win32.whl", hash = "sha256:d7a9889203d9e56cb5f7128c3742b7fc62a36cde8dddd71989570f0db6491c5c"}, {file = "crypto_cpp_py-1.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:7573420e603b523cec2e8ce26a11bcbc380987c2b68ce64be0e0068e9c0ff515"}, - {file = "crypto_cpp_py-1.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:31b7a93a34fb444bd8e5540907e83737dcf6ce8d19f770fe1b49c6091cc9a6b9"}, {file = "crypto_cpp_py-1.0.4-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:a3f4ac6852b5919508cc0218c5ae166afdbff64aa63bf6dfc8602761879706a5"}, - {file = "crypto_cpp_py-1.0.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:0bad6955111f68e49e6b053171a8ed998a6e8a706464c72c1728497b9e4ebd4c"}, - {file = "crypto_cpp_py-1.0.4-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:88ca995b5a28d5f8e775bd3efa0a6e995decc76866c6fa548df5637729d23c09"}, {file = "crypto_cpp_py-1.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4266a9851f0abba7757dbdda41b70fe986fbb333eca36d6d520400485cd64cf"}, {file = "crypto_cpp_py-1.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:734c5baddd300601d4c42638922a736ab3e74cef3cc3d15785cb3698b28ca90e"}, {file = "crypto_cpp_py-1.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8cf24b9b18775a21eaf1a7add19076184bd1d38c2c3dbced48502414f1c2cbd6"}, @@ -1440,10 +1463,7 @@ eth-utils = [ {file = "eth-utils-1.10.0.tar.gz", hash = "sha256:bf82762a46978714190b0370265a7148c954d3f0adaa31c6f085ea375e4c61af"}, {file = "eth_utils-1.10.0-py3-none-any.whl", hash = "sha256:74240a8c6f652d085ed3c85f5f1654203d2f10ff9062f83b3bad0a12ff321c7a"}, ] -execnet = [ - {file = "execnet-1.9.0-py2.py3-none-any.whl", hash = "sha256:a295f7cc774947aac58dde7fdc85f4aa00c42adf5d8f5468fc630c1acf30a142"}, - {file = "execnet-1.9.0.tar.gz", hash = "sha256:8f694f3ba9cc92cab508b152dcfe322153975c29bda272e2fd7f3f00f36e47c5"}, -] +execnet = [] fastecdsa = [ {file = "fastecdsa-2.2.3-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:c1f27c5b37aee4bafa8ee304f6da3382ba90200a6998764b3cedba506ef03ff5"}, {file = "fastecdsa-2.2.3-cp36-cp36m-macosx_11_0_x86_64.whl", hash = "sha256:05676e917fea8d56f15a7f00c81560d9a6be83767435dca17cc2a62741276656"}, @@ -1783,6 +1803,7 @@ packaging = [ parsimonious = [ {file = "parsimonious-0.8.1.tar.gz", hash = "sha256:3add338892d580e0cb3b1a39e4a1b427ff9f687858fdd61097053742391a9f6b"}, ] +pathspec = [] pipdeptree = [ {file = "pipdeptree-2.2.1-py3-none-any.whl", hash = "sha256:e20655a38d6e363d8e86d6a85e8a648680a3f4b6d039d6ee3ab0f539da1ad6ce"}, {file = "pipdeptree-2.2.1.tar.gz", hash = "sha256:2b97d80c64d229e01ad242f14229a899263c6e8645c588ec5b054c1b81f3065d"}, @@ -1825,7 +1846,40 @@ protobuf = [ {file = "protobuf-3.20.1-py2.py3-none-any.whl", hash = "sha256:adfc6cf69c7f8c50fd24c793964eef18f0ac321315439d94945820612849c388"}, {file = "protobuf-3.20.1.tar.gz", hash = "sha256:adc31566d027f45efe3f44eeb5b1f329da43891634d61c75a5944e9be6dd42c9"}, ] -psutil = [] +psutil = [ + {file = "psutil-5.9.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:799759d809c31aab5fe4579e50addf84565e71c1dc9f1c31258f159ff70d3f87"}, + {file = "psutil-5.9.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:9272167b5f5fbfe16945be3db475b3ce8d792386907e673a209da686176552af"}, + {file = "psutil-5.9.1-cp27-cp27m-win32.whl", hash = "sha256:0904727e0b0a038830b019551cf3204dd48ef5c6868adc776e06e93d615fc5fc"}, + {file = "psutil-5.9.1-cp27-cp27m-win_amd64.whl", hash = "sha256:e7e10454cb1ab62cc6ce776e1c135a64045a11ec4c6d254d3f7689c16eb3efd2"}, + {file = "psutil-5.9.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:56960b9e8edcca1456f8c86a196f0c3d8e3e361320071c93378d41445ffd28b0"}, + {file = "psutil-5.9.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:44d1826150d49ffd62035785a9e2c56afcea66e55b43b8b630d7706276e87f22"}, + {file = "psutil-5.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7be9d7f5b0d206f0bbc3794b8e16fb7dbc53ec9e40bbe8787c6f2d38efcf6c9"}, + {file = "psutil-5.9.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd9246e4cdd5b554a2ddd97c157e292ac11ef3e7af25ac56b08b455c829dca8"}, + {file = "psutil-5.9.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29a442e25fab1f4d05e2655bb1b8ab6887981838d22effa2396d584b740194de"}, + {file = "psutil-5.9.1-cp310-cp310-win32.whl", hash = "sha256:20b27771b077dcaa0de1de3ad52d22538fe101f9946d6dc7869e6f694f079329"}, + {file = "psutil-5.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:58678bbadae12e0db55186dc58f2888839228ac9f41cc7848853539b70490021"}, + {file = "psutil-5.9.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:3a76ad658641172d9c6e593de6fe248ddde825b5866464c3b2ee26c35da9d237"}, + {file = "psutil-5.9.1-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6a11e48cb93a5fa606306493f439b4aa7c56cb03fc9ace7f6bfa21aaf07c453"}, + {file = "psutil-5.9.1-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:068935df39055bf27a29824b95c801c7a5130f118b806eee663cad28dca97685"}, + {file = "psutil-5.9.1-cp36-cp36m-win32.whl", hash = "sha256:0f15a19a05f39a09327345bc279c1ba4a8cfb0172cc0d3c7f7d16c813b2e7d36"}, + {file = "psutil-5.9.1-cp36-cp36m-win_amd64.whl", hash = "sha256:db417f0865f90bdc07fa30e1aadc69b6f4cad7f86324b02aa842034efe8d8c4d"}, + {file = "psutil-5.9.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:91c7ff2a40c373d0cc9121d54bc5f31c4fa09c346528e6a08d1845bce5771ffc"}, + {file = "psutil-5.9.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fea896b54f3a4ae6f790ac1d017101252c93f6fe075d0e7571543510f11d2676"}, + {file = "psutil-5.9.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3054e923204b8e9c23a55b23b6df73a8089ae1d075cb0bf711d3e9da1724ded4"}, + {file = "psutil-5.9.1-cp37-cp37m-win32.whl", hash = "sha256:d2d006286fbcb60f0b391741f520862e9b69f4019b4d738a2a45728c7e952f1b"}, + {file = "psutil-5.9.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b14ee12da9338f5e5b3a3ef7ca58b3cba30f5b66f7662159762932e6d0b8f680"}, + {file = "psutil-5.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:19f36c16012ba9cfc742604df189f2f28d2720e23ff7d1e81602dbe066be9fd1"}, + {file = "psutil-5.9.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:944c4b4b82dc4a1b805329c980f270f170fdc9945464223f2ec8e57563139cf4"}, + {file = "psutil-5.9.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b6750a73a9c4a4e689490ccb862d53c7b976a2a35c4e1846d049dcc3f17d83b"}, + {file = "psutil-5.9.1-cp38-cp38-win32.whl", hash = "sha256:a8746bfe4e8f659528c5c7e9af5090c5a7d252f32b2e859c584ef7d8efb1e689"}, + {file = "psutil-5.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:79c9108d9aa7fa6fba6e668b61b82facc067a6b81517cab34d07a84aa89f3df0"}, + {file = "psutil-5.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:28976df6c64ddd6320d281128817f32c29b539a52bdae5e192537bc338a9ec81"}, + {file = "psutil-5.9.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b88f75005586131276634027f4219d06e0561292be8bd6bc7f2f00bdabd63c4e"}, + {file = "psutil-5.9.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:645bd4f7bb5b8633803e0b6746ff1628724668681a434482546887d22c7a9537"}, + {file = "psutil-5.9.1-cp39-cp39-win32.whl", hash = "sha256:32c52611756096ae91f5d1499fe6c53b86f4a9ada147ee42db4991ba1520e574"}, + {file = "psutil-5.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:f65f9a46d984b8cd9b3750c2bdb419b2996895b005aefa6cbaba9a143b1ce2c5"}, + {file = "psutil-5.9.1.tar.gz", hash = "sha256:57f1819b5d9e95cdfb0c881a8a5b7d542ed0b7c522d575706a80bedc848c8954"}, +] py = [ {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, @@ -1903,14 +1957,8 @@ pytest-asyncio = [ {file = "pytest_asyncio-0.18.3-1-py3-none-any.whl", hash = "sha256:16cf40bdf2b4fb7fc8e4b82bd05ce3fbcd454cbf7b92afc445fe299dabb88213"}, {file = "pytest_asyncio-0.18.3-py3-none-any.whl", hash = "sha256:8fafa6c52161addfd41ee7ab35f11836c5a16ec208f93ee388f752bea3493a84"}, ] -pytest-forked = [ - {file = "pytest-forked-1.4.0.tar.gz", hash = "sha256:8b67587c8f98cbbadfdd804539ed5455b6ed03802203485dd2f53c1422d7440e"}, - {file = "pytest_forked-1.4.0-py3-none-any.whl", hash = "sha256:bbbb6717efc886b9d64537b41fb1497cfaf3c9601276be8da2cccfea5a3c8ad8"}, -] -pytest-xdist = [ - {file = "pytest-xdist-2.5.0.tar.gz", hash = "sha256:4580deca3ff04ddb2ac53eba39d76cb5dd5edeac050cb6fbc768b0dd712b4edf"}, - {file = "pytest_xdist-2.5.0-py3-none-any.whl", hash = "sha256:6fe5c74fec98906deb8f2d2b616b5c782022744978e7bd4695d39c8f42d0ce65"}, -] +pytest-forked = [] +pytest-xdist = [] pywin32 = [ {file = "pywin32-304-cp310-cp310-win32.whl", hash = "sha256:3c7bacf5e24298c86314f03fa20e16558a4e4138fc34615d7de4070c23e65af3"}, {file = "pywin32-304-cp310-cp310-win_amd64.whl", hash = "sha256:4f32145913a2447736dad62495199a8e280a77a0ca662daa2332acf849f0be48"}, diff --git a/pyproject.toml b/pyproject.toml index d9d25ee08..e7a46fdfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ psutil = "~5.9.1" jsonschema = "~3.2.0" pytest-xdist = "~2.5.0" pylint-quotes = "~0.2.3" +black = "~22.6" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/scripts/format.sh b/scripts/format.sh new file mode 100755 index 000000000..ba9f0fec8 --- /dev/null +++ b/scripts/format.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +set -e + +poetry run black $(git ls-files '*.py') diff --git a/starknet_devnet/account.py b/starknet_devnet/account.py index 8b11a0f4c..51a89e227 100644 --- a/starknet_devnet/account.py +++ b/starknet_devnet/account.py @@ -4,20 +4,26 @@ from starkware.cairo.lang.vm.crypto import pedersen_hash from starkware.solidity.utils import load_nearby_contract -from starkware.starknet.business_logic.state.objects import ContractState, ContractCarriedState +from starkware.starknet.business_logic.state.objects import ( + ContractState, + ContractCarriedState, +) from starkware.starknet.public.abi import get_selector_from_name from starkware.starknet.services.api.contract_class import ContractClass -from starkware.starknet.core.os.contract_address.contract_address import calculate_contract_address_from_hash +from starkware.starknet.core.os.contract_address.contract_address import ( + calculate_contract_address_from_hash, +) from starkware.starknet.storage.starknet_storage import StorageLeaf from starkware.starknet.testing.contract import StarknetContract from starkware.python.utils import to_bytes from starknet_devnet.util import Uint256 + class Account: """Account contract wrapper.""" - CONTRACT_CLASS: ContractClass = None # loaded lazily + CONTRACT_CLASS: ContractClass = None # loaded lazily CONTRACT_PATH = "accounts_artifacts/OpenZeppelin/0.3.1/Account.cairo/Account" # Precalculated to save time @@ -25,7 +31,9 @@ class Account: HASH = 580711710156617243550448398501018980467831526895029280465303474122300077395 HASH_BYTES = to_bytes(HASH) - def __init__(self, starknet_wrapper, private_key: int, public_key: int, initial_balance: int): + def __init__( + self, starknet_wrapper, private_key: int, public_key: int, initial_balance: int + ): self.starknet_wrapper = starknet_wrapper self.private_key = private_key self.public_key = public_key @@ -36,7 +44,7 @@ def __init__(self, starknet_wrapper, private_key: int, public_key: int, initial_ salt=20, class_hash=1803505466663265559571280894381905521939782500874858933595227108099796801620, constructor_calldata=[public_key], - deployer_address=0 + deployer_address=0, ) self.initial_balance = initial_balance @@ -44,7 +52,9 @@ def __init__(self, starknet_wrapper, private_key: int, public_key: int, initial_ def get_contract_class(cls): """Returns contract class via lazy loading.""" if not cls.CONTRACT_CLASS: - cls.CONTRACT_CLASS = ContractClass.load(load_nearby_contract(cls.CONTRACT_PATH)) + cls.CONTRACT_CLASS = ContractClass.load( + load_nearby_contract(cls.CONTRACT_PATH) + ) return cls.CONTRACT_CLASS def to_json(self): @@ -53,7 +63,7 @@ def to_json(self): "initial_balance": self.initial_balance, "private_key": hex(self.private_key), "public_key": hex(self.public_key), - "address": hex(self.address) + "address": hex(self.address), } async def deploy(self) -> StarknetContract: @@ -68,30 +78,40 @@ async def deploy(self) -> StarknetContract: newly_deployed_account_state = await ContractState.create( contract_hash=Account.HASH_BYTES, - storage_commitment_tree=account_state.storage_commitment_tree + storage_commitment_tree=account_state.storage_commitment_tree, ) starknet.state.state.contract_states[self.address] = ContractCarriedState( state=newly_deployed_account_state, storage_updates={ - get_selector_from_name("Account_public_key"): StorageLeaf(self.public_key) - } + get_selector_from_name("Account_public_key"): StorageLeaf( + self.public_key + ) + }, ) # set initial balance fee_token_address = starknet.state.general_config.fee_token_address - fee_token_storage_updates = starknet.state.state.contract_states[fee_token_address].storage_updates + fee_token_storage_updates = starknet.state.state.contract_states[ + fee_token_address + ].storage_updates - balance_address = pedersen_hash(get_selector_from_name("ERC20_balances"), self.address) + balance_address = pedersen_hash( + get_selector_from_name("ERC20_balances"), self.address + ) initial_balance_uint256 = Uint256.from_felt(self.initial_balance) - fee_token_storage_updates[balance_address] = StorageLeaf(initial_balance_uint256.low) - fee_token_storage_updates[balance_address + 1] = StorageLeaf(initial_balance_uint256.high) + fee_token_storage_updates[balance_address] = StorageLeaf( + initial_balance_uint256.low + ) + fee_token_storage_updates[balance_address + 1] = StorageLeaf( + initial_balance_uint256.high + ) - contract = StarknetContract( + contract = StarknetContract( state=starknet.state, abi=contract_class.abi, contract_address=self.address, - deploy_execution_info=None + deploy_execution_info=None, ) self.starknet_wrapper.store_contract(self.address, contract, contract_class) diff --git a/starknet_devnet/accounts.py b/starknet_devnet/accounts.py index 1c2743a6c..712b18df2 100644 --- a/starknet_devnet/accounts.py +++ b/starknet_devnet/accounts.py @@ -9,8 +9,10 @@ from starkware.crypto.signature.signature import private_to_stark_key from .account import Account + class Accounts: """Accounts wrapper""" + list: List[Account] def __init__(self, starknet_wrapper): @@ -22,7 +24,7 @@ def __init__(self, starknet_wrapper): self.__generate( n_accounts=starknet_wrapper.config.accounts, initial_balance=starknet_wrapper.config.initial_balance, - seed=starknet_wrapper.config.seed + seed=starknet_wrapper.config.seed, ) if starknet_wrapper.config.accounts: self.__print() @@ -54,12 +56,14 @@ def __generate(self, n_accounts: int, initial_balance: int, seed: int): private_key = random_generator.getrandbits(128) public_key = private_to_stark_key(private_key) - self.add(Account( - self.starknet_wrapper, - private_key=private_key, - public_key=public_key, - initial_balance=initial_balance - )) + self.add( + Account( + self.starknet_wrapper, + private_key=private_key, + public_key=public_key, + initial_balance=initial_balance, + ) + ) def __print(self): """stdout accounts list""" @@ -74,6 +78,6 @@ def __print(self): print( "WARNING: Use these accounts and their keys ONLY for local testing. " "DO NOT use them on mainnet or other live networks because you will LOSE FUNDS.\n", - file=sys.stderr + file=sys.stderr, ) sys.stdout.flush() diff --git a/starknet_devnet/block_info_generator.py b/starknet_devnet/block_info_generator.py index 67674706d..3812ffa30 100644 --- a/starknet_devnet/block_info_generator.py +++ b/starknet_devnet/block_info_generator.py @@ -9,11 +9,13 @@ from starknet_devnet.constants import CAIRO_LANG_VERSION + def now() -> int: """Get the current time in seconds.""" return int(time.time()) -class BlockInfoGenerator(): + +class BlockInfoGenerator: """Generator of BlockInfo objects with the correct timestamp""" def __init__(self, start_time: int = None, gas_price: int = 0): @@ -37,7 +39,7 @@ def next_block(self, block_info: BlockInfo, general_config: StarknetGeneralConfi block_number=block_info.block_number, block_timestamp=block_timestamp, sequencer_address=general_config.sequencer_address, - starknet_version=CAIRO_LANG_VERSION + starknet_version=CAIRO_LANG_VERSION, ) def increase_time(self, time_s: int): diff --git a/starknet_devnet/blocks.py b/starknet_devnet/blocks.py index 75ba44e7f..14426facf 100644 --- a/starknet_devnet/blocks.py +++ b/starknet_devnet/blocks.py @@ -6,8 +6,13 @@ from starkware.starknet.testing.state import StarknetState from starkware.starknet.core.os.block_hash.block_hash import calculate_block_hash -from starkware.starknet.services.api.feeder_gateway.response_objects import StarknetBlock, BlockStatus -from starkware.starknet.services.api.feeder_gateway.response_objects import BlockStateUpdate +from starkware.starknet.services.api.feeder_gateway.response_objects import ( + StarknetBlock, + BlockStatus, +) +from starkware.starknet.services.api.feeder_gateway.response_objects import ( + BlockStateUpdate, +) from starknet_devnet.constants import CAIRO_LANG_VERSION @@ -15,10 +20,11 @@ from .util import StarknetDevnetException from .transactions import DevnetTransaction -class DevnetBlocks(): + +class DevnetBlocks: """This class is used to store the generated blocks of the devnet.""" - def __init__(self, origin: Origin, lite = False) -> None: + def __init__(self, origin: Origin, lite=False) -> None: self.origin = origin self.lite = lite self.__num2block: Dict[int, StarknetBlock] = {} @@ -42,7 +48,9 @@ def get_by_number(self, block_number: int) -> StarknetBlock: return self.origin.get_block_by_number(block_number) if block_number < 0: - message = f"Block number must be a non-negative integer; got: {block_number}." + message = ( + f"Block number must be a non-negative integer; got: {block_number}." + ) raise StarknetDevnetException(message=message) if block_number >= self.get_number_of_blocks(): @@ -85,12 +93,17 @@ def get_state_update(self, block_hash=None, block_number=None) -> BlockStateUpda return self.__state_updates[block_number] - - return self.__state_updates.get(self.get_number_of_blocks() - 1) or self.origin.get_state_update() + return ( + self.__state_updates.get(self.get_number_of_blocks() - 1) + or self.origin.get_state_update() + ) async def generate( - self, transaction: DevnetTransaction, state: StarknetState, - state_update = None, is_empty_block = False + self, + transaction: DevnetTransaction, + state: StarknetState, + state_update=None, + is_empty_block=False, ) -> StarknetBlock: """ Generates a block and stores it to blocks and hash2block. The block contains just the passed transaction. @@ -111,7 +124,7 @@ async def generate( transactions = [] else: transaction_receipts = (transaction.get_execution(),) - transactions=[transaction.internal_tx] + transactions = [transaction.internal_tx] if self.lite or is_empty_block: block_hash = block_number @@ -126,7 +139,7 @@ async def generate( tx_hashes=[transaction.internal_tx.hash_value], tx_signatures=[signature], event_hashes=[], - sequencer_address=state.general_config.sequencer_address + sequencer_address=state.general_config.sequencer_address, ) block = StarknetBlock.create( @@ -140,7 +153,7 @@ async def generate( gas_price=state.state.block_info.gas_price, sequencer_address=state.general_config.sequencer_address, parent_block_hash=parent_block_hash, - starknet_version=CAIRO_LANG_VERSION + starknet_version=CAIRO_LANG_VERSION, ) self.__num2block[block_number] = block diff --git a/starknet_devnet/blueprints/base.py b/starknet_devnet/blueprints/base.py index 8ab631de2..2d152c74d 100644 --- a/starknet_devnet/blueprints/base.py +++ b/starknet_devnet/blueprints/base.py @@ -9,30 +9,41 @@ base = Blueprint("base", __name__) + def extract_int(value): """extract int from float if an integer value""" return isinstance(value, float) and value.is_integer() and int(value) or value + def extract_positive(request_json, prop_name: str): """Expects `prop_name` from `request_json` and expects it to be positive""" value = extract_int(request_json.get(prop_name)) if value is None: - raise StarknetDevnetException(message=f"{prop_name} value must be provided.", status_code=400) + raise StarknetDevnetException( + message=f"{prop_name} value must be provided.", status_code=400 + ) if not isinstance(value, int) or isinstance(value, bool): - raise StarknetDevnetException(message=f"{prop_name} value must be an integer.", status_code=400) + raise StarknetDevnetException( + message=f"{prop_name} value must be an integer.", status_code=400 + ) if value < 0: - raise StarknetDevnetException(message=f"{prop_name} value must be greater than 0.", status_code=400) + raise StarknetDevnetException( + message=f"{prop_name} value must be greater than 0.", status_code=400 + ) return value + def extract_hex_string(request_json, prop_name: str) -> int: """Parse value from hex string to int""" value = request_json.get(prop_name) if value is None: - raise StarknetDevnetException(status_code=400, message=f"{prop_name} value must be provided.") + raise StarknetDevnetException( + status_code=400, message=f"{prop_name} value must be provided." + ) try: return int(value, 16) @@ -46,12 +57,14 @@ def is_alive(): """Health check endpoint.""" return "Alive!!!" + @base.route("/restart", methods=["POST"]) async def restart(): """Restart the starknet_wrapper""" await state.reset() return Response(status=200) + @base.route("/dump", methods=["POST"]) def dump(): """Dumps the starknet_wrapper""" @@ -69,6 +82,7 @@ def dump(): state.dumper.dump(dump_path) return Response(status=200) + @base.route("/load", methods=["POST"]) def load(): """Loads the starknet_wrapper""" @@ -81,6 +95,7 @@ def load(): state.load(load_path) return Response(status=200) + @base.route("/increase_time", methods=["POST"]) def increase_time(): """Increases the block timestamp offset""" @@ -91,6 +106,7 @@ def increase_time(): return jsonify({"timestamp_increased_by": time_s}) + @base.route("/set_time", methods=["POST"]) def set_time(): """Sets the block timestamp offset""" @@ -101,16 +117,15 @@ def set_time(): return jsonify({"next_block_timestamp": time_s}) + @base.route("/account_balance", methods=["GET"]) async def get_balance(): """Gets balance for the address""" address = request.args.get("address", type=lambda x: int(x, 16)) balance = await state.starknet_wrapper.fee_token.get_balance(address) - return jsonify({ - "amount": balance, - "unit": "wei" - }) + return jsonify({"amount": balance, "unit": "wei"}) + @base.route("/predeployed_accounts", methods=["GET"]) def get_predeployed_accounts(): @@ -118,6 +133,7 @@ def get_predeployed_accounts(): accounts = state.starknet_wrapper.accounts return jsonify([account.to_json() for account in accounts]) + @base.route("/fee_token", methods=["GET"]) async def get_fee_token(): """Get the address of the fee token""" @@ -125,6 +141,7 @@ async def get_fee_token(): symbol = FeeToken.SYMBOL return jsonify({"symbol": symbol, "address": hex(fee_token_address)}) + @base.route("/mint", methods=["POST"]) async def mint(): """Mint token and transfer to the provided address""" @@ -135,17 +152,12 @@ async def mint(): is_lite = request_json.get("lite", False) tx_hash = await state.starknet_wrapper.fee_token.mint( - to_address=address, - amount=amount, - lite=is_lite + to_address=address, amount=amount, lite=is_lite ) new_balance = await state.starknet_wrapper.fee_token.get_balance(address) - return jsonify({ - "new_balance": new_balance, - "unit": "wei", - "tx_hash": tx_hash - }) + return jsonify({"new_balance": new_balance, "unit": "wei", "tx_hash": tx_hash}) + @base.route("/create_block", methods=["POST"]) async def create_block(): diff --git a/starknet_devnet/blueprints/feeder_gateway.py b/starknet_devnet/blueprints/feeder_gateway.py index 8671193b0..be19aac78 100644 --- a/starknet_devnet/blueprints/feeder_gateway.py +++ b/starknet_devnet/blueprints/feeder_gateway.py @@ -4,7 +4,9 @@ from flask import request, jsonify, Blueprint, Response from marshmallow import ValidationError -from starkware.starknet.services.api.feeder_gateway.response_objects import BlockTransactionTraces +from starkware.starknet.services.api.feeder_gateway.response_objects import ( + BlockTransactionTraces, +) from starkware.starknet.services.api.gateway.transaction import InvokeFunction from werkzeug.datastructures import MultiDict @@ -13,26 +15,34 @@ feeder_gateway = Blueprint("feeder_gateway", __name__, url_prefix="/feeder_gateway") + def validate_call(data: bytes): """Ensure `data` is valid Starknet function call. Returns an `InvokeFunction`.""" try: call_specifications = InvokeFunction.loads(data) except (TypeError, ValidationError) as err: - raise StarknetDevnetException(message=f"Invalid Starknet function call: {err}", status_code=400) from err + raise StarknetDevnetException( + message=f"Invalid Starknet function call: {err}", status_code=400 + ) from err return call_specifications + def _check_block_hash(request_args: MultiDict): block_hash = request_args.get("blockHash", type=custom_int) if block_hash is not None: - print("Specifying a block by its hash is not supported. All interaction is done with the latest block.") + print( + "Specifying a block by its hash is not supported. All interaction is done with the latest block." + ) + def _check_block_arguments(block_hash, block_number): if block_hash is not None and block_number is not None: message = "Ambiguous criteria: only one of (block number, block hash) can be provided." raise StarknetDevnetException(message=message, status_code=500) + def _get_block_object(block_hash: str, block_number: int): """Returns the block object""" @@ -45,6 +55,7 @@ def _get_block_object(block_hash: str, block_number: int): return block + def _get_block_transaction_traces(block): traces = [] if block.transaction_receipts: @@ -58,18 +69,21 @@ def _get_block_transaction_traces(block): traces.append(trace_dict) # assert correct structure - return BlockTransactionTraces.load({ "traces": traces }) + return BlockTransactionTraces.load({"traces": traces}) + @feeder_gateway.route("/is_alive", methods=["GET"]) def is_alive(): """Health check endpoint.""" return "Alive!!!" + @feeder_gateway.route("/get_contract_addresses", methods=["GET"]) def get_contract_addresses(): """Endpoint that returns an object containing the addresses of key system components.""" return "Not implemented", 501 + @feeder_gateway.route("/call_contract", methods=["POST"]) async def call_contract(): """ @@ -82,6 +96,7 @@ async def call_contract(): return jsonify(result_dict) + @feeder_gateway.route("/get_block", methods=["GET"]) def get_block(): """Endpoint for retrieving a block identified by its hash or number.""" @@ -93,6 +108,7 @@ def get_block(): return Response(block.dumps(), status=200, mimetype="application/json") + @feeder_gateway.route("/get_block_traces", methods=["GET"]) def get_block_traces(): """Returns the traces of the transactions in the specified block.""" @@ -105,6 +121,7 @@ def get_block_traces(): return jsonify(block_transaction_traces.dump()) + @feeder_gateway.route("/get_code", methods=["GET"]) def get_code(): """ @@ -117,6 +134,7 @@ def get_code(): result_dict = state.starknet_wrapper.contracts.get_code(contract_address) return jsonify(result_dict) + @feeder_gateway.route("/get_full_contract", methods=["GET"]) def get_full_contract(): """ @@ -126,10 +144,13 @@ def get_full_contract(): contract_address = request.args.get("contractAddress", type=custom_int) - contract_class = state.starknet_wrapper.contracts.get_full_contract(contract_address) + contract_class = state.starknet_wrapper.contracts.get_full_contract( + contract_address + ) return jsonify(contract_class.dump()) + @feeder_gateway.route("/get_class_hash_at", methods=["GET"]) def get_class_hash_at(): """Get contract class hash by contract address""" @@ -138,6 +159,7 @@ def get_class_hash_at(): class_hash = state.starknet_wrapper.contracts.get_class_hash_at(contract_address) return jsonify(fixed_length_hex(class_hash)) + @feeder_gateway.route("/get_class_by_hash", methods=["GET"]) def get_class_by_hash(): """Get contract class by class hash""" @@ -146,9 +168,10 @@ def get_class_by_hash(): contract_class = state.starknet_wrapper.contracts.get_class_by_hash(class_hash) return jsonify(contract_class.dump()) + @feeder_gateway.route("/get_storage_at", methods=["GET"]) async def get_storage_at(): - """Endpoint for returning the storage identified by `key` from the contract at """ + """Endpoint for returning the storage identified by `key` from the contract at""" _check_block_hash(request.args) contract_address = request.args.get("contractAddress", type=custom_int) @@ -157,6 +180,7 @@ async def get_storage_at(): storage = await state.starknet_wrapper.get_storage_at(contract_address, key) return jsonify(storage) + @feeder_gateway.route("/get_transaction_status", methods=["GET"]) def get_transaction_status(): """ @@ -164,9 +188,12 @@ def get_transaction_status(): """ transaction_hash = request.args.get("transactionHash") - transaction_status = state.starknet_wrapper.transactions.get_transaction_status(transaction_hash) + transaction_status = state.starknet_wrapper.transactions.get_transaction_status( + transaction_hash + ) return jsonify(transaction_status) + @feeder_gateway.route("/get_transaction", methods=["GET"]) def get_transaction(): """ @@ -174,8 +201,13 @@ def get_transaction(): """ transaction_hash = request.args.get("transactionHash") - transaction_info = state.starknet_wrapper.transactions.get_transaction(transaction_hash) - return Response(response=transaction_info.dumps(), status=200, mimetype="application/json") + transaction_info = state.starknet_wrapper.transactions.get_transaction( + transaction_hash + ) + return Response( + response=transaction_info.dumps(), status=200, mimetype="application/json" + ) + @feeder_gateway.route("/get_transaction_receipt", methods=["GET"]) def get_transaction_receipt(): @@ -184,8 +216,13 @@ def get_transaction_receipt(): """ transaction_hash = request.args.get("transactionHash") - transaction_receipt = state.starknet_wrapper.transactions.get_transaction_receipt(transaction_hash) - return Response(response=transaction_receipt.dumps(), status=200, mimetype="application/json") + transaction_receipt = state.starknet_wrapper.transactions.get_transaction_receipt( + transaction_hash + ) + return Response( + response=transaction_receipt.dumps(), status=200, mimetype="application/json" + ) + @feeder_gateway.route("/get_transaction_trace", methods=["GET"]) def get_transaction_trace(): @@ -194,9 +231,14 @@ def get_transaction_trace(): """ transaction_hash = request.args.get("transactionHash") - transaction_trace = state.starknet_wrapper.transactions.get_transaction_trace(transaction_hash) + transaction_trace = state.starknet_wrapper.transactions.get_transaction_trace( + transaction_hash + ) + + return Response( + response=transaction_trace.dumps(), status=200, mimetype="application/json" + ) - return Response(response=transaction_trace.dumps(), status=200, mimetype="application/json") @feeder_gateway.route("/get_state_update", methods=["GET"]) def get_state_update(): @@ -208,13 +250,18 @@ def get_state_update(): block_hash = request.args.get("blockHash") block_number = request.args.get("blockNumber", type=custom_int) - state_update = state.starknet_wrapper.blocks.get_state_update(block_hash=block_hash, block_number=block_number) + state_update = state.starknet_wrapper.blocks.get_state_update( + block_hash=block_hash, block_number=block_number + ) if state_update is not None: - return Response(response=state_update.dumps(), status=200, mimetype="application/json") + return Response( + response=state_update.dumps(), status=200, mimetype="application/json" + ) return jsonify(state_update) + @feeder_gateway.route("/estimate_fee", methods=["POST"]) async def estimate_fee(): """Returns the estimated fee for a transaction.""" diff --git a/starknet_devnet/blueprints/gateway.py b/starknet_devnet/blueprints/gateway.py index 528c99cf6..07f165ffb 100644 --- a/starknet_devnet/blueprints/gateway.py +++ b/starknet_devnet/blueprints/gateway.py @@ -6,17 +6,19 @@ from starkware.starkware_utils.error_handling import StarkErrorCode from starknet_devnet.devnet_config import DumpOn -from starknet_devnet.util import StarknetDevnetException,fixed_length_hex +from starknet_devnet.util import StarknetDevnetException, fixed_length_hex from starknet_devnet.state import state from .shared import validate_transaction gateway = Blueprint("gateway", __name__, url_prefix="/gateway") + @gateway.route("/is_alive", methods=["GET"]) def is_alive(): """Health check endpoint.""" return "Alive!!!" + @gateway.route("/add_transaction", methods=["POST"]) async def add_transaction(): """Endpoint for accepting DEPLOY and INVOKE_FUNCTION transactions.""" @@ -29,20 +31,30 @@ async def add_transaction(): } if tx_type == TransactionType.DECLARE: - contract_class_hash, transaction_hash = await state.starknet_wrapper.declare(transaction) + contract_class_hash, transaction_hash = await state.starknet_wrapper.declare( + transaction + ) response_dict["class_hash"] = hex(contract_class_hash) elif tx_type == TransactionType.DEPLOY: - contract_address, transaction_hash = await state.starknet_wrapper.deploy(transaction) + contract_address, transaction_hash = await state.starknet_wrapper.deploy( + transaction + ) response_dict["address"] = fixed_length_hex(contract_address) elif tx_type == TransactionType.INVOKE_FUNCTION: - contract_address, transaction_hash, result_dict = await state.starknet_wrapper.invoke(transaction) + ( + contract_address, + transaction_hash, + result_dict, + ) = await state.starknet_wrapper.invoke(transaction) response_dict["address"] = fixed_length_hex(contract_address) response_dict.update(result_dict) else: - raise StarknetDevnetException(message=f"Invalid tx_type: {tx_type.name}.", status_code=400) + raise StarknetDevnetException( + message=f"Invalid tx_type: {tx_type.name}.", status_code=400 + ) response_dict["transaction_hash"] = hex(transaction_hash) diff --git a/starknet_devnet/blueprints/postman.py b/starknet_devnet/blueprints/postman.py index 4d56e1acb..9be9bfa59 100644 --- a/starknet_devnet/blueprints/postman.py +++ b/starknet_devnet/blueprints/postman.py @@ -10,6 +10,7 @@ postman = Blueprint("postman", __name__, url_prefix="/postman") + def validate_load_messaging_contract(request_dict: dict): """Ensure `data` is valid Starknet function call. Returns an `InvokeFunction`.""" @@ -20,6 +21,7 @@ def validate_load_messaging_contract(request_dict: dict): return network_url + @postman.route("/load_l1_messaging_contract", methods=["POST"]) async def load_l1_messaging_contract(): """ @@ -33,14 +35,17 @@ async def load_l1_messaging_contract(): contract_address = request_dict.get("address") network_id = request_dict.get("networkId") - result_dict = await state.starknet_wrapper.load_messaging_contract_in_l1(network_url, contract_address, network_id) + result_dict = await state.starknet_wrapper.load_messaging_contract_in_l1( + network_url, contract_address, network_id + ) return jsonify(result_dict) + @postman.route("/flush", methods=["POST"]) async def flush(): """ Handles all pending L1 <> L2 messages and sends them to the other layer """ - result_dict= await state.starknet_wrapper.postman_flush() + result_dict = await state.starknet_wrapper.postman_flush() return jsonify(result_dict) diff --git a/starknet_devnet/blueprints/rpc/call.py b/starknet_devnet/blueprints/rpc/call.py index d63fe6529..5d1f343ef 100644 --- a/starknet_devnet/blueprints/rpc/call.py +++ b/starknet_devnet/blueprints/rpc/call.py @@ -7,7 +7,10 @@ from starkware.starkware_utils.error_handling import StarkException from starknet_devnet.blueprints.rpc.utils import rpc_felt, assert_block_id_is_latest -from starknet_devnet.blueprints.rpc.structures.payloads import make_invoke_function, FunctionCall +from starknet_devnet.blueprints.rpc.structures.payloads import ( + make_invoke_function, + FunctionCall, +) from starknet_devnet.blueprints.rpc.structures.types import Felt, BlockId, RpcError from starknet_devnet.state import state from starknet_devnet.util import StarknetDevnetException @@ -19,11 +22,15 @@ async def call(request: FunctionCall, block_id: BlockId) -> List[Felt]: """ assert_block_id_is_latest(block_id) - if not state.starknet_wrapper.contracts.is_deployed(int(request["contract_address"], 16)): + if not state.starknet_wrapper.contracts.is_deployed( + int(request["contract_address"], 16) + ): raise RpcError(code=20, message="Contract not found") try: - result = await state.starknet_wrapper.call(transaction=make_invoke_function(request)) + result = await state.starknet_wrapper.call( + transaction=make_invoke_function(request) + ) result["result"] = [rpc_felt(int(res, 16)) for res in result["result"]] return result except StarknetDevnetException as ex: diff --git a/starknet_devnet/blueprints/rpc/classes.py b/starknet_devnet/blueprints/rpc/classes.py index 4905fb6de..9c60dcd08 100644 --- a/starknet_devnet/blueprints/rpc/classes.py +++ b/starknet_devnet/blueprints/rpc/classes.py @@ -4,7 +4,12 @@ from starknet_devnet.blueprints.rpc.utils import assert_block_id_is_latest, rpc_felt from starknet_devnet.blueprints.rpc.structures.payloads import rpc_contract_class -from starknet_devnet.blueprints.rpc.structures.types import BlockId, Address, Felt, RpcError +from starknet_devnet.blueprints.rpc.structures.types import ( + BlockId, + Address, + Felt, + RpcError, +) from starknet_devnet.state import state from starknet_devnet.util import StarknetDevnetException @@ -14,9 +19,13 @@ async def get_class(class_hash: Felt) -> dict: Get the contract class definition associated with the given hash """ try: - result = state.starknet_wrapper.contracts.get_class_by_hash(class_hash=int(class_hash, 16)) + result = state.starknet_wrapper.contracts.get_class_by_hash( + class_hash=int(class_hash, 16) + ) except StarknetDevnetException as ex: - raise RpcError(code=28, message="The supplied contract class hash is invalid or unknown") from ex + raise RpcError( + code=28, message="The supplied contract class hash is invalid or unknown" + ) from ex return rpc_contract_class(result) @@ -28,9 +37,13 @@ async def get_class_hash_at(block_id: BlockId, contract_address: Address) -> Fel assert_block_id_is_latest(block_id) try: - result = state.starknet_wrapper.contracts.get_class_hash_at(address=int(contract_address, 16)) + result = state.starknet_wrapper.contracts.get_class_hash_at( + address=int(contract_address, 16) + ) except StarknetDevnetException as ex: - raise RpcError(code=28, message="The supplied contract class hash is invalid or unknown") from ex + raise RpcError( + code=28, message="The supplied contract class hash is invalid or unknown" + ) from ex return rpc_felt(result) @@ -42,8 +55,12 @@ async def get_class_at(block_id: BlockId, contract_address: Address) -> dict: assert_block_id_is_latest(block_id) try: - class_hash = state.starknet_wrapper.contracts.get_class_hash_at(address=int(contract_address, 16)) - result = state.starknet_wrapper.contracts.get_class_by_hash(class_hash=class_hash) + class_hash = state.starknet_wrapper.contracts.get_class_hash_at( + address=int(contract_address, 16) + ) + result = state.starknet_wrapper.contracts.get_class_by_hash( + class_hash=class_hash + ) except StarknetDevnetException as ex: raise RpcError(code=20, message="Contract not found") from ex diff --git a/starknet_devnet/blueprints/rpc/routes.py b/starknet_devnet/blueprints/rpc/routes.py index 213570ec7..6cec209d4 100644 --- a/starknet_devnet/blueprints/rpc/routes.py +++ b/starknet_devnet/blueprints/rpc/routes.py @@ -11,16 +11,32 @@ from flask import Blueprint from flask import request -from starknet_devnet.blueprints.rpc.blocks import get_block_with_tx_hashes, get_block_with_txs, \ - get_block_transaction_count, block_number, block_hash_and_number +from starknet_devnet.blueprints.rpc.blocks import ( + get_block_with_tx_hashes, + get_block_with_txs, + get_block_transaction_count, + block_number, + block_hash_and_number, +) from starknet_devnet.blueprints.rpc.call import call -from starknet_devnet.blueprints.rpc.classes import get_class, get_class_hash_at, get_class_at +from starknet_devnet.blueprints.rpc.classes import ( + get_class, + get_class_hash_at, + get_class_at, +) from starknet_devnet.blueprints.rpc.misc import chain_id, syncing, get_events, get_nonce from starknet_devnet.blueprints.rpc.state import get_state_update from starknet_devnet.blueprints.rpc.storage import get_storage_at -from starknet_devnet.blueprints.rpc.transactions import get_transaction_by_hash, \ - get_transaction_by_block_id_and_index, get_transaction_receipt, estimate_fee, pending_transactions, \ - add_invoke_transaction, add_declare_transaction, add_deploy_transaction +from starknet_devnet.blueprints.rpc.transactions import ( + get_transaction_by_hash, + get_transaction_by_block_id_and_index, + get_transaction_receipt, + estimate_fee, + pending_transactions, + add_invoke_transaction, + add_declare_transaction, + add_deploy_transaction, +) from starknet_devnet.blueprints.rpc.utils import rpc_response, rpc_error from starknet_devnet.blueprints.rpc.structures.types import RpcError @@ -63,7 +79,9 @@ async def base_route(): try: result = await method(*args) if isinstance(args, list) else await method(**args) except NotImplementedError: - return rpc_error(message_id=message_id, code=-2, message="Method not implemented") + return rpc_error( + message_id=message_id, code=-2, message="Method not implemented" + ) except RpcError as error: return rpc_error(message_id=message_id, code=error.code, message=error.message) diff --git a/starknet_devnet/blueprints/rpc/state.py b/starknet_devnet/blueprints/rpc/state.py index d45ea203c..1e0283b00 100644 --- a/starknet_devnet/blueprints/rpc/state.py +++ b/starknet_devnet/blueprints/rpc/state.py @@ -17,9 +17,13 @@ async def get_state_update(block_id: BlockId) -> dict: try: if "block_hash" in block_id: - result = state.starknet_wrapper.blocks.get_state_update(block_hash=block_id["block_hash"]) + result = state.starknet_wrapper.blocks.get_state_update( + block_hash=block_id["block_hash"] + ) else: - result = state.starknet_wrapper.blocks.get_state_update(block_number=block_id["block_number"]) + result = state.starknet_wrapper.blocks.get_state_update( + block_number=block_id["block_number"] + ) except StarknetDevnetException as ex: raise RpcError(code=24, message="Invalid block id") from ex diff --git a/starknet_devnet/blueprints/rpc/storage.py b/starknet_devnet/blueprints/rpc/storage.py index e90382ae2..4ab2d182c 100644 --- a/starknet_devnet/blueprints/rpc/storage.py +++ b/starknet_devnet/blueprints/rpc/storage.py @@ -3,11 +3,18 @@ """ from starknet_devnet.blueprints.rpc.utils import assert_block_id_is_latest, rpc_felt -from starknet_devnet.blueprints.rpc.structures.types import Address, BlockId, Felt, RpcError +from starknet_devnet.blueprints.rpc.structures.types import ( + Address, + BlockId, + Felt, + RpcError, +) from starknet_devnet.state import state -async def get_storage_at(contract_address: Address, key: str, block_id: BlockId) -> Felt: +async def get_storage_at( + contract_address: Address, key: str, block_id: BlockId +) -> Felt: """ Get the value of the storage at the given address and key """ @@ -17,7 +24,6 @@ async def get_storage_at(contract_address: Address, key: str, block_id: BlockId) raise RpcError(code=20, message="Contract not found") storage = await state.starknet_wrapper.get_storage_at( - contract_address=int(contract_address, 16), - key=int(key, 16) + contract_address=int(contract_address, 16), key=int(key, 16) ) return rpc_felt(int(storage, 16)) diff --git a/starknet_devnet/blueprints/rpc/structures/payloads.py b/starknet_devnet/blueprints/rpc/structures/payloads.py index bc73daf80..7c6a5d814 100644 --- a/starknet_devnet/blueprints/rpc/structures/payloads.py +++ b/starknet_devnet/blueprints/rpc/structures/payloads.py @@ -15,15 +15,25 @@ TransactionSpecificInfo, TransactionType, BlockStateUpdate, - DeclareSpecificInfo + DeclareSpecificInfo, ) from starkware.starknet.services.api.gateway.transaction import InvokeFunction from starkware.starknet.services.api.gateway.transaction_utils import compress_program from typing_extensions import TypedDict from starknet_devnet.blueprints.rpc.utils import rpc_root, rpc_felt -from starknet_devnet.blueprints.rpc.structures.types import RpcBlockStatus, BlockHash, BlockNumber, Felt, \ - rpc_block_status, TxnHash, Address, NumAsHex, TxnType, rpc_txn_type +from starknet_devnet.blueprints.rpc.structures.types import ( + RpcBlockStatus, + BlockHash, + BlockNumber, + Felt, + rpc_block_status, + TxnHash, + Address, + NumAsHex, + TxnType, + rpc_txn_type, +) from starknet_devnet.state import state @@ -31,6 +41,7 @@ class RpcBlock(TypedDict): """ TypeDict for rpc block """ + status: RpcBlockStatus block_hash: BlockHash parent_hash: BlockHash @@ -41,10 +52,13 @@ class RpcBlock(TypedDict): transactions: Union[List[str], List[RpcTransaction]] -async def rpc_block(block: StarknetBlock, tx_type: Optional[str] = "TXN_HASH") -> RpcBlock: +async def rpc_block( + block: StarknetBlock, tx_type: Optional[str] = "TXN_HASH" +) -> RpcBlock: """ Convert gateway block to rpc block """ + async def transactions() -> List[RpcTransaction]: # pylint: disable=no-member return [rpc_transaction(tx) for tx in block.transactions] @@ -84,6 +98,7 @@ class RpcInvokeTransaction(TypedDict): """ TypedDict for rpc invoke transaction """ + contract_address: Address entry_point_selector: Felt calldata: List[Felt] @@ -100,6 +115,7 @@ class RpcDeclareTransaction(TypedDict): """ TypedDict for rpc declare transaction """ + class_hash: Felt sender_address: Address # Common @@ -115,6 +131,7 @@ class RpcDeployTransaction(TypedDict): """ TypedDict for rpc deploy transaction """ + transaction_hash: TxnHash class_hash: Felt version: NumAsHex @@ -124,7 +141,9 @@ class RpcDeployTransaction(TypedDict): constructor_calldata: List[Felt] -RpcTransaction = Union[RpcInvokeTransaction, RpcDeclareTransaction, RpcDeployTransaction] +RpcTransaction = Union[ + RpcInvokeTransaction, RpcDeclareTransaction, RpcDeployTransaction +] def rpc_transaction(transaction: TransactionSpecificInfo) -> RpcTransaction: @@ -143,6 +162,7 @@ class FunctionCall(TypedDict): """ TypedDict for rpc function call """ + contract_address: Address entry_point_selector: Felt calldata: List[Felt] @@ -194,7 +214,9 @@ def rpc_deploy_transaction(transaction: DeploySpecificInfo) -> RpcDeployTransact "type": rpc_txn_type(transaction.tx_type.name), "contract_address": rpc_felt(transaction.contract_address), "contract_address_salt": rpc_felt(transaction.contract_address_salt), - "constructor_calldata": [rpc_felt(data) for data in transaction.constructor_calldata], + "constructor_calldata": [ + rpc_felt(data) for data in transaction.constructor_calldata + ], } return txn @@ -203,6 +225,7 @@ class RpcFeeEstimate(TypedDict): """ Fee estimate TypedDict for rpc """ + gas_consumed: NumAsHex gas_price: NumAsHex overall_fee: NumAsHex @@ -238,6 +261,7 @@ class EntryPoint(TypedDict): """ TypedDict for rpc contract class entry point """ + offset: NumAsHex selector: Felt @@ -246,6 +270,7 @@ class EntryPoints(TypedDict): """ TypedDict for rpc contract class entry points """ + CONSTRUCTOR: List[EntryPoint] EXTERNAL: List[EntryPoint] L1_HANDLER: List[EntryPoint] @@ -255,6 +280,7 @@ class RpcContractClass(TypedDict): """ TypedDict for rpc contract class """ + program: str entry_points_by_type: EntryPoints @@ -263,6 +289,7 @@ def rpc_contract_class(contract_class: ContractClass) -> RpcContractClass: """ Convert gateway contract class to rpc contract class """ + def program() -> str: _program = contract_class.program.Schema().dump(contract_class.program) return compress_program(_program) @@ -277,14 +304,14 @@ def entry_points_by_type() -> EntryPoints: for entry_point in entry_points: _entry_point: EntryPoint = { "selector": rpc_felt(entry_point.selector), - "offset": hex(entry_point.offset) + "offset": hex(entry_point.offset), } _entry_points[typ.name].append(_entry_point) return _entry_points _contract_class: RpcContractClass = { "program": program(), - "entry_points_by_type": entry_points_by_type() + "entry_points_by_type": entry_points_by_type(), } return _contract_class @@ -293,6 +320,7 @@ class RpcStorageDiff(TypedDict): """ TypedDict for rpc storage diff """ + address: Felt key: Felt value: Felt @@ -302,6 +330,7 @@ class RpcDeclaredContractDiff(TypedDict): """ TypedDict for rpc declared contract diff """ + class_hash: Felt @@ -309,6 +338,7 @@ class RpcDeployedContractDiff(TypedDict): """ TypedDict for rpc deployed contract diff """ + address: Felt class_hash: Felt @@ -317,6 +347,7 @@ class RpcNonceDiff(TypedDict): """ TypedDict for rpc nonce diff """ + contract_address: Address nonce: Felt @@ -325,6 +356,7 @@ class RpcStateDiff(TypedDict): """ TypedDict for rpc state diff """ + storage_diffs: List[RpcStorageDiff] declared_contracts: List[RpcDeclaredContractDiff] deployed_contracts: List[RpcDeployedContractDiff] @@ -335,6 +367,7 @@ class RpcStateUpdate(TypedDict): """ TypedDict for rpc state update """ + block_hash: BlockHash new_root: Felt old_root: Felt @@ -345,6 +378,7 @@ def rpc_state_update(state_update: BlockStateUpdate) -> RpcStateUpdate: """ Convert gateway state update to rpc state update """ + def storage_diffs() -> List[RpcStorageDiff]: _storage_diffs = [] for address, diffs in state_update.state_diff.storage_diffs.items(): @@ -360,9 +394,7 @@ def storage_diffs() -> List[RpcStorageDiff]: def declared_contracts() -> List[RpcDeclaredContractDiff]: _contracts = [] for contract in state_update.state_diff.declared_contracts: - diff: RpcDeclaredContractDiff = { - "class_hash": rpc_felt(contract) - } + diff: RpcDeclaredContractDiff = {"class_hash": rpc_felt(contract)} _contracts.append(diff) return _contracts @@ -371,7 +403,7 @@ def deployed_contracts() -> List[RpcDeployedContractDiff]: for contract in state_update.state_diff.deployed_contracts: diff: RpcDeployedContractDiff = { "address": rpc_felt(contract.address), - "class_hash": rpc_felt(contract.class_hash) + "class_hash": rpc_felt(contract.class_hash), } _contracts.append(diff) return _contracts @@ -385,6 +417,6 @@ def deployed_contracts() -> List[RpcDeployedContractDiff]: "declared_contracts": declared_contracts(), "deployed_contracts": deployed_contracts(), "nonces": [], - } + }, } return rpc_state diff --git a/starknet_devnet/blueprints/rpc/structures/responses.py b/starknet_devnet/blueprints/rpc/structures/responses.py index 1d6c68468..78631b38a 100644 --- a/starknet_devnet/blueprints/rpc/structures/responses.py +++ b/starknet_devnet/blueprints/rpc/structures/responses.py @@ -6,10 +6,20 @@ from typing_extensions import TypedDict from starkware.starknet.definitions.transaction_type import TransactionType -from starkware.starknet.services.api.feeder_gateway.response_objects import TransactionReceipt, TransactionStatus +from starkware.starknet.services.api.feeder_gateway.response_objects import ( + TransactionReceipt, + TransactionStatus, +) from starknet_devnet.blueprints.rpc.utils import rpc_felt -from starknet_devnet.blueprints.rpc.structures.types import TxnHash, Felt, Address, BlockNumber, BlockHash, TxnStatus +from starknet_devnet.blueprints.rpc.structures.types import ( + TxnHash, + Felt, + Address, + BlockNumber, + BlockHash, + TxnStatus, +) from starknet_devnet.state import state @@ -17,6 +27,7 @@ class RpcInvokeTransactionResult(TypedDict): """ TypedDict for rpc invoke transaction result """ + transaction_hash: TxnHash @@ -24,6 +35,7 @@ class RpcDeclareTransactionResult(TypedDict): """ TypedDict for rpc declare transaction result """ + transaction_hash: TxnHash class_hash: Felt @@ -32,6 +44,7 @@ class RpcDeployTransactionResult(TypedDict): """ TypedDict for rpc deploy transaction result """ + transaction_hash: TxnHash contract_address: Felt @@ -40,6 +53,7 @@ class MessageToL1(TypedDict): """ TypedDict for rpc message from l2 to l1 """ + to_address: Felt payload: List[Felt] @@ -48,6 +62,7 @@ class MessageToL2(TypedDict): """ TypedDict for rpc message from l1 to l2 """ + from_address: str payload: List[Felt] @@ -56,6 +71,7 @@ class Event(TypedDict): """ TypedDict for rpc event """ + from_address: Address keys: List[Felt] data: List[Felt] @@ -65,6 +81,7 @@ class RpcBaseTransactionReceipt(TypedDict): """ TypedDict for rpc transaction receipt """ + # Common transaction_hash: TxnHash actual_fee: Felt @@ -78,6 +95,7 @@ class RpcInvokeReceipt(TypedDict): """ TypedDict for rpc invoke transaction receipt """ + messages_sent: List[MessageToL1] l1_origin_message: Optional[MessageToL2] events: List[Event] @@ -94,6 +112,7 @@ class RpcDeclareReceipt(TypedDict): """ TypedDict for rpc declare transaction receipt """ + # Common transaction_hash: TxnHash actual_fee: Felt @@ -107,6 +126,7 @@ class RpcDeployReceipt(TypedDict): """ TypedDict for rpc declare transaction receipt """ + # Common transaction_hash: TxnHash actual_fee: Felt @@ -120,11 +140,15 @@ def rpc_invoke_receipt(txr: TransactionReceipt) -> RpcInvokeReceipt: """ Convert rpc invoke transaction receipt to rpc format """ + def l2_to_l1_messages() -> List[MessageToL1]: - return [{ - "to_address": rpc_felt(message.to_address), - "payload": [rpc_felt(p) for p in message.payload] - } for message in txr.l2_to_l1_messages] + return [ + { + "to_address": rpc_felt(message.to_address), + "payload": [rpc_felt(p) for p in message.payload], + } + for message in txr.l2_to_l1_messages + ] def l1_to_l2_message() -> Optional[MessageToL2]: if txr.l1_to_l2_consumed_message is None: @@ -132,7 +156,7 @@ def l1_to_l2_message() -> Optional[MessageToL2]: msg: MessageToL2 = { "from_address": txr.l1_to_l2_consumed_message.from_address, - "payload": [rpc_felt(p) for p in txr.l1_to_l2_consumed_message.payload] + "payload": [rpc_felt(p) for p in txr.l1_to_l2_consumed_message.payload], } return msg @@ -175,6 +199,7 @@ def rpc_base_transaction_receipt(txr: TransactionReceipt) -> RpcBaseTransactionR """ Convert gateway transaction receipt to rpc transaction receipt """ + def status() -> str: if txr.status is None: return "UNKNOWN" @@ -215,6 +240,8 @@ def rpc_transaction_receipt(txr: TransactionReceipt) -> dict: TransactionType.INVOKE_FUNCTION: rpc_invoke_receipt, TransactionType.DECLARE: rpc_declare_receipt, } - transaction = state.starknet_wrapper.transactions.get_transaction(hex(txr.transaction_hash)).transaction + transaction = state.starknet_wrapper.transactions.get_transaction( + hex(txr.transaction_hash) + ).transaction tx_type = transaction.tx_type return tx_mapping[tx_type](txr) diff --git a/starknet_devnet/blueprints/rpc/structures/types.py b/starknet_devnet/blueprints/rpc/structures/types.py index 022ee722f..7f4d6f308 100644 --- a/starknet_devnet/blueprints/rpc/structures/types.py +++ b/starknet_devnet/blueprints/rpc/structures/types.py @@ -19,6 +19,7 @@ class BlockHashDict(TypedDict): """ TypedDict class for BlockId with block hash """ + block_hash: BlockHash @@ -26,6 +27,7 @@ class BlockNumberDict(TypedDict): """ TypedDict class for BlockId with block number """ + block_number: BlockNumber @@ -45,7 +47,7 @@ def rpc_block_status(block_status: BlockStatus) -> RpcBlockStatus: "ABORTED": "REJECTED", "REVERTED": "REJECTED", "ACCEPTED_ON_L2": "ACCEPTED_ON_L2", - "ACCEPTED_ON_L1": "ACCEPTED_ON_L1" + "ACCEPTED_ON_L1": "ACCEPTED_ON_L1", } return block_status_map[block_status] @@ -69,7 +71,10 @@ def rpc_txn_type(transaction_type: str) -> TxnType: "INVOKE_FUNCTION": "INVOKE", } if transaction_type not in txn_type_map: - raise RpcError(code=-1, message=f"Current implementation does not support {transaction_type} transaction type") + raise RpcError( + code=-1, + message=f"Current implementation does not support {transaction_type} transaction type", + ) return txn_type_map[transaction_type] diff --git a/starknet_devnet/blueprints/rpc/transactions.py b/starknet_devnet/blueprints/rpc/transactions.py index 89ab9400a..60704aae2 100644 --- a/starknet_devnet/blueprints/rpc/transactions.py +++ b/starknet_devnet/blueprints/rpc/transactions.py @@ -8,18 +8,45 @@ from marshmallow.exceptions import MarshmallowError from starkware.starknet.definitions import constants from starkware.starknet.services.api.contract_class import ContractClass -from starkware.starknet.services.api.feeder_gateway.response_objects import TransactionStatus -from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Declare, DECLARE_SENDER_ADDRESS, Deploy +from starkware.starknet.services.api.feeder_gateway.response_objects import ( + TransactionStatus, +) +from starkware.starknet.services.api.gateway.transaction import ( + InvokeFunction, + Declare, + DECLARE_SENDER_ADDRESS, + Deploy, +) from starkware.starknet.services.api.gateway.transaction_utils import decompress_program from starkware.starkware_utils.error_handling import StarkException -from starknet_devnet.blueprints.rpc.utils import get_block_by_block_id, rpc_felt, \ - assert_block_id_is_latest -from starknet_devnet.blueprints.rpc.structures.payloads import rpc_transaction, RpcTransaction, FunctionCall, \ - RpcContractClass, RpcInvokeTransaction, make_invoke_function, rpc_fee_estimate -from starknet_devnet.blueprints.rpc.structures.responses import rpc_transaction_receipt, RpcInvokeTransactionResult, \ - RpcDeclareTransactionResult, RpcDeployTransactionResult -from starknet_devnet.blueprints.rpc.structures.types import TxnHash, BlockId, NumAsHex, Felt, RpcError +from starknet_devnet.blueprints.rpc.utils import ( + get_block_by_block_id, + rpc_felt, + assert_block_id_is_latest, +) +from starknet_devnet.blueprints.rpc.structures.payloads import ( + rpc_transaction, + RpcTransaction, + FunctionCall, + RpcContractClass, + RpcInvokeTransaction, + make_invoke_function, + rpc_fee_estimate, +) +from starknet_devnet.blueprints.rpc.structures.responses import ( + rpc_transaction_receipt, + RpcInvokeTransactionResult, + RpcDeclareTransactionResult, + RpcDeployTransactionResult, +) +from starknet_devnet.blueprints.rpc.structures.types import ( + TxnHash, + BlockId, + NumAsHex, + Felt, + RpcError, +) from starknet_devnet.state import state from starknet_devnet.util import StarknetDevnetException @@ -58,7 +85,9 @@ async def get_transaction_receipt(transaction_hash: TxnHash) -> dict: Get the transaction receipt by the transaction hash """ try: - result = state.starknet_wrapper.transactions.get_transaction_receipt(tx_hash=transaction_hash) + result = state.starknet_wrapper.transactions.get_transaction_receipt( + tx_hash=transaction_hash + ) except StarknetDevnetException as ex: raise RpcError(code=25, message="Invalid transaction hash") from ex @@ -75,8 +104,12 @@ async def pending_transactions() -> List[RpcTransaction]: raise NotImplementedError() -async def add_invoke_transaction(function_invocation: FunctionCall, max_fee: NumAsHex, version: NumAsHex, - signature: Optional[List[Felt]] = None) -> dict: +async def add_invoke_transaction( + function_invocation: FunctionCall, + max_fee: NumAsHex, + version: NumAsHex, + signature: Optional[List[Felt]] = None, +) -> dict: """ Submit a new transaction to be added to the chain """ @@ -86,21 +119,29 @@ async def add_invoke_transaction(function_invocation: FunctionCall, max_fee: Num calldata=[int(data, 16) for data in function_invocation["calldata"]], max_fee=int(max_fee, 16), version=int(version, 16), - signature=[int(data, 16) for data in signature] if signature is not None else [], + signature=[int(data, 16) for data in signature] + if signature is not None + else [], ) - _, transaction_hash, _ = await state.starknet_wrapper.invoke(invoke_function=invoke_function) + _, transaction_hash, _ = await state.starknet_wrapper.invoke( + invoke_function=invoke_function + ) return RpcInvokeTransactionResult( transaction_hash=rpc_felt(transaction_hash), ) -async def add_declare_transaction(contract_class: RpcContractClass, version: NumAsHex) -> dict: +async def add_declare_transaction( + contract_class: RpcContractClass, version: NumAsHex +) -> dict: """ Submit a new class declaration transaction """ try: - decompressed_program = decompress_program({"contract_class": contract_class}, False) + decompressed_program = decompress_program( + {"contract_class": contract_class}, False + ) decompressed_program = decompressed_program["contract_class"] contract_definition = ContractClass.load(decompressed_program) @@ -118,20 +159,27 @@ async def add_declare_transaction(contract_class: RpcContractClass, version: Num nonce=0, ) - class_hash, transaction_hash = await state.starknet_wrapper.declare(declare_transaction=declare_transaction) + class_hash, transaction_hash = await state.starknet_wrapper.declare( + declare_transaction=declare_transaction + ) return RpcDeclareTransactionResult( transaction_hash=rpc_felt(transaction_hash), class_hash=rpc_felt(class_hash), ) -async def add_deploy_transaction(contract_address_salt: Felt, constructor_calldata: List[Felt], - contract_definition: RpcContractClass) -> dict: +async def add_deploy_transaction( + contract_address_salt: Felt, + constructor_calldata: List[Felt], + contract_definition: RpcContractClass, +) -> dict: """ Submit a new deploy contract transaction """ try: - decompressed_program = decompress_program({"contract_definition": contract_definition}, False) + decompressed_program = decompress_program( + {"contract_definition": contract_definition}, False + ) decompressed_program = decompressed_program["contract_definition"] contract_class = ContractClass.load(decompressed_program) @@ -146,7 +194,9 @@ async def add_deploy_transaction(contract_address_salt: Felt, constructor_callda version=constants.TRANSACTION_VERSION, ) - contract_address, transaction_hash = await state.starknet_wrapper.deploy(deploy_transaction=deploy_transaction) + contract_address, transaction_hash = await state.starknet_wrapper.deploy( + deploy_transaction=deploy_transaction + ) return RpcDeployTransactionResult( transaction_hash=rpc_felt(transaction_hash), contract_address=rpc_felt(contract_address), @@ -159,15 +209,22 @@ async def estimate_fee(request: RpcInvokeTransaction, block_id: BlockId) -> dict """ assert_block_id_is_latest(block_id) - if not state.starknet_wrapper.contracts.is_deployed(int(request["contract_address"], 16)): + if not state.starknet_wrapper.contracts.is_deployed( + int(request["contract_address"], 16) + ): raise RpcError(code=20, message="Contract not found") invoke_function = make_invoke_function(request) try: - fee_response = await state.starknet_wrapper.calculate_actual_fee(invoke_function) + fee_response = await state.starknet_wrapper.calculate_actual_fee( + invoke_function + ) except StarkException as ex: - if f"Entry point {hex(int(request['entry_point_selector'], 16))} not found" in ex.message: + if ( + f"Entry point {hex(int(request['entry_point_selector'], 16))} not found" + in ex.message + ): raise RpcError(code=21, message="Invalid message selector") from ex if "While handling calldata" in ex.message: raise RpcError(code=22, message="Invalid call data") from ex diff --git a/starknet_devnet/blueprints/rpc/utils.py b/starknet_devnet/blueprints/rpc/utils.py index ca4f5074a..743cae760 100644 --- a/starknet_devnet/blueprints/rpc/utils.py +++ b/starknet_devnet/blueprints/rpc/utils.py @@ -13,10 +13,15 @@ def block_tag_to_block_number(block_id: BlockId) -> BlockId: """ if isinstance(block_id, str): if block_id == "latest": - return {"block_number": state.starknet_wrapper.blocks.get_number_of_blocks() - 1} + return { + "block_number": state.starknet_wrapper.blocks.get_number_of_blocks() - 1 + } if block_id == "pending": - raise RpcError(code=-1, message="Calls with block_id == 'pending' are not supported currently.") + raise RpcError( + code=-1, + message="Calls with block_id == 'pending' are not supported currently.", + ) raise RpcError(code=24, message="Invalid block id") @@ -32,8 +37,12 @@ def get_block_by_block_id(block_id: BlockId) -> dict: try: if "block_hash" in block_id: - return state.starknet_wrapper.blocks.get_by_hash(block_hash=block_id["block_hash"]) - return state.starknet_wrapper.blocks.get_by_number(block_number=block_id["block_number"]) + return state.starknet_wrapper.blocks.get_by_hash( + block_hash=block_id["block_hash"] + ) + return state.starknet_wrapper.blocks.get_by_number( + block_number=block_id["block_number"] + ) except StarknetDevnetException as ex: raise RpcError(code=24, message="Invalid block id") from ex @@ -43,7 +52,10 @@ def assert_block_id_is_latest(block_id: BlockId) -> None: Assert block_id is "latest" and throw RpcError otherwise """ if block_id != "latest": - raise RpcError(code=-1, message="Calls with block_id != 'latest' are not supported currently.") + raise RpcError( + code=-1, + message="Calls with block_id != 'latest' are not supported currently.", + ) def rpc_felt(value: int) -> Felt: @@ -75,11 +87,7 @@ def rpc_response(message_id: int, content: dict) -> dict: """ Wrap response content in rpc format """ - return { - "jsonrpc": "2.0", - "id": message_id, - "result": content - } + return {"jsonrpc": "2.0", "id": message_id, "result": content} def rpc_error(message_id: int, code: int, message: str) -> dict: @@ -89,8 +97,5 @@ def rpc_error(message_id: int, code: int, message: str) -> dict: return { "jsonrpc": "2.0", "id": message_id, - "error": { - "code": code, - "message": message - } + "error": {"code": code, "message": message}, } diff --git a/starknet_devnet/blueprints/shared.py b/starknet_devnet/blueprints/shared.py index 325e08926..8778f934d 100644 --- a/starknet_devnet/blueprints/shared.py +++ b/starknet_devnet/blueprints/shared.py @@ -8,7 +8,8 @@ from starknet_devnet.constants import CAIRO_LANG_VERSION from starknet_devnet.util import StarknetDevnetException -def validate_transaction(data: bytes, loader: Transaction=Transaction): + +def validate_transaction(data: bytes, loader: Transaction = Transaction): """Ensure `data` is a valid Starknet transaction. Returns the parsed `Transaction`.""" try: transaction = loader.loads(data) diff --git a/starknet_devnet/constants.py b/starknet_devnet/constants.py index 62a10d7b9..1fd4e325b 100644 --- a/starknet_devnet/constants.py +++ b/starknet_devnet/constants.py @@ -8,12 +8,14 @@ CAIRO_LANG_VERSION = version("cairo-lang") FAILURE_REASON_KEY = "transaction_failure_reason" -TIMEOUT_FOR_WEB3_REQUESTS = 120 #seconds -L1_MESSAGE_CANCELLATION_DELAY = 0 # Min amount of time in seconds for a message to be able to be cancelled +TIMEOUT_FOR_WEB3_REQUESTS = 120 # seconds +L1_MESSAGE_CANCELLATION_DELAY = ( + 0 # Min amount of time in seconds for a message to be able to be cancelled +) DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 5050 DEFAULT_ACCOUNTS = 10 -DEFAULT_INITIAL_BALANCE = 10 ** 21 -DEFAULT_GAS_PRICE = 10 ** 11 +DEFAULT_INITIAL_BALANCE = 10**21 +DEFAULT_GAS_PRICE = 10**11 diff --git a/starknet_devnet/contract_wrapper.py b/starknet_devnet/contract_wrapper.py index dbc105121..e0ed5d5c0 100644 --- a/starknet_devnet/contract_wrapper.py +++ b/starknet_devnet/contract_wrapper.py @@ -8,18 +8,25 @@ from starkware.starknet.testing.contract import StarknetContract from starkware.starknet.utils.api_utils import cast_to_felts + class ContractWrapper: """ Wraps a StarknetContract, storing its types and code for later use. """ - def __init__(self, contract: StarknetContract, contract_class: ContractClass, deployment_tx_hash: int = None): + + def __init__( + self, + contract: StarknetContract, + contract_class: ContractClass, + deployment_tx_hash: int = None, + ): self.contract: StarknetContract = contract self.contract_class = contract_class.remove_debug_info() self.deployment_tx_hash = deployment_tx_hash self.code: dict = { "abi": contract_class.abi, - "bytecode": self.contract_class.dump()["program"]["data"] + "bytecode": self.contract_class.dump()["program"]["data"], } # pylint: disable=too-many-arguments @@ -29,7 +36,7 @@ async def call( calldata: List[int], signature: List[int], caller_address: int, - max_fee: int + max_fee: int, ): """ Calls the function identified with `entry_point_selector`, potentially passing in `calldata` and `signature`. @@ -41,7 +48,7 @@ async def call( contract_address=self.contract.contract_address, max_fee=max_fee, selector=entry_point_selector, - signature=signature and cast_to_felts(values=signature) + signature=signature and cast_to_felts(values=signature), ) result = list(map(hex, call_info.retdata)) @@ -54,7 +61,7 @@ async def invoke( calldata: List[int], signature: List[int], caller_address: int, - max_fee: int + max_fee: int, ): """ Invokes the function identified with `entry_point_selector`, potentially passing in `calldata` and `signature`. diff --git a/starknet_devnet/contracts.py b/starknet_devnet/contracts.py index fa806e62c..a7b806a76 100644 --- a/starknet_devnet/contracts.py +++ b/starknet_devnet/contracts.py @@ -7,12 +7,10 @@ from starkware.starknet.services.api.contract_class import ContractClass from .origin import Origin -from .util import ( - StarknetDevnetException, - fixed_length_hex -) +from .util import StarknetDevnetException, fixed_length_hex from .contract_wrapper import ContractWrapper + class DevnetContracts: """ This class is used to store the deployed contracts of the devnet. @@ -47,7 +45,9 @@ def get_by_address(self, address: int) -> ContractWrapper: Get the contract wrapper by address. """ if not self.is_deployed(address): - message = f"No contract at the provided address ({fixed_length_hex(address)})." + message = ( + f"No contract at the provided address ({fixed_length_hex(address)})." + ) raise StarknetDevnetException(message=message) return self.__instances[address] diff --git a/starknet_devnet/devnet_config.py b/starknet_devnet/devnet_config.py index 643aa792c..46c9aecfe 100644 --- a/starknet_devnet/devnet_config.py +++ b/starknet_devnet/devnet_config.py @@ -11,7 +11,7 @@ DEFAULT_GAS_PRICE, DEFAULT_HOST, DEFAULT_INITIAL_BALANCE, - DEFAULT_PORT + DEFAULT_PORT, ) @@ -28,24 +28,32 @@ # # otherwise a URL; perhaps check validity # return name + class DumpOn(Enum): """Enumerate possible dumping frequencies.""" + EXIT = auto() TRANSACTION = auto() + DUMP_ON_OPTIONS = [e.name.lower() for e in DumpOn] DUMP_ON_OPTIONS_STRINGIFIED = ", ".join(DUMP_ON_OPTIONS) + def parse_dump_on(option: str): """Parse dumping frequency option.""" if option in DUMP_ON_OPTIONS: return DumpOn[option.upper()] - sys.exit(f"Error: Invalid --dump-on option: {option}. Valid options: {DUMP_ON_OPTIONS_STRINGIFIED}") + sys.exit( + f"Error: Invalid --dump-on option: {option}. Valid options: {DUMP_ON_OPTIONS_STRINGIFIED}" + ) + class NonNegativeAction(argparse.Action): """ Action for parsing the non negative int argument. """ + def __call__(self, parser, namespace, values, option_string=None): error_msg = f"{option_string} must be a positive integer; got: {values}." try: @@ -58,86 +66,89 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, value) + def parse_args(raw_args: List[str]): """ Parses CLI arguments. """ - parser = argparse.ArgumentParser(description="Run a local instance of Starknet Devnet") + parser = argparse.ArgumentParser( + description="Run a local instance of Starknet Devnet" + ) parser.add_argument( - "-v", "--version", + "-v", + "--version", help="Print the version", action="version", - version=__version__ + version=__version__, ) parser.add_argument( "--host", - help=f"Specify the address to listen at; defaults to {DEFAULT_HOST} " + - "(use the address the program outputs on start)", - default=DEFAULT_HOST + help=f"Specify the address to listen at; defaults to {DEFAULT_HOST} " + + "(use the address the program outputs on start)", + default=DEFAULT_HOST, ) parser.add_argument( - "--port", "-p", + "--port", + "-p", type=int, help=f"Specify the port to listen at; defaults to {DEFAULT_PORT}", - default=DEFAULT_PORT + default=DEFAULT_PORT, ) parser.add_argument( - "--load-path", - help="Specify the path from which the state is loaded on startup" - ) - parser.add_argument( - "--dump-path", - help="Specify the path to dump to" + "--load-path", help="Specify the path from which the state is loaded on startup" ) + parser.add_argument("--dump-path", help="Specify the path to dump to") parser.add_argument( "--dump-on", help=f"Specify when to dump; can dump on: {DUMP_ON_OPTIONS_STRINGIFIED}", - type=parse_dump_on + type=parse_dump_on, ) parser.add_argument( "--lite-mode", action="store_true", - help="Applies all lite-mode-* optimizations by disabling some features." + help="Applies all lite-mode-* optimizations by disabling some features.", ) parser.add_argument( "--lite-mode-block-hash", action="store_true", - help="Disables block hash calculation" + help="Disables block hash calculation", ) parser.add_argument( "--lite-mode-deploy-hash", action="store_true", - help="Disables deploy tx hash calculation" + help="Disables deploy tx hash calculation", ) parser.add_argument( "--accounts", action=NonNegativeAction, help=f"Specify the number of accounts to be predeployed; defaults to {DEFAULT_ACCOUNTS}", - default=DEFAULT_ACCOUNTS + default=DEFAULT_ACCOUNTS, ) parser.add_argument( - "--initial-balance", "-e", + "--initial-balance", + "-e", action=NonNegativeAction, - help="Specify the initial balance of accounts to be predeployed; " + - f"defaults to {DEFAULT_INITIAL_BALANCE:g}", - default=DEFAULT_INITIAL_BALANCE + help="Specify the initial balance of accounts to be predeployed; " + + f"defaults to {DEFAULT_INITIAL_BALANCE:g}", + default=DEFAULT_INITIAL_BALANCE, ) parser.add_argument( "--seed", type=int, - help="Specify the seed for randomness of accounts to be predeployed" + help="Specify the seed for randomness of accounts to be predeployed", ) parser.add_argument( "--start-time", action=NonNegativeAction, - help="Specify the start time of the genesis block in Unix time seconds" + help="Specify the start time of the genesis block in Unix time seconds", ) parser.add_argument( - "--gas-price", "-g", + "--gas-price", + "-g", action=NonNegativeAction, default=DEFAULT_GAS_PRICE, - help="Specify the gas price in wei per gas unit; " + - f"defaults to {DEFAULT_GAS_PRICE:g}" + help="Specify the gas price in wei per gas unit; " + + f"defaults to {DEFAULT_GAS_PRICE:g}", ) # Uncomment this once fork support is added # parser.add_argument( @@ -153,12 +164,13 @@ def parse_args(raw_args: List[str]): return parsed_args + # pylint: disable=too-few-public-methods # pylint: disable=too-many-instance-attributes class DevnetConfig: """Class holding configuration specified by user""" - def __init__(self, args: argparse.Namespace=None): + def __init__(self, args: argparse.Namespace = None): # these args are used in tests; in production, this is overwritten in `main` self.args = args or parse_args(["--accounts", "0"]) self.accounts = self.args.accounts diff --git a/starknet_devnet/dump.py b/starknet_devnet/dump.py index 74b8b9e4d..e06280da4 100644 --- a/starknet_devnet/dump.py +++ b/starknet_devnet/dump.py @@ -9,6 +9,7 @@ # Instead of "fork", the default on MacOS since Python3.8 has been "spawn", which causes pickling to fail multiprocessing.set_start_method("fork") + class Dumper: """Class for dumping objects.""" @@ -28,7 +29,7 @@ def __write_file(self, path): with open(path, "wb") as file: pickle.dump(self.dumpable, file) - def dump(self, path: str=None): + def dump(self, path: str = None): """Dump to `path`.""" path = path or self.dump_path assert path, "No dump_path defined" diff --git a/starknet_devnet/fee_token.py b/starknet_devnet/fee_token.py index 51f01e597..a025bd360 100644 --- a/starknet_devnet/fee_token.py +++ b/starknet_devnet/fee_token.py @@ -6,16 +6,20 @@ from starkware.starknet.services.api.contract_class import ContractClass from starkware.starknet.services.api.gateway.transaction import InvokeFunction from starkware.starknet.storage.starknet_storage import StorageLeaf -from starkware.starknet.business_logic.state.objects import (ContractState, ContractCarriedState) +from starkware.starknet.business_logic.state.objects import ( + ContractState, + ContractCarriedState, +) from starkware.starknet.testing.contract import StarknetContract from starkware.python.utils import to_bytes from starkware.starknet.compiler.compile import get_selector_from_name from starknet_devnet.util import Uint256 + class FeeToken: """Wrapper of token for charging fees.""" - CONTRACT_CLASS: ContractClass = None # loaded lazily + CONTRACT_CLASS: ContractClass = None # loaded lazily # Precalculated # HASH = to_bytes(compute_class_hash(contract_class=FeeToken.get_contract_class())) @@ -25,7 +29,9 @@ class FeeToken: # Precalculated to fixed address # ADDRESS = calculate_contract_address_from_hash(salt=10, class_hash=HASH, # constructor_calldata=[], caller_address=0) - ADDRESS = 2774287484619332564597403632816768868845110259953541691709975889937073775752 + ADDRESS = ( + 2774287484619332564597403632816768868845110259953541691709975889937073775752 + ) SYMBOL = "ETH" NAME = "ether" @@ -38,7 +44,9 @@ def __init__(self, starknet_wrapper): def get_contract_class(cls): """Returns contract class via lazy loading.""" if not cls.CONTRACT_CLASS: - cls.CONTRACT_CLASS = ContractClass.load(load_nearby_contract("ERC20_Mintable_OZ_0.2.0")) + cls.CONTRACT_CLASS = ContractClass.load( + load_nearby_contract("ERC20_Mintable_OZ_0.2.0") + ) return cls.CONTRACT_CLASS async def deploy(self): @@ -53,35 +61,40 @@ async def deploy(self): starknet.state.state.contract_definitions[FeeToken.HASH_BYTES] = contract_class newly_deployed_fee_token_state = await ContractState.create( contract_hash=FeeToken.HASH_BYTES, - storage_commitment_tree=fee_token_state.storage_commitment_tree + storage_commitment_tree=fee_token_state.storage_commitment_tree, ) starknet.state.state.contract_states[FeeToken.ADDRESS] = ContractCarriedState( state=newly_deployed_fee_token_state, storage_updates={ # Running the constructor doesn't need to be simulated - get_selector_from_name("ERC20_name"): StorageLeaf(int.from_bytes(bytes(FeeToken.NAME, "ascii"), "big")), - get_selector_from_name("ERC20_symbol"): StorageLeaf(int.from_bytes(bytes(FeeToken.SYMBOL, "ascii"), "big")), - get_selector_from_name("ERC20_decimals"): StorageLeaf(18) - } + get_selector_from_name("ERC20_name"): StorageLeaf( + int.from_bytes(bytes(FeeToken.NAME, "ascii"), "big") + ), + get_selector_from_name("ERC20_symbol"): StorageLeaf( + int.from_bytes(bytes(FeeToken.SYMBOL, "ascii"), "big") + ), + get_selector_from_name("ERC20_decimals"): StorageLeaf(18), + }, ) self.contract = StarknetContract( state=starknet.state, abi=FeeToken.get_contract_class().abi, contract_address=FeeToken.ADDRESS, - deploy_execution_info=None + deploy_execution_info=None, ) - self.starknet_wrapper.store_contract(FeeToken.ADDRESS, self.contract, contract_class) + self.starknet_wrapper.store_contract( + FeeToken.ADDRESS, self.contract, contract_class + ) async def get_balance(self, address: int) -> int: """Return the balance of the contract under `address`.""" response = await self.contract.balanceOf(address).call() balance = Uint256( - low=response.result.balance.low, - high=response.result.balance.high + low=response.result.balance.low, high=response.result.balance.high ).to_felt() return balance @@ -96,7 +109,7 @@ def get_mint_transaction(cls, to_address: int, amount: Uint256): str(amount.high), ], "signature": [], - "contract_address": hex(cls.ADDRESS) + "contract_address": hex(cls.ADDRESS), } return InvokeFunction.load(transaction_data) @@ -110,8 +123,7 @@ async def mint(self, to_address: int, amount: int, lite: bool): tx_hash = None if lite: await self.contract.mint( - to_address, - (amount_uint256.low, amount_uint256.high) + to_address, (amount_uint256.low, amount_uint256.high) ).invoke() else: transaction = self.get_mint_transaction(to_address, amount_uint256) diff --git a/starknet_devnet/general_config.py b/starknet_devnet/general_config.py index 451d3d61a..b51643058 100644 --- a/starknet_devnet/general_config.py +++ b/starknet_devnet/general_config.py @@ -13,20 +13,22 @@ from .fee_token import FeeToken -DEFAULT_GENERAL_CONFIG = build_general_config({ - "cairo_resource_fee_weights": { - "n_steps": constants.N_STEPS_FEE_WEIGHT, - }, - "contract_storage_commitment_tree_height": constants.CONTRACT_STATES_COMMITMENT_TREE_HEIGHT, - "event_commitment_tree_height": constants.EVENT_COMMITMENT_TREE_HEIGHT, - "global_state_commitment_tree_height": constants.CONTRACT_ADDRESS_BITS, - "invoke_tx_max_n_steps": DEFAULT_MAX_STEPS, - "min_gas_price": DEFAULT_GAS_PRICE, - "sequencer_address": hex(DEFAULT_SEQUENCER_ADDRESS), - "starknet_os_config": { - "chain_id": DEFAULT_CHAIN_ID.name, - "fee_token_address": hex(FeeToken.ADDRESS) - }, - "tx_version": constants.TRANSACTION_VERSION, - "tx_commitment_tree_height": constants.TRANSACTION_COMMITMENT_TREE_HEIGHT -}) +DEFAULT_GENERAL_CONFIG = build_general_config( + { + "cairo_resource_fee_weights": { + "n_steps": constants.N_STEPS_FEE_WEIGHT, + }, + "contract_storage_commitment_tree_height": constants.CONTRACT_STATES_COMMITMENT_TREE_HEIGHT, + "event_commitment_tree_height": constants.EVENT_COMMITMENT_TREE_HEIGHT, + "global_state_commitment_tree_height": constants.CONTRACT_ADDRESS_BITS, + "invoke_tx_max_n_steps": DEFAULT_MAX_STEPS, + "min_gas_price": DEFAULT_GAS_PRICE, + "sequencer_address": hex(DEFAULT_SEQUENCER_ADDRESS), + "starknet_os_config": { + "chain_id": DEFAULT_CHAIN_ID.name, + "fee_token_address": hex(FeeToken.ADDRESS), + }, + "tx_version": constants.TRANSACTION_VERSION, + "tx_commitment_tree_height": constants.TRANSACTION_COMMITMENT_TREE_HEIGHT, + } +) diff --git a/starknet_devnet/origin.py b/starknet_devnet/origin.py index d71cc9c5b..130dc648a 100644 --- a/starknet_devnet/origin.py +++ b/starknet_devnet/origin.py @@ -13,6 +13,7 @@ from starknet_devnet.util import StarknetDevnetException + class Origin: """ Abstraction of an L2 blockchain. @@ -66,22 +67,23 @@ def get_number_of_blocks(self): """Returns the number of blocks stored so far""" raise NotImplementedError - def get_state_update(self, block_hash: str=None, block_number: int=None) -> dict or None: + def get_state_update( + self, block_hash: str = None, block_number: int = None + ) -> dict or None: """ Returns the state update for provided block hash or block number. If none are provided return the last state update """ raise NotImplementedError + class NullOrigin(Origin): """ A default class to comply with the Origin interface. """ def get_transaction_status(self, transaction_hash: str): - return { - "tx_status": TransactionStatus.NOT_RECEIVED.name - } + return {"tx_status": TransactionStatus.NOT_RECEIVED.name} def get_transaction(self, transaction_hash: str) -> TransactionInfo: return TransactionInfo.create( @@ -100,16 +102,16 @@ def get_transaction_receipt(self, transaction_hash: str) -> TransactionReceipt: execution_resources=None, actual_fee=None, transaction_failure_reason=None, - l1_to_l2_consumed_message=None + l1_to_l2_consumed_message=None, ) def get_transaction_trace(self, transaction_hash: str): tx_hash_int = int(transaction_hash, 16) - message=f"Transaction corresponding to hash {tx_hash_int} is not found." + message = f"Transaction corresponding to hash {tx_hash_int} is not found." raise StarknetDevnetException(message=message) def get_block_by_hash(self, block_hash: str): - message=f"Block hash not found; got: {block_hash}." + message = f"Block hash not found; got: {block_hash}." raise StarknetDevnetException(message=message) def get_block_by_number(self, block_number: int): @@ -117,17 +119,10 @@ def get_block_by_number(self, block_number: int): raise StarknetDevnetException(message=message) def get_code(self, contract_address: int): - return { - "abi": {}, - "bytecode": [] - } + return {"abi": {}, "bytecode": []} def get_full_contract(self, contract_address: int) -> dict: - return { - "abi": {}, - "entry_points_by_type": {}, - "program": {} - } + return {"abi": {}, "entry_points_by_type": {}, "program": {}} def get_class_by_hash(self, class_hash: int) -> ContractClass: message = f"Class with hash {hex(class_hash)} is not declared" @@ -143,15 +138,22 @@ def get_storage_at(self, contract_address: int, key: int) -> str: def get_number_of_blocks(self): return 0 - def get_state_update(self, block_hash: str=None, block_number: int=None) -> dict or None: + def get_state_update( + self, block_hash: str = None, block_number: int = None + ) -> dict or None: if block_hash: - error_message = f"No state updates saved for the provided block hash {block_hash}" + error_message = ( + f"No state updates saved for the provided block hash {block_hash}" + ) raise StarknetDevnetException(message=error_message) if block_number is not None: - error_message = f"No state updates saved for the provided block number {block_number}" + error_message = ( + f"No state updates saved for the provided block number {block_number}" + ) raise StarknetDevnetException(message=error_message) + class ForkedOrigin(Origin): """ Abstracts an origin that the devnet was forked from. @@ -194,5 +196,7 @@ def get_storage_at(self, contract_address: int, key: int) -> str: def get_number_of_blocks(self): return self.number_of_blocks - def get_state_update(self, block_hash: str=None, block_number: int=None) -> dict or None: + def get_state_update( + self, block_hash: str = None, block_number: int = None + ) -> dict or None: raise NotImplementedError diff --git a/starknet_devnet/postman_wrapper.py b/starknet_devnet/postman_wrapper.py index 49008a958..73bf04cfb 100644 --- a/starknet_devnet/postman_wrapper.py +++ b/starknet_devnet/postman_wrapper.py @@ -30,9 +30,13 @@ def __parse_l1_l2_messages(self, l1_raw_messages, l2_raw_messages) -> dict: for message in l1_raw_messages: message["args"]["selector"] = hex(message["args"]["selector"]) - message["args"]["to_address"] = fixed_length_hex(message["args"].pop("toAddress")) # L2 addresses need the leading 0 + message["args"]["to_address"] = fixed_length_hex( + message["args"].pop("toAddress") + ) # L2 addresses need the leading 0 message["args"]["from_address"] = message["args"].pop("fromAddress") - message["args"]["payload"] = [hex(val) for val in message["args"]["payload"]] + message["args"]["payload"] = [ + hex(val) for val in message["args"]["payload"] + ] # change case to snake_case message["transaction_hash"] = message.pop("transactionHash") @@ -44,22 +48,25 @@ def __parse_l1_l2_messages(self, l1_raw_messages, l2_raw_messages) -> dict: l2_messages = [] for message in l2_raw_messages: new_message = { - "from_address": fixed_length_hex(message.from_address), # L2 addresses need the leading 0 + "from_address": fixed_length_hex( + message.from_address + ), # L2 addresses need the leading 0 "payload": [hex(val) for val in message.payload], - "to_address": hex(message.to_address) + "to_address": hex(message.to_address), } l2_messages.append(new_message) return { "l1_provider": self.__l1_provider, - "consumed_messages": { - "from_l1": l1_raw_messages, - "from_l2": l2_messages - } + "consumed_messages": {"from_l1": l1_raw_messages, "from_l2": l2_messages}, } def load_l1_messaging_contract( - self, starknet: Starknet, network_url: str, contract_address: str, network_id: str + self, + starknet: Starknet, + network_url: str, + contract_address: str, + network_id: str, ) -> dict: """Creates a Postman Wrapper instance and loads an already deployed Messaging contract in the L1 network""" @@ -68,7 +75,9 @@ def load_l1_messaging_contract( try: starknet.state.l2_to_l1_messages_log.clear() self.__postman_wrapper = LocalPostmanWrapper(network_url) - self.__postman_wrapper.load_mock_messaging_contract_in_l1(starknet,contract_address) + self.__postman_wrapper.load_mock_messaging_contract_in_l1( + starknet, contract_address + ) except Exception as error: message = f"""Unable to load the Starknet Messaging contract in a local testnet instance. Make sure you have a local testnet instance running at the provided network url ({network_url}), @@ -82,26 +91,31 @@ def load_l1_messaging_contract( return { "l1_provider": network_url, - "address": self.__postman_wrapper.mock_starknet_messaging_contract.address + "address": self.__postman_wrapper.mock_starknet_messaging_contract.address, } async def flush(self, state) -> dict: - """Handles all pending L1 <> L2 messages and sends them to the other layer. """ + """Handles all pending L1 <> L2 messages and sends them to the other layer.""" if self.__postman_wrapper is None: return {} postman = self.__postman_wrapper.postman - l1_to_l2_messages = json.loads(Web3.toJSON(self.__postman_wrapper.l1_to_l2_message_filter.get_new_entries())) - l2_to_l1_messages = state.l2_to_l1_messages_log[postman.n_consumed_l2_to_l1_messages :] + l1_to_l2_messages = json.loads( + Web3.toJSON( + self.__postman_wrapper.l1_to_l2_message_filter.get_new_entries() + ) + ) + l2_to_l1_messages = state.l2_to_l1_messages_log[ + postman.n_consumed_l2_to_l1_messages : + ] await self.__postman_wrapper.flush() return self.__parse_l1_l2_messages(l1_to_l2_messages, l2_to_l1_messages) - class PostmanWrapper(ABC): """Postman Wrapper base class""" @@ -121,6 +135,7 @@ async def flush(self): """Handles the L1 <> L2 message exchange""" await self.postman.flush() + class LocalPostmanWrapper(PostmanWrapper): """Wrapper of Postman usage on a local testnet instantiated using a local testnet""" @@ -129,20 +144,24 @@ def __init__(self, network_url: str): request_kwargs = {"timeout": TIMEOUT_FOR_WEB3_REQUESTS} self.web3 = Web3(HTTPProvider(network_url, request_kwargs=request_kwargs)) self.web3.middleware_onion.inject(geth_poa_middleware, layer=0) - self.eth_account = EthAccount(self.web3,self.web3.eth.accounts[0]) + self.eth_account = EthAccount(self.web3, self.web3.eth.accounts[0]) def load_mock_messaging_contract_in_l1(self, starknet, contract_address): if contract_address is None: self.mock_starknet_messaging_contract = self.eth_account.deploy( load_nearby_contract("MockStarknetMessaging"), - L1_MESSAGE_CANCELLATION_DELAY + L1_MESSAGE_CANCELLATION_DELAY, ) else: address = Web3.toChecksumAddress(contract_address) contract_json = load_nearby_contract("MockStarknetMessaging") abi = contract_json["abi"] - w3_contract = self.web3.eth.contract(abi=abi,address=address) - self.mock_starknet_messaging_contract = EthContract(self.web3,address,w3_contract,abi,self.eth_account) + w3_contract = self.web3.eth.contract(abi=abi, address=address) + self.mock_starknet_messaging_contract = EthContract( + self.web3, address, w3_contract, abi, self.eth_account + ) self.postman = Postman(self.mock_starknet_messaging_contract, starknet) - self.l1_to_l2_message_filter = self.mock_starknet_messaging_contract.w3_contract.events.LogMessageToL2.createFilter(fromBlock="latest") + self.l1_to_l2_message_filter = self.mock_starknet_messaging_contract.w3_contract.events.LogMessageToL2.createFilter( + fromBlock="latest" + ) diff --git a/starknet_devnet/server.py b/starknet_devnet/server.py index 45779af7d..a0a22f078 100644 --- a/starknet_devnet/server.py +++ b/starknet_devnet/server.py @@ -31,6 +31,7 @@ async def initialize_starknet(): """Initialize Starknet to assert it's defined before its first use.""" await state.starknet_wrapper.initialize() + app.register_blueprint(base) app.register_blueprint(gateway) app.register_blueprint(feeder_gateway) @@ -50,29 +51,32 @@ def __init__(self, application, args): def load_config(self): self.cfg.set("bind", f"{self.args.host}:{self.args.port}") self.cfg.set("workers", 1) - self.cfg.set("logconfig_dict", { - "loggers": { - "gunicorn.error": { - # Disable info messages like "Starting gunicorn" - "level": "WARNING", - "handlers": ["error_console"], - "propagate": False, - "qualname": "gunicorn.error" + self.cfg.set( + "logconfig_dict", + { + "loggers": { + "gunicorn.error": { + # Disable info messages like "Starting gunicorn" + "level": "WARNING", + "handlers": ["error_console"], + "propagate": False, + "qualname": "gunicorn.error", + }, + "gunicorn.access": { + "level": "INFO", + # Log access to stderr to maintain backward compatibility + "handlers": ["error_console"], + "propagate": False, + "qualname": "gunicorn.access", + }, }, - - "gunicorn.access": { - "level": "INFO", - # Log access to stderr to maintain backward compatibility - "handlers": ["error_console"], - "propagate": False, - "qualname": "gunicorn.access" - } }, - }) + ) def load(self): return self.application + def main(): """Runs the server.""" @@ -103,12 +107,17 @@ def main(): state.dumper.dump() sys.exit(0) + @app.errorhandler(StarkException) def handle(error: StarkException): - """Handles the error and responds in JSON. """ - return {"message": error.message, "status_code": error.status_code}, error.status_code + """Handles the error and responds in JSON.""" + return { + "message": error.message, + "status_code": error.status_code, + }, error.status_code + -@app.route("/api", methods = ["GET"]) +@app.route("/api", methods=["GET"]) def api(): """Return available endpoints.""" routes = {} @@ -117,9 +126,10 @@ def api(): routes[url.rule] = { "functionName": url.endpoint, "methods": list(url.methods), - "doc": app.view_functions[url.endpoint].__doc__.strip() + "doc": app.view_functions[url.endpoint].__doc__.strip(), } return jsonify(routes) + if __name__ == "__main__": main() diff --git a/starknet_devnet/starknet_wrapper.py b/starknet_devnet/starknet_wrapper.py index ddeb30444..2c50ddd2a 100644 --- a/starknet_devnet/starknet_wrapper.py +++ b/starknet_devnet/starknet_wrapper.py @@ -13,12 +13,18 @@ ) from starkware.starknet.business_logic.internal_transaction import CallInfo from starkware.starknet.business_logic.state.state import BlockInfo, CarriedState -from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Deploy, Declare +from starkware.starknet.services.api.gateway.transaction import ( + InvokeFunction, + Deploy, + Declare, +) from starkware.starknet.testing.starknet import Starknet from starkware.starkware_utils.error_handling import StarkException from starkware.starknet.business_logic.transaction_fee import calculate_tx_fee from starkware.starknet.services.api.contract_class import EntryPointType, ContractClass -from starkware.starknet.services.api.feeder_gateway.response_objects import TransactionStatus +from starkware.starknet.services.api.feeder_gateway.response_objects import ( + TransactionStatus, +) from starkware.starknet.testing.contract import StarknetContract from starkware.starknet.testing.objects import FunctionInvocation @@ -27,12 +33,7 @@ from .fee_token import FeeToken from .general_config import DEFAULT_GENERAL_CONFIG from .origin import NullOrigin, Origin -from .util import ( - DummyExecutionInfo, - enable_pickling, - generate_state_update, - to_bytes -) +from .util import DummyExecutionInfo, enable_pickling, generate_state_update, to_bytes from .contract_wrapper import ContractWrapper from .postman_wrapper import DevnetL1L2 from .transactions import DevnetTransactions, DevnetTransaction @@ -43,7 +44,7 @@ enable_pickling() -#pylint: disable=too-many-instance-attributes +# pylint: disable=too-many-instance-attributes class StarknetWrapper: """ Wraps a Starknet instance and stores data to be returned by the server: @@ -94,7 +95,9 @@ async def create_empty_block(self): self.__update_block_number() state_update = await self.__update_state() state = self.get_state() - return await self.blocks.generate(None, state, state_update, is_empty_block=True) + return await self.blocks.generate( + None, state, state_update, is_empty_block=True + ) async def __preserve_current_state(self, state: CarriedState): self.__current_carried_state = deepcopy(state) @@ -123,17 +126,19 @@ async def __update_state(self): current_carried_state.block_info = self.block_info_generator.next_block( block_info=current_carried_state.block_info, - general_config=state.general_config + general_config=state.general_config, ) if not self.config.lite_mode_block_hash: # This is the most time-intensive part of the function. # With only skipping it in lite-mode, we still get the time benefit. # In regular mode it's needed for state update calculation and block state_root calculation. - updated_shared_state = await current_carried_state.shared_state.apply_state_updates( - ffc=current_carried_state.ffc, - previous_carried_state=previous_state, - current_carried_state=current_carried_state + updated_shared_state = ( + await current_carried_state.shared_state.apply_state_updates( + ffc=current_carried_state.ffc, + previous_carried_state=previous_state, + current_carried_state=current_carried_state, + ) ) state.state.shared_state = updated_shared_state @@ -143,14 +148,24 @@ async def __update_state(self): return None - def store_contract(self, - address: int, contract: StarknetContract, contract_class: ContractClass, tx_hash: int = None): + def store_contract( + self, + address: int, + contract: StarknetContract, + contract_class: ContractClass, + tx_hash: int = None, + ): """Store the provided data sa wrapped contract""" - self.contracts.store(address, ContractWrapper(contract, contract_class, tx_hash)) + self.contracts.store( + address, ContractWrapper(contract, contract_class, tx_hash) + ) async def __store_transaction( - self, transaction: DevnetTransaction, tx_hash: int, - state_update: Dict, error_message: str=None + self, + transaction: DevnetTransaction, + tx_hash: int, + state_update: Dict, + error_message: str = None, ) -> None: """ Stores the provided data as a deploy transaction in `self.transactions`. @@ -185,20 +200,21 @@ async def declare(self, declare_transaction: Declare) -> Tuple[int, int]: """ internal_declare: InternalDeclare = InternalDeclare.from_external( - declare_transaction, - self.get_state().general_config + declare_transaction, self.get_state().general_config ) declared_class = await self.starknet.declare( contract_class=declare_transaction.contract_class, ) - self.contracts.store_class(declared_class.class_hash, declare_transaction.contract_class) + self.contracts.store_class( + declared_class.class_hash, declare_transaction.contract_class + ) tx_hash = internal_declare.hash_value transaction = DevnetTransaction( internal_tx=internal_declare, status=TransactionStatus.ACCEPTED_ON_L2, execution_info=DummyExecutionInfo(), - transaction_hash=tx_hash + transaction_hash=tx_hash, ) self.__update_block_number() @@ -208,12 +224,11 @@ async def declare(self, declare_transaction: Declare) -> Tuple[int, int]: transaction=transaction, tx_hash=tx_hash, state_update=state_update, - error_message=None + error_message=None, ) return declared_class.class_hash, tx_hash - def __update_block_number(self): """Updates just the block number. Returns the old block info to allow reverting""" current_carried_state = self.get_state().state @@ -223,7 +238,7 @@ def __update_block_number(self): block_number=block_info.block_number + 1, block_timestamp=block_info.block_timestamp, sequencer_address=block_info.sequencer_address, - starknet_version=block_info.starknet_version + starknet_version=block_info.starknet_version, ) return block_info @@ -236,7 +251,9 @@ async def deploy(self, deploy_transaction: Deploy) -> Tuple[int, int]: state = self.get_state() contract_class = deploy_transaction.contract_definition - internal_tx: InternalDeploy = InternalDeploy.from_external(deploy_transaction, state.general_config) + internal_tx: InternalDeploy = InternalDeploy.from_external( + deploy_transaction, state.general_config + ) contract_address = internal_tx.contract_address if self.contracts.is_deployed(contract_address): @@ -254,13 +271,15 @@ async def deploy(self, deploy_transaction: Deploy) -> Tuple[int, int]: contract = await self.starknet.deploy( contract_class=contract_class, constructor_calldata=deploy_transaction.constructor_calldata, - contract_address_salt=deploy_transaction.contract_address_salt + contract_address_salt=deploy_transaction.contract_address_salt, ) execution_info = contract.deploy_execution_info error_message = None status = TransactionStatus.ACCEPTED_ON_L2 - self.store_contract(contract.contract_address, contract, contract_class, tx_hash) + self.store_contract( + contract.contract_address, contract, contract_class, tx_hash + ) state_update = await self.__update_state() except StarkException as err: error_message = err.message @@ -282,28 +301,34 @@ async def deploy(self, deploy_transaction: Deploy) -> Tuple[int, int]: transaction=transaction, state_update=state_update, error_message=error_message, - tx_hash=tx_hash + tx_hash=tx_hash, ) - await self.__register_new_contracts(execution_info.call_info.internal_calls, tx_hash) + await self.__register_new_contracts( + execution_info.call_info.internal_calls, tx_hash + ) return contract_address, tx_hash async def invoke(self, invoke_function: InvokeFunction): """Perform invoke according to specifications in `transaction`.""" state = self.get_state() - invoke_transaction: InternalInvokeFunction = InternalInvokeFunction.from_external(invoke_function, state.general_config) + invoke_transaction: InternalInvokeFunction = ( + InternalInvokeFunction.from_external(invoke_function, state.general_config) + ) try: preserved_block_info = self.__update_block_number() - contract_wrapper = self.contracts.get_by_address(invoke_transaction.contract_address) + contract_wrapper = self.contracts.get_by_address( + invoke_transaction.contract_address + ) adapted_result, execution_info = await contract_wrapper.invoke( entry_point_selector=invoke_transaction.entry_point_selector, calldata=invoke_transaction.calldata, signature=invoke_transaction.signature, caller_address=invoke_transaction.caller_address, - max_fee=invoke_transaction.max_fee + max_fee=invoke_transaction.max_fee, ) status = TransactionStatus.ACCEPTED_ON_L2 error_message = None @@ -325,12 +350,14 @@ async def invoke(self, invoke_function: InvokeFunction): transaction=transaction, state_update=state_update, error_message=error_message, - tx_hash=tx_hash + tx_hash=tx_hash, ) - await self.__register_new_contracts(execution_info.call_info.internal_calls, tx_hash) + await self.__register_new_contracts( + execution_info.call_info.internal_calls, tx_hash + ) - return invoke_function.contract_address, tx_hash, { "result": adapted_result } + return invoke_function.contract_address, tx_hash, {"result": adapted_result} async def call(self, transaction: InvokeFunction): """Perform call according to specifications in `transaction`.""" @@ -341,20 +368,26 @@ async def call(self, transaction: InvokeFunction): calldata=transaction.calldata, signature=transaction.signature, caller_address=0, - max_fee=transaction.max_fee + max_fee=transaction.max_fee, ) - return { "result": adapted_result } + return {"result": adapted_result} - async def __register_new_contracts(self, internal_calls: List[Union[FunctionInvocation, CallInfo]], tx_hash: int): + async def __register_new_contracts( + self, internal_calls: List[Union[FunctionInvocation, CallInfo]], tx_hash: int + ): for internal_call in internal_calls: if internal_call.entry_point_type == EntryPointType.CONSTRUCTOR: state = self.get_state() class_hash = to_bytes(internal_call.class_hash) contract_class = state.state.get_contract_class(class_hash) - contract = StarknetContract(state, contract_class.abi, internal_call.contract_address, None) - self.store_contract(internal_call.contract_address, contract, contract_class, tx_hash) + contract = StarknetContract( + state, contract_class.abi, internal_call.contract_address, None + ) + self.store_contract( + internal_call.contract_address, contract, contract_class, tx_hash + ) await self.__register_new_contracts(internal_call.internal_calls, tx_hash) async def get_storage_at(self, contract_address: int, key: int) -> Felt: @@ -370,12 +403,16 @@ async def get_storage_at(self, contract_address: int, key: int) -> Felt: return hex(contract_state.storage_updates[key].value) return self.origin.get_storage_at(contract_address, key) - async def load_messaging_contract_in_l1(self, network_url: str, contract_address: str, network_id: str) -> dict: + async def load_messaging_contract_in_l1( + self, network_url: str, contract_address: str, network_id: str + ) -> dict: """Loads the messaging contract at `contract_address`""" - return self.l1l2.load_l1_messaging_contract(self.starknet, network_url, contract_address, network_id) + return self.l1l2.load_l1_messaging_contract( + self.starknet, network_url, contract_address, network_id + ) async def postman_flush(self) -> dict: - """Handles all pending L1 <> L2 messages and sends them to the other layer. """ + """Handles all pending L1 <> L2 messages and sends them to the other layer.""" state = self.get_state() return await self.l1l2.flush(state) @@ -383,15 +420,17 @@ async def postman_flush(self) -> dict: async def calculate_actual_fee(self, external_tx: InvokeFunction): """Calculates actual fee""" state = self.get_state() - internal_tx = InternalInvokeFunction.from_external_query_tx(external_tx, state.general_config) + internal_tx = InternalInvokeFunction.from_external_query_tx( + external_tx, state.general_config + ) child_state = state.state.create_child_state_for_querying() - call_info = await internal_tx.execute(child_state, state.general_config, only_query=True) + call_info = await internal_tx.execute( + child_state, state.general_config, only_query=True + ) tx_fee = calculate_tx_fee( - state=child_state, - call_info=call_info, - general_config=state.general_config + state=child_state, call_info=call_info, general_config=state.general_config ) gas_price = state.state.block_info.gas_price diff --git a/starknet_devnet/state.py b/starknet_devnet/state.py index a7f9886bc..cfe087eff 100644 --- a/starknet_devnet/state.py +++ b/starknet_devnet/state.py @@ -8,10 +8,12 @@ from .starknet_wrapper import StarknetWrapper from .util import StarknetDevnetException, check_valid_dump_path -class State(): + +class State: """ Stores starknet wrapper and dumper """ + def __init__(self): self.set_starknet_wrapper(StarknetWrapper(DevnetConfig())) @@ -45,4 +47,5 @@ def set_dump_options(self, dump_path: str, dump_on: str): self.dumper.dump_path = dump_path self.dumper.dump_on = dump_on + state = State() diff --git a/starknet_devnet/transactions.py b/starknet_devnet/transactions.py index f81adedb1..ad004c5c5 100644 --- a/starknet_devnet/transactions.py +++ b/starknet_devnet/transactions.py @@ -15,18 +15,21 @@ StarknetBlock, FunctionInvocation, Event, - L2ToL1Message + L2ToL1Message, ) from starkware.starknet.business_logic.internal_transaction import InternalTransaction from starkware.starknet.testing.objects import ( TransactionExecutionInfo, - StarknetTransactionExecutionInfo + StarknetTransactionExecutionInfo, ) from starkware.starknet.definitions.error_codes import StarknetErrorCode -from services.everest.business_logic.transaction_execution_objects import TransactionFailureReason +from services.everest.business_logic.transaction_execution_objects import ( + TransactionFailureReason, +) from .origin import Origin + class DevnetTransaction: """Represents the devnet transaction""" @@ -34,7 +37,9 @@ def __init__( self, internal_tx: InternalTransaction, status: TransactionStatus, - execution_info: Union[TransactionExecutionInfo, StarknetTransactionExecutionInfo], + execution_info: Union[ + TransactionExecutionInfo, StarknetTransactionExecutionInfo + ], transaction_hash: int = None, ): self.block = None @@ -50,7 +55,11 @@ def __init__( def __get_actual_fee(self) -> int: """Returns the actual fee""" - return self.execution_info.actual_fee if hasattr(self.execution_info, "actual_fee") else 0 + return ( + self.execution_info.actual_fee + if hasattr(self.execution_info, "actual_fee") + else 0 + ) def __get_events(self) -> List[Event]: """Returns the events""" @@ -69,11 +78,13 @@ def __get_l2_to_l1_messages(self) -> List[L2ToL1Message]: contract_address = self.execution_info.call_info.contract_address for l2_to_l1_message in self.execution_info.call_info.l2_to_l1_messages: - l2_to_l1_messages.append(L2ToL1Message( - from_address=contract_address, - to_address=Web3.toChecksumAddress(hex(l2_to_l1_message.to_address)), - payload=l2_to_l1_message.payload, - )) + l2_to_l1_messages.append( + L2ToL1Message( + from_address=contract_address, + to_address=Web3.toChecksumAddress(hex(l2_to_l1_message.to_address)), + payload=l2_to_l1_message.payload, + ) + ) return l2_to_l1_messages @@ -92,13 +103,14 @@ def set_block(self, block: StarknetBlock): def set_failure_reason(self, error_message: str): """Sets the failure reason of the transaction""" self.transaction_failure_reason = TransactionFailureReason( - code=StarknetErrorCode.TRANSACTION_FAILED.name, - error_message=error_message + code=StarknetErrorCode.TRANSACTION_FAILED.name, error_message=error_message ) def get_signature(self) -> List[int]: """Returns the signature""" - return self.internal_tx.signature if hasattr(self.internal_tx, "signature") else [] + return ( + self.internal_tx.signature if hasattr(self.internal_tx, "signature") else [] + ) def get_tx_info(self) -> TransactionInfo: """Returns the transaction info""" @@ -108,7 +120,7 @@ def get_tx_info(self) -> TransactionInfo: transaction_index=self.transaction_index, block_hash=self.__get_block_hash(), block_number=self.__get_block_number(), - transaction_failure_reason=self.transaction_failure_reason + transaction_failure_reason=self.transaction_failure_reason, ) def get_receipt(self) -> TransactionReceipt: @@ -127,7 +139,7 @@ def get_receipt(self) -> TransactionReceipt: actual_fee=self.__get_actual_fee(), events=self.__get_events(), execution_resources=execution_resources, - l2_to_l1_messages=self.__get_l2_to_l1_messages() + l2_to_l1_messages=self.__get_l2_to_l1_messages(), ) def get_trace(self) -> TransactionTrace: @@ -138,7 +150,9 @@ def get_trace(self) -> TransactionTrace: function_invocation=( call_info if isinstance(call_info, FunctionInvocation) - else FunctionInvocation.from_internal_version(self.execution_info.call_info) + else FunctionInvocation.from_internal_version( + self.execution_info.call_info + ) ), signature=self.get_signature(), ) @@ -152,9 +166,10 @@ def get_execution(self) -> TransactionExecution: events=self.__get_events(), execution_resources=self.execution_info.call_info.execution_resources, l2_to_l1_messages=self.__get_l2_to_l1_messages(), - l1_to_l2_consumed_message=None + l1_to_l2_consumed_message=None, ) + class DevnetTransactions: """ This class is used to store transactions. @@ -194,7 +209,6 @@ def get_transaction(self, tx_hash: str): return transaction.get_tx_info() - def get_transaction_trace(self, tx_hash: str): """ Get a transaction trace. @@ -233,7 +247,10 @@ def get_transaction_status(self, tx_hash: str): } # "block_hash" will only exist after transaction enters ACCEPTED_ON_L2 - if transaction.status == TransactionStatus.ACCEPTED_ON_L2 and transaction.block is not None: + if ( + transaction.status == TransactionStatus.ACCEPTED_ON_L2 + and transaction.block is not None + ): status_response["block_hash"] = hex(transaction.block.block_hash) # "tx_failure_reason" will only exist if the transaction was rejected. diff --git a/starknet_devnet/util.py b/starknet_devnet/util.py index 53b300a23..90a429fbe 100644 --- a/starknet_devnet/util.py +++ b/starknet_devnet/util.py @@ -11,9 +11,13 @@ from starkware.starknet.business_logic.execution.objects import CallInfo from starkware.starknet.business_logic.state.state import CarriedState from starkware.starknet.services.api.feeder_gateway.response_objects import ( - BlockStateUpdate, StateDiff, StorageEntry, DeployedContract + BlockStateUpdate, + StateDiff, + StorageEntry, + DeployedContract, ) + def custom_int(arg: str) -> int: """ Converts the argument to an integer. @@ -22,15 +26,18 @@ def custom_int(arg: str) -> int: base = 16 if arg.startswith("0x") else 10 return int(arg, base) + def fixed_length_hex(arg: int) -> str: """ Converts the int input to a hex output of fixed length """ return f"0x{arg:064x}" + @dataclass class Uint256: """Abstraction of Uint256 type""" + low: int high: int @@ -41,23 +48,24 @@ def to_felt(self) -> int: @staticmethod def from_felt(felt: int) -> "Uint256": """Converts felt to Uint256""" - return Uint256( - low=felt & ((1 << 128) - 1), - high=felt >> 128 - ) + return Uint256(low=felt & ((1 << 128) - 1), high=felt >> 128) + class StarknetDevnetException(StarkException): """ Exception raised across the project. Indicates the raised issue is devnet-related. """ + def __init__(self, status_code=500, code=None, message=None): super().__init__(code=code, message=message) self.status_code = status_code + @dataclass class DummyExecutionInfo: """Used if tx fails, but execution info is still required.""" + def __init__(self): self.actual_fee = 0 self.call_info = CallInfo.empty_for_testing() @@ -74,10 +82,12 @@ def get_sorted_l2_to_l1_messages(self): """Return empty list""" return self.l2_to_l1_messages + def enable_pickling(): """ Extends the `StarknetContract` class to enable pickling. """ + def contract_getstate(self): return self.__dict__ @@ -87,26 +97,31 @@ def contract_setstate(self, state): StarknetContract.__getstate__ = contract_getstate StarknetContract.__setstate__ = contract_setstate -def generate_storage_diff(previous_storage_updates, storage_updates) -> List[StorageEntry]: + +def generate_storage_diff( + previous_storage_updates, storage_updates +) -> List[StorageEntry]: """ Returns storage diff between previous and current storage updates """ storage_diff = [] for storage_key, leaf in storage_updates.items(): - previous_leaf = previous_storage_updates.get(storage_key) if previous_storage_updates else None + previous_leaf = ( + previous_storage_updates.get(storage_key) + if previous_storage_updates + else None + ) if previous_leaf is None or previous_leaf.value != leaf.value: - storage_diff.append(StorageEntry( - key=storage_key, - value=leaf.value - ) - ) + storage_diff.append(StorageEntry(key=storage_key, value=leaf.value)) return storage_diff -def generate_state_update(previous_state: CarriedState, current_state: CarriedState) -> BlockStateUpdate: +def generate_state_update( + previous_state: CarriedState, current_state: CarriedState +) -> BlockStateUpdate: """ Returns roots, deployed contracts and storage diffs between 2 states """ @@ -116,26 +131,27 @@ def generate_state_update(previous_state: CarriedState, current_state: CarriedSt for class_hash in current_state.contract_definitions: if class_hash not in previous_state.contract_definitions: - declared_contracts.append( - int.from_bytes(class_hash, byteorder="big") - ) + declared_contracts.append(int.from_bytes(class_hash, byteorder="big")) for contract_address in current_state.contract_states: if contract_address not in previous_state.contract_states: class_hash = int.from_bytes( current_state.contract_states[contract_address].state.contract_hash, - "big" + "big", ) deployed_contracts.append( - DeployedContract( - address=contract_address, - class_hash=class_hash - ) + DeployedContract(address=contract_address, class_hash=class_hash) ) else: - previous_storage_updates = previous_state.contract_states[contract_address].storage_updates - storage_updates = current_state.contract_states[contract_address].storage_updates - storage_diff = generate_storage_diff(previous_storage_updates, storage_updates) + previous_storage_updates = previous_state.contract_states[ + contract_address + ].storage_updates + storage_updates = current_state.contract_states[ + contract_address + ].storage_updates + storage_diff = generate_storage_diff( + previous_storage_updates, storage_updates + ) if len(storage_diff) > 0: storage_diffs[contract_address] = storage_diff @@ -145,16 +161,14 @@ def generate_state_update(previous_state: CarriedState, current_state: CarriedSt state_diff = StateDiff( deployed_contracts=deployed_contracts, declared_contracts=tuple(declared_contracts), - storage_diffs=storage_diffs + storage_diffs=storage_diffs, ) return BlockStateUpdate( - block_hash=None, - new_root=new_root, - old_root=old_root, - state_diff=state_diff + block_hash=None, new_root=new_root, old_root=old_root, state_diff=state_diff ) + def to_bytes(value: Union[int, bytes]) -> bytes: """ If int, convert to 32-byte big-endian bytes instance @@ -162,6 +176,7 @@ def to_bytes(value: Union[int, bytes]) -> bytes: """ return value if isinstance(value, bytes) else value.to_bytes(32, "big") + def check_valid_dump_path(dump_path: str): """Checks if dump path is a directory. Raises ValueError if not.""" diff --git a/test/account.py b/test/account.py index cf044b799..4ef1e6906 100644 --- a/test/account.py +++ b/test/account.py @@ -9,7 +9,7 @@ from starkware.starknet.definitions.constants import TRANSACTION_VERSION, QUERY_VERSION from starkware.starknet.core.os.transaction_hash.transaction_hash import ( calculate_transaction_hash_common, - TransactionHashPrefix + TransactionHashPrefix, ) from starkware.starknet.definitions.general_config import StarknetChainId @@ -25,14 +25,17 @@ PRIVATE_KEY = 123456789987654321 PUBLIC_KEY = private_to_stark_key(PRIVATE_KEY) + def deploy_account_contract(salt=None): """Deploy account contract.""" return deploy(ACCOUNT_PATH, inputs=[str(PUBLIC_KEY)], salt=salt) + def get_nonce(account_address): """Get nonce.""" return call("get_nonce", account_address, ACCOUNT_ABI_PATH) + def get_execute_calldata(call_array, calldata, nonce): """Get calldata for __execute__.""" return [ @@ -40,18 +43,21 @@ def get_execute_calldata(call_array, calldata, nonce): *[x for t in call_array for x in t], len(calldata), *calldata, - int(nonce) + int(nonce), ] + def str_to_felt(text: str) -> int: """Converts string to felt.""" return int.from_bytes(bytes(text, "ascii"), "big") + def get_signature(message_hash: int, private_key: int) -> Tuple[str, str]: """Get signature from message hash and private key.""" sig_r, sig_s = sign(message_hash, private_key) return [str(sig_r), str(sig_s)] + def from_call_to_call_array(calls): """Transforms calls to call_array and calldata.""" call_array = [] @@ -64,17 +70,19 @@ def from_call_to_call_array(calls): call_tuple[0], get_selector_from_name(call_tuple[1]), len(calldata), - len(call_tuple[2]) + len(call_tuple[2]), ) call_array.append(entry) calldata.extend(call_tuple[2]) return (call_array, calldata) + def adapt_inputs(execute_calldata: List[int]) -> List[str]: """Get stringified inputs from execute_calldata.""" return [str(v) for v in execute_calldata] + # pylint: disable=too-many-arguments def get_execute_args( calls, @@ -82,7 +90,8 @@ def get_execute_args( private_key, nonce=None, max_fee=0, - version: int = TRANSACTION_VERSION): + version: int = TRANSACTION_VERSION, +): """Returns signature and execute calldata""" if nonce is None: @@ -97,17 +106,15 @@ def get_execute_args( contract_address=int(account_address, 16), calldata=execute_calldata, version=version, - max_fee=max_fee + max_fee=max_fee, ) signature = get_signature(message_hash, private_key) return signature, execute_calldata + def get_transaction_hash( - contract_address: int, - calldata: Sequence[int], - version: int, - max_fee: int = 0 + contract_address: int, calldata: Sequence[int], version: int, max_fee: int = 0 ) -> str: """Get transaction hash for execute transaction.""" return calculate_transaction_hash_common( @@ -121,6 +128,7 @@ def get_transaction_hash( additional_data=[], ) + def get_estimated_fee(calls, account_address, private_key, nonce=None): """Get estimated fee through account.""" signature, execute_calldata = get_execute_args( @@ -128,7 +136,7 @@ def get_estimated_fee(calls, account_address, private_key, nonce=None): account_address=account_address, private_key=private_key, nonce=nonce, - version=QUERY_VERSION + version=QUERY_VERSION, ) return estimate_fee( @@ -149,7 +157,9 @@ def execute(calls, account_address, private_key, nonce=None, max_fee=0, query=Fa version = TRANSACTION_VERSION runner = invoke - signature, execute_calldata = get_execute_args(calls, account_address, private_key, nonce, max_fee, version=version) + signature, execute_calldata = get_execute_args( + calls, account_address, private_key, nonce, max_fee, version=version + ) return runner( "__execute__", @@ -157,5 +167,5 @@ def execute(calls, account_address, private_key, nonce=None, max_fee=0, query=Fa address=account_address, abi_path=ACCOUNT_ABI_PATH, signature=signature, - max_fee=str(max_fee) + max_fee=str(max_fee), ) diff --git a/test/rpc/conftest.py b/test/rpc/conftest.py index 66fc0e630..79a9bdd5f 100644 --- a/test/rpc/conftest.py +++ b/test/rpc/conftest.py @@ -14,8 +14,17 @@ from starkware.starknet.services.api.gateway.transaction import Transaction, Deploy import pytest -from starknet_devnet.blueprints.rpc.structures.types import BlockNumberDict, BlockHashDict, Felt -from .rpc_utils import gateway_call, get_block_with_transaction, pad_zero, add_transaction +from starknet_devnet.blueprints.rpc.structures.types import ( + BlockNumberDict, + BlockHashDict, + Felt, +) +from .rpc_utils import ( + gateway_call, + get_block_with_transaction, + pad_zero, + add_transaction, +) DEPLOY_CONTENT = load_file_content("deploy_rpc.json") INVOKE_CONTENT = load_file_content("invoke_rpc.json") @@ -36,7 +45,9 @@ def fixture_class_hash(deploy_info) -> Felt: """ Class hash of deployed contract """ - class_hash = gateway_call("get_class_hash_at", contractAddress=deploy_info["address"]) + class_hash = gateway_call( + "get_class_hash_at", contractAddress=deploy_info["address"] + ) return pad_zero(class_hash) diff --git a/test/rpc/rpc_utils.py b/test/rpc/rpc_utils.py index b68914359..9ad3f3ef1 100644 --- a/test/rpc/rpc_utils.py +++ b/test/rpc/rpc_utils.py @@ -13,16 +13,16 @@ class BackgroundDevnetClient: - """ A thin wrapper for requests, to interact with a background devnet instance """ + """A thin wrapper for requests, to interact with a background devnet instance""" @staticmethod def get(endpoint: str) -> requests.Response: - """ Submit get request at given endpoint """ + """Submit get request at given endpoint""" return requests.get(f"{APP_URL}{endpoint}") @staticmethod def post(endpoint: str, body: dict) -> requests.Response: - """ Submit post request with given dict in body (JSON) """ + """Submit post request with given dict in body (JSON)""" return requests.post(f"{APP_URL}{endpoint}", json=body) @@ -30,12 +30,7 @@ def make_rpc_payload(method: str, params: Union[dict, list]): """ Make a wrapper for rpc call """ - return { - "jsonrpc": "2.0", - "method": method, - "params": params, - "id": 0 - } + return {"jsonrpc": "2.0", "method": method, "params": params, "id": 0} def rpc_call_background_devnet(method: str, params: Union[dict, list]): @@ -50,7 +45,9 @@ def rpc_call(method: str, params: Union[dict, list]) -> dict: """ Make a call to the RPC endpoint """ - return BackgroundDevnetClient.post("/rpc", body=make_rpc_payload(method, params)).json() + return BackgroundDevnetClient.post( + "/rpc", body=make_rpc_payload(method, params) + ).json() def add_transaction(body: dict) -> dict: @@ -76,7 +73,9 @@ def get_block_with_transaction(transaction_hash: str) -> dict: Retrieve block for given transaction """ transaction = gateway_call("get_transaction", transactionHash=transaction_hash) - assert transaction["status"] != "NOT_RECEIVED", f"Transaction {transaction_hash} was not received or does not exist" + assert ( + transaction["status"] != "NOT_RECEIVED" + ), f"Transaction {transaction_hash} was not received or does not exist" block_number: int = transaction["block_number"] block = gateway_call("get_block", blockNumber=block_number) return block diff --git a/test/rpc/test_rpc_blocks.py b/test/rpc/test_rpc_blocks.py index 43a684f58..f987cb128 100644 --- a/test/rpc/test_rpc_blocks.py +++ b/test/rpc/test_rpc_blocks.py @@ -4,7 +4,11 @@ from test.shared import GENESIS_BLOCK_NUMBER, INCORRECT_GENESIS_BLOCK_HASH import pytest -from starknet_devnet.blueprints.rpc.structures.types import BlockNumberDict, BlockHashDict, rpc_txn_type +from starknet_devnet.blueprints.rpc.structures.types import ( + BlockNumberDict, + BlockHashDict, + rpc_txn_type, +) from starknet_devnet.blueprints.rpc.utils import rpc_root from starknet_devnet.general_config import DEFAULT_GENERAL_CONFIG @@ -21,9 +25,7 @@ def test_get_block_with_tx_hashes(deploy_info, gateway_block, block_id): block_number: int = gateway_block["block_number"] new_root: str = rpc_root(gateway_block["state_root"]) - resp = rpc_call( - "starknet_getBlockWithTxHashes", params={"block_id": block_id} - ) + resp = rpc_call("starknet_getBlockWithTxHashes", params={"block_id": block_id}) block = resp["result"] transaction_hash: str = pad_zero(deploy_info["transaction_hash"]) @@ -40,20 +42,20 @@ def test_get_block_with_tx_hashes(deploy_info, gateway_block, block_id): @pytest.mark.usefixtures("run_devnet_in_background", "deploy_info") -@pytest.mark.parametrize("block_id", [BlockNumberDict(block_number=1234), - BlockHashDict(block_hash=pad_zero(INCORRECT_GENESIS_BLOCK_HASH))]) +@pytest.mark.parametrize( + "block_id", + [ + BlockNumberDict(block_number=1234), + BlockHashDict(block_hash=pad_zero(INCORRECT_GENESIS_BLOCK_HASH)), + ], +) def test_get_block_with_tx_hashes_raises_on_incorrect_block_id(block_id): """ Get block with tx hashes by incorrect block_id """ - ex = rpc_call( - "starknet_getBlockWithTxHashes", params={"block_id": block_id} - ) + ex = rpc_call("starknet_getBlockWithTxHashes", params={"block_id": block_id}) - assert ex["error"] == { - "code": 24, - "message": "Invalid block id" - } + assert ex["error"] == {"code": 24, "message": "Invalid block id"} @pytest.mark.usefixtures("run_devnet_in_background", "deploy_info") @@ -67,9 +69,7 @@ def test_get_block_with_txs(gateway_block, block_id): new_root: str = rpc_root(gateway_block["state_root"]) block_tx = gateway_block["transactions"][0] - resp = rpc_call( - "starknet_getBlockWithTxs", params={"block_id": block_id} - ) + resp = rpc_call("starknet_getBlockWithTxs", params={"block_id": block_id}) block = resp["result"] assert block == { @@ -83,7 +83,9 @@ def test_get_block_with_txs(gateway_block, block_id): "transactions": [ { "class_hash": pad_zero(block_tx["class_hash"]), - "constructor_calldata": [pad_zero(data) for data in block_tx["constructor_calldata"]], + "constructor_calldata": [ + pad_zero(data) for data in block_tx["constructor_calldata"] + ], "contract_address": pad_zero(block_tx["contract_address"]), "contract_address_salt": pad_zero(block_tx["contract_address_salt"]), "transaction_hash": pad_zero(block_tx["transaction_hash"]), @@ -95,20 +97,20 @@ def test_get_block_with_txs(gateway_block, block_id): @pytest.mark.usefixtures("run_devnet_in_background", "deploy_info") -@pytest.mark.parametrize("block_id", [BlockNumberDict(block_number=1234), - BlockHashDict(block_hash=pad_zero(INCORRECT_GENESIS_BLOCK_HASH))]) +@pytest.mark.parametrize( + "block_id", + [ + BlockNumberDict(block_number=1234), + BlockHashDict(block_hash=pad_zero(INCORRECT_GENESIS_BLOCK_HASH)), + ], +) def test_get_block_with_txs_raises_on_incorrect_block_id(block_id): """ Get block with txs by incorrect block_id """ - ex = rpc_call( - "starknet_getBlockWithTxHashes", params={"block_id": block_id} - ) + ex = rpc_call("starknet_getBlockWithTxHashes", params={"block_id": block_id}) - assert ex["error"] == { - "code": 24, - "message": "Invalid block id" - } + assert ex["error"] == {"code": 24, "message": "Invalid block id"} @pytest.mark.usefixtures("run_devnet_in_background", "deploy_info", "gateway_block") @@ -120,29 +122,27 @@ def test_get_block_transaction_count(block_id): if "block_number" in block_id: block_id["block_number"] = GENESIS_BLOCK_NUMBER + 1 - resp = rpc_call( - "starknet_getBlockTransactionCount", params={"block_id": block_id} - ) + resp = rpc_call("starknet_getBlockTransactionCount", params={"block_id": block_id}) count = resp["result"] assert count == 1 @pytest.mark.usefixtures("run_devnet_in_background", "deploy_info") -@pytest.mark.parametrize("block_id", [BlockNumberDict(block_number=99999), - BlockHashDict(block_hash=pad_zero(INCORRECT_GENESIS_BLOCK_HASH))]) +@pytest.mark.parametrize( + "block_id", + [ + BlockNumberDict(block_number=99999), + BlockHashDict(block_hash=pad_zero(INCORRECT_GENESIS_BLOCK_HASH)), + ], +) def test_get_block_transaction_count_raises_on_incorrect_block_id(block_id): """ Get count of transactions in block by incorrect block id """ - ex = rpc_call( - "starknet_getBlockTransactionCount", params={"block_id": block_id} - ) + ex = rpc_call("starknet_getBlockTransactionCount", params={"block_id": block_id}) - assert ex["error"] == { - "code": 24, - "message": "Invalid block id" - } + assert ex["error"] == {"code": 24, "message": "Invalid block id"} @pytest.mark.usefixtures("run_devnet_in_background", "deploy_info") @@ -154,9 +154,7 @@ def test_get_block_number(): latest_block = gateway_call("get_block", blockNumber="latest") latest_block_number: int = latest_block["block_number"] - resp = rpc_call( - "starknet_blockNumber", params={} - ) + resp = rpc_call("starknet_blockNumber", params={}) block_number: int = resp["result"] assert latest_block_number == block_number diff --git a/test/rpc/test_rpc_call.py b/test/rpc/test_rpc_call.py index a011284b8..a79f2db4a 100644 --- a/test/rpc/test_rpc_call.py +++ b/test/rpc/test_rpc_call.py @@ -16,14 +16,15 @@ def test_call(deploy_info): contract_address: str = deploy_info["address"] resp = rpc_call( - "starknet_call", params={ + "starknet_call", + params={ "request": { "contract_address": pad_zero(contract_address), "entry_point_selector": hex(get_selector_from_name("get_balance")), "calldata": [], }, - "block_id": "latest" - } + "block_id": "latest", + }, ) result = resp["result"] @@ -36,20 +37,18 @@ def test_call_raises_on_incorrect_contract_address(): Call contract with incorrect address """ ex = rpc_call( - "starknet_call", params={ + "starknet_call", + params={ "request": { "contract_address": "0x07b529269b82f3f3ebbb2c463a9e1edaa2c6eea8fa308ff70b30398766a2e20c", "entry_point_selector": hex(get_selector_from_name("get_balance")), "calldata": [], }, - "block_id": "latest" - } + "block_id": "latest", + }, ) - assert ex["error"] == { - "code": 20, - "message": "Contract not found" - } + assert ex["error"] == {"code": 20, "message": "Contract not found"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -60,20 +59,18 @@ def test_call_raises_on_incorrect_selector(deploy_info): contract_address: str = deploy_info["address"] ex = rpc_call( - "starknet_call", params={ + "starknet_call", + params={ "request": { "contract_address": pad_zero(contract_address), "entry_point_selector": hex(get_selector_from_name("xxxxxxx")), "calldata": [], }, - "block_id": "latest" - } + "block_id": "latest", + }, ) - assert ex["error"] == { - "code": 21, - "message": "Invalid message selector" - } + assert ex["error"] == {"code": 21, "message": "Invalid message selector"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -84,20 +81,18 @@ def test_call_raises_on_invalid_calldata(deploy_info): contract_address: str = deploy_info["address"] ex = rpc_call( - "starknet_call", params={ + "starknet_call", + params={ "request": { "contract_address": pad_zero(contract_address), "entry_point_selector": hex(get_selector_from_name("get_balance")), "calldata": ["a", "b", "123"], }, - "block_id": "latest" - } + "block_id": "latest", + }, ) - assert ex["error"] == { - "code": 22, - "message": "Invalid call data" - } + assert ex["error"] == {"code": 22, "message": "Invalid call data"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -108,17 +103,18 @@ def test_call_raises_on_incorrect_block_hash(deploy_info): contract_address: str = deploy_info["address"] ex = rpc_call( - "starknet_call", params={ + "starknet_call", + params={ "request": { "contract_address": pad_zero(contract_address), "entry_point_selector": hex(get_selector_from_name("get_balance")), "calldata": [], }, - "block_id": "0x0" - } + "block_id": "0x0", + }, ) assert ex["error"] == { "code": -1, - "message": "Calls with block_id != 'latest' are not supported currently." + "message": "Calls with block_id != 'latest' are not supported currently.", } diff --git a/test/rpc/test_rpc_class.py b/test/rpc/test_rpc_class.py index 2d2a33cef..fea5baba5 100644 --- a/test/rpc/test_rpc_class.py +++ b/test/rpc/test_rpc_class.py @@ -13,19 +13,22 @@ def test_get_class(class_hash): """ Test get contract class """ - resp = rpc_call( - "starknet_getClass", - params={"class_hash": class_hash} - ) + resp = rpc_call("starknet_getClass", params={"class_hash": class_hash}) contract_class = resp["result"] assert contract_class["entry_points_by_type"] == { "CONSTRUCTOR": [], "EXTERNAL": [ - {"offset": "0x3a", "selector": "0x0362398bec32bc0ebb411203221a35a0301193a96f317ebe5e40be9f60d15320"}, - {"offset": "0x5b", "selector": "0x039e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695"} + { + "offset": "0x3a", + "selector": "0x0362398bec32bc0ebb411203221a35a0301193a96f317ebe5e40be9f60d15320", + }, + { + "offset": "0x5b", + "selector": "0x039e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695", + }, ], - "L1_HANDLER": [] + "L1_HANDLER": [], } assert isinstance(contract_class["program"], str) decompress_program({"contract_class": contract_class}, False) @@ -41,7 +44,7 @@ def test_get_class_hash_at(deploy_info, class_hash): resp = rpc_call( "starknet_getClassHashAt", - params={"contract_address": pad_zero(contract_address), "block_id": block_id} + params={"contract_address": pad_zero(contract_address), "block_id": block_id}, ) rpc_class_hash = resp["result"] @@ -58,17 +61,23 @@ def test_get_class_at(deploy_info): resp = rpc_call( "starknet_getClassAt", - params={"contract_address": pad_zero(contract_address), "block_id": block_id} + params={"contract_address": pad_zero(contract_address), "block_id": block_id}, ) contract_class = resp["result"] assert contract_class["entry_points_by_type"] == { "CONSTRUCTOR": [], "EXTERNAL": [ - {"offset": "0x3a", "selector": "0x0362398bec32bc0ebb411203221a35a0301193a96f317ebe5e40be9f60d15320"}, - {"offset": "0x5b", "selector": "0x039e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695"} + { + "offset": "0x3a", + "selector": "0x0362398bec32bc0ebb411203221a35a0301193a96f317ebe5e40be9f60d15320", + }, + { + "offset": "0x5b", + "selector": "0x039e11d48192e4333233c7eb19d10ad67c362bb28580c604d67884c85da39695", + }, ], - "L1_HANDLER": [] + "L1_HANDLER": [], } assert isinstance(contract_class["program"], str) decompress_program({"contract_class": contract_class}, False) diff --git a/test/rpc/test_rpc_estimate_fee.py b/test/rpc/test_rpc_estimate_fee.py index a924fb187..b1733c4f3 100644 --- a/test/rpc/test_rpc_estimate_fee.py +++ b/test/rpc/test_rpc_estimate_fee.py @@ -29,7 +29,9 @@ def common_estimate_response(response): @pytest.mark.usefixtures("run_devnet_in_background") -@pytest.mark.parametrize("run_devnet_in_background", [["--gas-price", str(DEFAULT_GAS_PRICE)]], indirect=True) +@pytest.mark.parametrize( + "run_devnet_in_background", [["--gas-price", str(DEFAULT_GAS_PRICE)]], indirect=True +) def test_estimate_happy_path(): """Happy path estimate_fee call""" deploy_info = deploy(CONTRACT_PATH, ["0"]) @@ -78,16 +80,13 @@ def test_estimate_fee_with_invalid_call_data(rpc_invoke_tx_common): "contract_address": deploy_info["address"], "entry_point_selector": hex(get_selector_from_name("sum_point_array")), "calldata": ["10", "20"], - **rpc_invoke_tx_common + **rpc_invoke_tx_common, } ex = rpc_call_background_devnet( "starknet_estimateFee", {"request": txn, "block_id": "latest"} ) - assert ex["error"] == { - "code": 22, - "message": "Invalid call data" - } + assert ex["error"] == {"code": 22, "message": "Invalid call data"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -103,10 +102,7 @@ def test_estimate_fee_with_invalid_contract_address(rpc_invoke_tx_common): "starknet_estimateFee", {"request": txn, "block_id": "latest"} ) - assert ex["error"] == { - "code": 20, - "message": "Contract not found" - } + assert ex["error"] == {"code": 20, "message": "Contract not found"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -124,14 +120,13 @@ def test_estimate_fee_with_invalid_message_selector(rpc_invoke_tx_common): "starknet_estimateFee", {"request": txn, "block_id": "latest"} ) - assert ex["error"] == { - "code": 21, - "message": "Invalid message selector" - } + assert ex["error"] == {"code": 21, "message": "Invalid message selector"} @pytest.mark.usefixtures("run_devnet_in_background") -@pytest.mark.parametrize("run_devnet_in_background", [["--gas-price", str(DEFAULT_GAS_PRICE)]], indirect=True) +@pytest.mark.parametrize( + "run_devnet_in_background", [["--gas-price", str(DEFAULT_GAS_PRICE)]], indirect=True +) def test_estimate_fee_with_complete_request_data(rpc_invoke_tx_common): """Estimate fee with complete request data""" diff --git a/test/rpc/test_rpc_misc.py b/test/rpc/test_rpc_misc.py index 15247bea9..5b6524050 100644 --- a/test/rpc/test_rpc_misc.py +++ b/test/rpc/test_rpc_misc.py @@ -30,16 +30,20 @@ def test_get_state_update(deploy_info, invoke_info, contract_class): block_id_invoke = BlockHashDict(block_hash=block_with_invoke_hash) class_hash = pad_zero(hex(compute_class_hash(contract_class))) - storage = gateway_call("get_storage_at", contractAddress=contract_address, key=get_storage_var_address("balance")) + storage = gateway_call( + "get_storage_at", + contractAddress=contract_address, + key=get_storage_var_address("balance"), + ) - new_root_deploy = "0x0" + gateway_call("get_state_update", blockHash=block_with_deploy_hash)["new_root"].lstrip("0") - new_root_invoke = "0x0" + gateway_call("get_state_update", blockHash=block_with_invoke_hash)["new_root"].lstrip("0") + new_root_deploy = "0x0" + gateway_call( + "get_state_update", blockHash=block_with_deploy_hash + )["new_root"].lstrip("0") + new_root_invoke = "0x0" + gateway_call( + "get_state_update", blockHash=block_with_invoke_hash + )["new_root"].lstrip("0") - resp = rpc_call( - "starknet_getStateUpdate", params={ - "block_id": block_id_deploy - } - ) + resp = rpc_call("starknet_getStateUpdate", params={"block_id": block_id_deploy}) state_update = resp["result"] assert state_update["block_hash"] == block_with_deploy_hash @@ -62,11 +66,7 @@ def test_get_state_update(deploy_info, invoke_info, contract_class): "nonces": [], } - resp = rpc_call( - "starknet_getStateUpdate", params={ - "block_id": block_id_invoke - } - ) + resp = rpc_call("starknet_getStateUpdate", params={"block_id": block_id_invoke}) state_update = resp["result"] assert state_update["block_hash"] == block_with_invoke_hash diff --git a/test/rpc/test_rpc_storage.py b/test/rpc/test_rpc_storage.py index 0595a3aee..941a17340 100644 --- a/test/rpc/test_rpc_storage.py +++ b/test/rpc/test_rpc_storage.py @@ -18,11 +18,12 @@ def test_get_storage_at(deploy_info): block_id: str = "latest" resp = rpc_call( - "starknet_getStorageAt", params={ + "starknet_getStorageAt", + params={ "contract_address": pad_zero(contract_address), "key": key, "block_id": block_id, - } + }, ) storage = resp["result"] @@ -38,17 +39,15 @@ def test_get_storage_at_raises_on_incorrect_contract(): block_id: str = "latest" ex = rpc_call( - "starknet_getStorageAt", params={ + "starknet_getStorageAt", + params={ "contract_address": "0x00", "key": key, "block_id": block_id, - } + }, ) - assert ex["error"] == { - "code": 20, - "message": "Contract not found" - } + assert ex["error"] == {"code": 20, "message": "Contract not found"} # internal workings of get_storage_at would have to be changed for this to work properly @@ -62,11 +61,12 @@ def test_get_storage_at_raises_on_incorrect_key(deploy_info): contract_address: str = deploy_info["address"] response = rpc_call( - "starknet_getStorageAt", params={ + "starknet_getStorageAt", + params={ "contract_address": pad_zero(contract_address), "key": "0x00", "block_id": "latest", - } + }, ) assert response["result"] == "0x00" @@ -82,14 +82,15 @@ def test_get_storage_at_raises_on_incorrect_block_id(deploy_info): key: str = hex(get_storage_var_address("balance")) ex = rpc_call( - "starknet_getStorageAt", params={ + "starknet_getStorageAt", + params={ "contract_address": pad_zero(contract_address), "key": key, "block_id": "0x0", - } + }, ) assert ex["error"] == { "code": -1, - "message": "Calls with block_id != 'latest' are not supported currently." + "message": "Calls with block_id != 'latest' are not supported currently.", } diff --git a/test/rpc/test_rpc_transactions.py b/test/rpc/test_rpc_transactions.py index 2e3d90e5c..8e6cdca75 100644 --- a/test/rpc/test_rpc_transactions.py +++ b/test/rpc/test_rpc_transactions.py @@ -25,7 +25,9 @@ def pad_zero_external_entry_points(contract_class: dict) -> dict: """ external_entry_points = contract_class["entry_points_by_type"]["EXTERNAL"] for i, _ in enumerate(external_entry_points): - external_entry_points[i]["selector"] = pad_zero(external_entry_points[i]["selector"]) + external_entry_points[i]["selector"] = pad_zero( + external_entry_points[i]["selector"] + ) contract_class["entry_points_by_type"]["EXTERNAL"] = external_entry_points @@ -43,7 +45,8 @@ def test_get_transaction_by_hash_deploy(deploy_info): contract_address: str = deploy_info["address"] resp = rpc_call( - "starknet_getTransactionByHash", params={"transaction_hash": pad_zero(transaction_hash)} + "starknet_getTransactionByHash", + params={"transaction_hash": pad_zero(transaction_hash)}, ) transaction = resp["result"] @@ -68,13 +71,12 @@ def test_get_transaction_by_hash_invoke(invoke_info): transaction_hash: str = invoke_info["transaction_hash"] contract_address: str = invoke_info["address"] entry_point_selector: str = invoke_info["entry_point_selector"] - signature: List[str] = [pad_zero(hex(int(sig))) - for sig in invoke_info["signature"]] - calldata: List[str] = [pad_zero(hex(int(data))) - for data in invoke_info["calldata"]] + signature: List[str] = [pad_zero(hex(int(sig))) for sig in invoke_info["signature"]] + calldata: List[str] = [pad_zero(hex(int(data))) for data in invoke_info["calldata"]] resp = rpc_call( - "starknet_getTransactionByHash", params={"transaction_hash": pad_zero(transaction_hash)} + "starknet_getTransactionByHash", + params={"transaction_hash": pad_zero(transaction_hash)}, ) transaction = resp["result"] @@ -99,11 +101,13 @@ def test_get_transaction_by_hash_declare(declare_info): block = get_block_with_transaction(declare_info["transaction_hash"]) block_tx = block["transactions"][0] transaction_hash: str = declare_info["transaction_hash"] - signature: List[str] = [pad_zero(hex(int(sig))) - for sig in declare_info["signature"]] + signature: List[str] = [ + pad_zero(hex(int(sig))) for sig in declare_info["signature"] + ] resp = rpc_call( - "starknet_getTransactionByHash", params={"transaction_hash": pad_zero(transaction_hash)} + "starknet_getTransactionByHash", + params={"transaction_hash": pad_zero(transaction_hash)}, ) transaction = resp["result"] @@ -124,14 +128,9 @@ def test_get_transaction_by_hash_raises_on_incorrect_hash(): """ Get transaction by incorrect hash """ - ex = rpc_call( - "starknet_getTransactionByHash", params={"transaction_hash": "0x00"} - ) + ex = rpc_call("starknet_getTransactionByHash", params={"transaction_hash": "0x00"}) - assert ex["error"] == { - "code": 25, - "message": "Invalid transaction hash" - } + assert ex["error"] == {"code": 25, "message": "Invalid transaction hash"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -147,18 +146,21 @@ def test_get_transaction_by_block_id_and_index(deploy_info): index: int = 0 resp = rpc_call( - "starknet_getTransactionByBlockIdAndIndex", params={ + "starknet_getTransactionByBlockIdAndIndex", + params={ "block_id": { "block_number": block_number, }, - "index": index - } + "index": index, + }, ) transaction = resp["result"] assert transaction == { "class_hash": pad_zero(block_tx["class_hash"]), - "constructor_calldata": [pad_zero(tx) for tx in block_tx["constructor_calldata"]], + "constructor_calldata": [ + pad_zero(tx) for tx in block_tx["constructor_calldata"] + ], "contract_address": pad_zero(contract_address), "contract_address_salt": pad_zero(block_tx["contract_address_salt"]), "transaction_hash": pad_zero(transaction_hash), @@ -173,18 +175,14 @@ def test_get_transaction_by_block_id_and_index_raises_on_incorrect_block_hash(): Get transaction by incorrect block id """ ex = rpc_call( - "starknet_getTransactionByBlockIdAndIndex", params={ - "block_id": { - "block_hash": pad_zero(INCORRECT_GENESIS_BLOCK_HASH) - }, - "index": 0 - } + "starknet_getTransactionByBlockIdAndIndex", + params={ + "block_id": {"block_hash": pad_zero(INCORRECT_GENESIS_BLOCK_HASH)}, + "index": 0, + }, ) - assert ex["error"] == { - "code": 24, - "message": "Invalid block id" - } + assert ex["error"] == {"code": 24, "message": "Invalid block id"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -196,17 +194,18 @@ def test_get_transaction_by_block_id_and_index_raises_on_incorrect_index(deploy_ block_hash: str = block["block_hash"] ex = rpc_call( - "starknet_getTransactionByBlockIdAndIndex", params={ + "starknet_getTransactionByBlockIdAndIndex", + params={ "block_id": { "block_hash": pad_zero(block_hash), }, - "index": 999999 - } + "index": 999999, + }, ) assert ex["error"] == { "code": 27, - "message": "Invalid transaction index in a block" + "message": "Invalid transaction index in a block", } @@ -219,9 +218,8 @@ def test_get_declare_transaction_receipt(declare_info): block = get_block_with_transaction(transaction_hash) resp = rpc_call( - "starknet_getTransactionReceipt", params={ - "transaction_hash": pad_zero(transaction_hash) - } + "starknet_getTransactionReceipt", + params={"transaction_hash": pad_zero(transaction_hash)}, ) receipt = resp["result"] @@ -243,9 +241,8 @@ def test_get_invoke_transaction_receipt(invoke_info): transaction_hash: str = invoke_info["transaction_hash"] resp = rpc_call( - "starknet_getTransactionReceipt", params={ - "transaction_hash": pad_zero(transaction_hash) - } + "starknet_getTransactionReceipt", + params={"transaction_hash": pad_zero(transaction_hash)}, ) receipt = resp["result"] @@ -264,15 +261,10 @@ def test_get_transaction_receipt_on_incorrect_hash(): Get transaction receipt by incorrect hash """ ex = rpc_call( - "starknet_getTransactionReceipt", params={ - "transaction_hash": rpc_felt(0) - } + "starknet_getTransactionReceipt", params={"transaction_hash": rpc_felt(0)} ) - assert ex["error"] == { - "code": 25, - "message": "Invalid transaction hash" - } + assert ex["error"] == {"code": 25, "message": "Invalid transaction hash"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -284,9 +276,8 @@ def test_get_deploy_transaction_receipt(deploy_info): block = get_block_with_transaction(transaction_hash) resp = rpc_call( - "starknet_getTransactionReceipt", params={ - "transaction_hash": pad_zero(transaction_hash) - } + "starknet_getTransactionReceipt", + params={"transaction_hash": pad_zero(transaction_hash)}, ) receipt = resp["result"] @@ -310,13 +301,17 @@ def test_add_invoke_transaction(invoke_content): params={ "function_invocation": { "contract_address": pad_zero(invoke_content["contract_address"]), - "entry_point_selector": pad_zero(invoke_content["entry_point_selector"]), - "calldata": [pad_zero(hex(int(data))) for data in invoke_content["calldata"]], + "entry_point_selector": pad_zero( + invoke_content["entry_point_selector"] + ), + "calldata": [ + pad_zero(hex(int(data))) for data in invoke_content["calldata"] + ], }, "signature": [pad_zero(sig) for sig in invoke_content["signature"]], "max_fee": hex(0), "version": hex(constants.TRANSACTION_VERSION), - } + }, ) receipt = resp["result"] @@ -342,13 +337,10 @@ def test_add_declare_transaction_on_incorrect_contract(declare_content): params={ "contract_class": rpc_contract, "version": hex(constants.TRANSACTION_VERSION), - } + }, ) - assert ex["error"] == { - "code": 50, - "message": "Invalid contract class" - } + assert ex["error"] == {"code": 50, "message": "Invalid contract class"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -369,7 +361,7 @@ def test_add_declare_transaction(declare_content): params={ "contract_class": rpc_contract, "version": hex(constants.TRANSACTION_VERSION), - } + }, ) receipt = resp["result"] @@ -399,13 +391,10 @@ def test_add_deploy_transaction_on_incorrect_contract(deploy_content): "contract_address_salt": pad_zero(salt), "constructor_calldata": calldata, "contract_definition": rpc_contract, - } + }, ) - assert ex["error"] == { - "code": 50, - "message": "Invalid contract class" - } + assert ex["error"] == {"code": 50, "message": "Invalid contract class"} @pytest.mark.usefixtures("run_devnet_in_background") @@ -429,7 +418,7 @@ def test_add_deploy_transaction(deploy_content): "contract_address_salt": pad_zero(salt), "constructor_calldata": calldata, "contract_definition": rpc_contract, - } + }, ) receipt = resp["result"] diff --git a/test/settings.py b/test/settings.py index 3340be1d6..7ce5d3526 100644 --- a/test/settings.py +++ b/test/settings.py @@ -2,6 +2,7 @@ import socket + def bind_free_port(host): """return assigned free port and test base endpoint""" sock = socket.socket() @@ -9,6 +10,7 @@ def bind_free_port(host): port = str(sock.getsockname()[1]) return port, f"http://{host}:{port}" + HOST = "127.0.0.1" PORT, APP_URL = bind_free_port(HOST) diff --git a/test/shared.py b/test/shared.py index ef8e8150a..67c63fd28 100644 --- a/test/shared.py +++ b/test/shared.py @@ -11,15 +11,21 @@ DEPLOYER_CONTRACT_PATH = f"{ARTIFACTS_PATH}/deployer.cairo/deployer.json" DEPLOYER_ABI_PATH = f"{ARTIFACTS_PATH}/deployer.cairo/deployer_abi.json" -BALANCE_KEY = "916907772491729262376534102982219947830828984996257231353398618781993312401" +BALANCE_KEY = ( + "916907772491729262376534102982219947830828984996257231353398618781993312401" +) SIGNATURE = [ "1225578735933442828068102633747590437426782890965066746429241472187377583468", - "3568809569741913715045370357918125425757114920266578211811626257903121825123" + "3568809569741913715045370357918125425757114920266578211811626257903121825123", ] -EXPECTED_SALTY_DEPLOY_ADDRESS = "0x07a0c836e446fb20e2b8e3354251b862ea45cfd039bb158576f5e8d0983ff2bb" -EXPECTED_SALTY_DEPLOY_HASH = "0x23801cc34aa43f4e2bf3e74a838fe45dd1b1ad316a2d3545aaef7efe1f39b21" +EXPECTED_SALTY_DEPLOY_ADDRESS = ( + "0x07a0c836e446fb20e2b8e3354251b862ea45cfd039bb158576f5e8d0983ff2bb" +) +EXPECTED_SALTY_DEPLOY_HASH = ( + "0x23801cc34aa43f4e2bf3e74a838fe45dd1b1ad316a2d3545aaef7efe1f39b21" +) EXPECTED_CLASS_HASH = "0x757a84aa38bf4ad191a7dfea2e8146fc7f3c4aa6090a8f0bddd7b688f0b24c" NONEXISTENT_TX_HASH = "0x1" diff --git a/test/test_account.py b/test/test_account.py index 27f4d05e5..79967e20c 100644 --- a/test/test_account.py +++ b/test/test_account.py @@ -16,7 +16,7 @@ get_transaction_receipt, load_file_content, call, - estimate_fee + estimate_fee, ) from .account import ( ACCOUNT_ABI_PATH, @@ -25,7 +25,7 @@ deploy_account_contract, get_nonce, execute, - get_estimated_fee + get_estimated_fee, ) INVOKE_CONTENT = load_file_content("invoke.json") @@ -35,28 +35,38 @@ SALT = "0x99" ACCOUNTS_SEED_DEVNET_ARGS = [ - "--accounts", "1", - "--seed", "42", - "--gas-price", "100_000_000", - "--initial-balance", "1_000_000_000_000_000_000_000" + "--accounts", + "1", + "--seed", + "42", + "--gas-price", + "100_000_000", + "--initial-balance", + "1_000_000_000_000_000_000_000", ] -PREDEPLOYED_ACCOUNT_ADDRESS = "0x347be35996a21f6bf0623e75dbce52baba918ad5ae8d83b6f416045ab22961a" -PREDEPLOYED_ACCOUNT_PRIVATE_KEY = 0xbdd640fb06671ad11c80317fa3b1799d +PREDEPLOYED_ACCOUNT_ADDRESS = ( + "0x347be35996a21f6bf0623e75dbce52baba918ad5ae8d83b6f416045ab22961a" +) +PREDEPLOYED_ACCOUNT_PRIVATE_KEY = 0xBDD640FB06671AD11C80317FA3B1799D + def deploy_empty_contract(): """Deploy sample contract with balance = 0.""" return deploy(CONTRACT_PATH, inputs=["0"], salt=SALT) + def deploy_events_contract(): """Deploy events contract with salt of 0x99.""" return deploy(EVENTS_CONTRACT_PATH, salt=SALT) + def get_account_balance(address: str) -> int: """Get balance (wei) of account with `address` (hex).""" resp = requests.get(f"{APP_URL}/account_balance?address={address}") assert resp.status_code == 200 return int(resp.json()["amount"]) + @pytest.mark.account @devnet_in_background() def test_account_contract_deploy(): @@ -64,14 +74,13 @@ def test_account_contract_deploy(): deploy_info = deploy_account_contract(salt=SALT) assert deploy_info["address"] == ACCOUNT_ADDRESS - deployed_public_key = call( - "get_public_key", ACCOUNT_ADDRESS, ACCOUNT_ABI_PATH - ) + deployed_public_key = call("get_public_key", ACCOUNT_ADDRESS, ACCOUNT_ABI_PATH) assert int(deployed_public_key, 16) == PUBLIC_KEY nonce = get_nonce(ACCOUNT_ADDRESS) assert nonce == "0" + @pytest.mark.account @devnet_in_background() def test_invoking_another_contract(): @@ -91,12 +100,15 @@ def test_invoking_another_contract(): assert nonce == "1" # check if balance is increased - balance_raw = execute([(to_address, "get_balance", [])], ACCOUNT_ADDRESS, PRIVATE_KEY, query=True) + balance_raw = execute( + [(to_address, "get_balance", [])], ACCOUNT_ADDRESS, PRIVATE_KEY, query=True + ) balance_arr = balance_raw.split() assert_equal(len(balance_arr), 2) balance = balance_arr[1] assert balance == "30" + @pytest.mark.account @devnet_in_background() def test_estimated_fee(): @@ -115,10 +127,7 @@ def test_estimated_fee(): # estimate fee without account estimated_fee_without_account = estimate_fee( - "increase_balance", - ["10", "20"], - deploy_info["address"], - ABI_PATH + "increase_balance", ["10", "20"], deploy_info["address"], ABI_PATH ) assert estimated_fee_without_account < estimated_fee @@ -127,6 +136,7 @@ def test_estimated_fee(): balance = call("get_balance", deploy_info["address"], abi_path=ABI_PATH) assert balance == initial_balance + @pytest.mark.account @devnet_in_background() def test_low_max_fee(): @@ -150,6 +160,7 @@ def test_low_max_fee(): assert_equal(balance, initial_balance) + @pytest.mark.account @devnet_in_background(*ACCOUNTS_SEED_DEVNET_ARGS) def test_sufficient_max_fee(): @@ -160,7 +171,9 @@ def test_sufficient_max_fee(): to_address = int(deploy_info["address"], 16) initial_account_balance = get_account_balance(account_address) - initial_contract_balance = call("get_balance", deploy_info["address"], abi_path=ABI_PATH) + initial_contract_balance = call( + "get_balance", deploy_info["address"], abi_path=ABI_PATH + ) args = [10, 20] calls = [(to_address, "increase_balance", args)] @@ -173,43 +186,62 @@ def test_sufficient_max_fee(): invoke_receipt = get_transaction_receipt(invoke_tx_hash) actual_fee = int(invoke_receipt["actual_fee"], 16) - final_contract_balance = call("get_balance", deploy_info["address"], abi_path=ABI_PATH) + final_contract_balance = call( + "get_balance", deploy_info["address"], abi_path=ABI_PATH + ) assert_equal(int(final_contract_balance), int(initial_contract_balance) + sum(args)) final_account_balance = get_account_balance(account_address) assert_equal(final_account_balance, initial_account_balance - actual_fee) + @pytest.mark.account @devnet_in_background( - "--accounts", "1", - "--seed", "42", - "--gas-price", "100_000_000", - "--initial-balance", "10" + "--accounts", + "1", + "--seed", + "42", + "--gas-price", + "100_000_000", + "--initial-balance", + "10", ) def test_insufficient_balance(): """Test handling of insufficient account balance""" deploy_info = deploy_empty_contract() - account_address = "0x347be35996a21f6bf0623e75dbce52baba918ad5ae8d83b6f416045ab22961a" - private_key = 0xbdd640fb06671ad11c80317fa3b1799d + account_address = ( + "0x347be35996a21f6bf0623e75dbce52baba918ad5ae8d83b6f416045ab22961a" + ) + private_key = 0xBDD640FB06671AD11C80317FA3B1799D to_address = int(deploy_info["address"], 16) initial_account_balance = get_account_balance(account_address) - initial_contract_balance = call("get_balance", deploy_info["address"], abi_path=ABI_PATH) + initial_contract_balance = call( + "get_balance", deploy_info["address"], abi_path=ABI_PATH + ) args = [10, 20] calls = [(to_address, "increase_balance", args)] - invoke_tx_hash = execute(calls, account_address, private_key, max_fee=10 ** 21) # big enough + invoke_tx_hash = execute( + calls, account_address, private_key, max_fee=10**21 + ) # big enough assert_tx_status(invoke_tx_hash, "REJECTED") invoke_receipt = get_transaction_receipt(invoke_tx_hash) - assert "subtraction overflow" in invoke_receipt["transaction_failure_reason"]["error_message"] + assert ( + "subtraction overflow" + in invoke_receipt["transaction_failure_reason"]["error_message"] + ) - final_contract_balance = call("get_balance", deploy_info["address"], abi_path=ABI_PATH) + final_contract_balance = call( + "get_balance", deploy_info["address"], abi_path=ABI_PATH + ) assert_equal(final_contract_balance, initial_contract_balance) final_account_balance = get_account_balance(account_address) assert_equal(initial_account_balance, final_account_balance) + @pytest.mark.account @devnet_in_background() def test_multicall(): @@ -221,7 +253,7 @@ def test_multicall(): # execute increase_balance calls calls = [ (to_address, "increase_balance", [10, 20]), - (to_address, "increase_balance", [30, 40]) + (to_address, "increase_balance", [30, 40]), ] tx_hash = execute(calls, ACCOUNT_ADDRESS, PRIVATE_KEY) @@ -235,6 +267,7 @@ def test_multicall(): balance = call("get_balance", deploy_info["address"], abi_path=ABI_PATH) assert balance == "100" + @pytest.mark.account @devnet_in_background(*ACCOUNTS_SEED_DEVNET_ARGS) def test_events(): diff --git a/test/test_api_specifications.py b/test/test_api_specifications.py index b4f429473..298dafae4 100644 --- a/test/test_api_specifications.py +++ b/test/test_api_specifications.py @@ -5,6 +5,7 @@ from starknet_devnet.server import app from .settings import APP_URL + def test_api_endpoint(): """Assert that /api endpoint return list of endpoints""" response = app.test_client().get(f"{APP_URL}/api") diff --git a/test/test_block_number.py b/test/test_block_number.py index c916abd62..6b0c100a5 100644 --- a/test/test_block_number.py +++ b/test/test_block_number.py @@ -8,14 +8,14 @@ BLOCK_NUMBER_CONTRACT_PATH = f"{ARTIFACTS_PATH}/block_number.cairo/block_number.json" BLOCK_NUMBER_ABI_PATH = f"{ARTIFACTS_PATH}/block_number.cairo/block_number_abi.json" + def my_get_block_number(address: str): """Execute my_get_block_number on block_number.cairo contract deployed at `address`""" return call( - function="my_get_block_number", - address=address, - abi_path=BLOCK_NUMBER_ABI_PATH + function="my_get_block_number", address=address, abi_path=BLOCK_NUMBER_ABI_PATH ) + def base_workflow(): """Used by test cases to perform the test""" deploy_info = deploy(BLOCK_NUMBER_CONTRACT_PATH) @@ -26,30 +26,33 @@ def base_workflow(): function="write_block_number", inputs=[], address=deploy_info["address"], - abi_path=BLOCK_NUMBER_ABI_PATH + abi_path=BLOCK_NUMBER_ABI_PATH, ) written_block_number = call( function="read_block_number", inputs=[], address=deploy_info["address"], - abi_path=BLOCK_NUMBER_ABI_PATH + abi_path=BLOCK_NUMBER_ABI_PATH, ) assert int(written_block_number) == GENESIS_BLOCK_NUMBER + 2 block_number_after = my_get_block_number(deploy_info["address"]) assert int(block_number_after) == GENESIS_BLOCK_NUMBER + 2 + @devnet_in_background() def test_block_number_incremented(): """Tests how block number is incremented in regular mode""" base_workflow() + @devnet_in_background("--lite-mode") def test_block_number_incremented_in_lite_mode(): """Tests compatibility with lite mode""" base_workflow() + @devnet_in_background() def test_block_number_incremented_on_declare(): """Declare tx should increment get_block_number response""" @@ -64,6 +67,7 @@ def test_block_number_incremented_on_declare(): block_number_after = my_get_block_number(deploy_info["address"]) assert int(block_number_after) == GENESIS_BLOCK_NUMBER + 2 + @devnet_in_background() def test_block_number_not_incremented_if_deploy_fails(): """ @@ -80,6 +84,7 @@ def test_block_number_not_incremented_if_deploy_fails(): block_number_after = my_get_block_number(deploy_info["address"]) assert int(block_number_after) == GENESIS_BLOCK_NUMBER + 1 + @devnet_in_background() def test_block_number_not_incremented_if_invoke_fails(): """ diff --git a/test/test_declare.py b/test/test_declare.py index c07a668ee..533ea6a5e 100644 --- a/test/test_declare.py +++ b/test/test_declare.py @@ -10,7 +10,7 @@ CONTRACT_PATH, DEPLOYER_ABI_PATH, DEPLOYER_CONTRACT_PATH, - EXPECTED_CLASS_HASH + EXPECTED_CLASS_HASH, ) from .util import ( assert_contract_class, @@ -24,9 +24,10 @@ get_class_by_hash, get_class_hash_at, get_transaction_receipt, - invoke + invoke, ) + def assert_deployed_through_syscall(tx_hash, initial_balance): """Asserts that a contract has been deployed using the deploy syscall""" assert_tx_status(tx_hash, "ACCEPTED_ON_L2") @@ -43,15 +44,15 @@ def assert_deployed_through_syscall(tx_hash, initial_balance): fetched_class_hash = get_class_hash_at(contract_address=contract_address) assert_hex_equal(fetched_class_hash, EXPECTED_CLASS_HASH) - balance = call( - function="get_balance", - address=contract_address, - abi_path=ABI_PATH - ) + balance = call(function="get_balance", address=contract_address, abi_path=ABI_PATH) assert_equal(balance, initial_balance) -PREDEPLOYED_ACCOUNT_ADDRESS = "0x347be35996a21f6bf0623e75dbce52baba918ad5ae8d83b6f416045ab22961a" -PREDEPLOYED_ACCOUNT_PRIVATE_KEY = 0xbdd640fb06671ad11c80317fa3b1799d + +PREDEPLOYED_ACCOUNT_ADDRESS = ( + "0x347be35996a21f6bf0623e75dbce52baba918ad5ae8d83b6f416045ab22961a" +) +PREDEPLOYED_ACCOUNT_PRIVATE_KEY = 0xBDD640FB06671AD11C80317FA3B1799D + @pytest.mark.declare @devnet_in_background("--accounts", "1", "--seed", "42") @@ -72,11 +73,13 @@ def test_declare_and_deploy(): initial_balance_in_constructor = "5" deployer_deploy_info = deploy( contract=DEPLOYER_CONTRACT_PATH, - inputs=[declare_info["class_hash"], initial_balance_in_constructor] + inputs=[declare_info["class_hash"], initial_balance_in_constructor], ) deployer_address = deployer_deploy_info["address"] - assert_deployed_through_syscall(deployer_deploy_info["tx_hash"], initial_balance_in_constructor) + assert_deployed_through_syscall( + deployer_deploy_info["tx_hash"], initial_balance_in_constructor + ) # Deploy a contract of the declared class through the deployer initial_balance = "10" @@ -84,18 +87,24 @@ def test_declare_and_deploy(): function="deploy_contract", inputs=[initial_balance], address=deployer_address, - abi_path=DEPLOYER_ABI_PATH + abi_path=DEPLOYER_ABI_PATH, ) assert_deployed_through_syscall(invoke_tx_hash, initial_balance) # Deploy a contract of the declared class through the deployer - using an account initial_balance_through_account = 15 invoke_through_account_tx_hash = execute( - calls=[(int(deployer_address, 16), "deploy_contract", [initial_balance_through_account])], + calls=[ + ( + int(deployer_address, 16), + "deploy_contract", + [initial_balance_through_account], + ) + ], account_address=PREDEPLOYED_ACCOUNT_ADDRESS, - private_key=PREDEPLOYED_ACCOUNT_PRIVATE_KEY + private_key=PREDEPLOYED_ACCOUNT_PRIVATE_KEY, ) assert_deployed_through_syscall( tx_hash=invoke_through_account_tx_hash, - initial_balance=str(initial_balance_through_account) + initial_balance=str(initial_balance_through_account), ) diff --git a/test/test_deploy.py b/test/test_deploy.py index 657d77a0d..7a4b8031d 100644 --- a/test/test_deploy.py +++ b/test/test_deploy.py @@ -4,21 +4,27 @@ import pytest from starkware.starknet.business_logic.internal_transaction import InternalDeploy -from starkware.starknet.core.os.contract_address.contract_address import calculate_contract_address +from starkware.starknet.core.os.contract_address.contract_address import ( + calculate_contract_address, +) from starkware.starknet.definitions import constants from starkware.starknet.services.api.contract_class import ContractClass from starkware.starknet.services.api.gateway.transaction import Deploy -from starkware.starknet.services.api.feeder_gateway.response_objects import TransactionStatus +from starkware.starknet.services.api.feeder_gateway.response_objects import ( + TransactionStatus, +) from starknet_devnet.devnet_config import parse_args, DevnetConfig from starknet_devnet.starknet_wrapper import StarknetWrapper from .shared import CONTRACT_PATH, GENESIS_BLOCK_NUMBER + def get_contract_class(): """Get the contract class from the contract.json file.""" with open(CONTRACT_PATH, "r", encoding="utf-8") as contract_class_file: return ContractClass.loads(contract_class_file.read()) + def get_deploy_transaction(inputs: List[int], salt=0): """Get a Deploy transaction.""" contract_class = get_contract_class() @@ -27,9 +33,10 @@ def get_deploy_transaction(inputs: List[int], salt=0): contract_address_salt=salt, contract_definition=contract_class, constructor_calldata=inputs, - version=constants.TRANSACTION_VERSION + version=constants.TRANSACTION_VERSION, ) + @pytest.mark.asyncio async def test_deploy(): """ @@ -39,13 +46,15 @@ async def test_deploy(): await devnet.initialize() deploy_transaction = get_deploy_transaction(inputs=[0]) - contract_address, tx_hash = await devnet.deploy(deploy_transaction=deploy_transaction) + contract_address, tx_hash = await devnet.deploy( + deploy_transaction=deploy_transaction + ) expected_contract_address = calculate_contract_address( deployer_address=0, constructor_calldata=deploy_transaction.constructor_calldata, salt=deploy_transaction.contract_address_salt, - contract_class=deploy_transaction.contract_definition + contract_class=deploy_transaction.contract_definition, ) assert contract_address == expected_contract_address @@ -53,12 +62,12 @@ async def test_deploy(): state = devnet.get_state() internal_tx = InternalDeploy.from_external( - external_tx=deploy_transaction, - general_config=state.general_config + external_tx=deploy_transaction, general_config=state.general_config ) assert tx_hash == internal_tx.hash_value + @pytest.mark.asyncio async def test_deploy_lite(): """ @@ -68,12 +77,14 @@ async def test_deploy_lite(): await devnet.initialize() deploy_transaction = get_deploy_transaction(inputs=[0]) - contract_address, tx_hash = await devnet.deploy(deploy_transaction=deploy_transaction) + contract_address, tx_hash = await devnet.deploy( + deploy_transaction=deploy_transaction + ) expected_contract_address = calculate_contract_address( deployer_address=0, constructor_calldata=deploy_transaction.constructor_calldata, salt=deploy_transaction.contract_address_salt, - contract_class=deploy_transaction.contract_definition + contract_class=deploy_transaction.contract_definition, ) assert contract_address == expected_contract_address diff --git a/test/test_dump.py b/test/test_dump.py index 1cfa1fb89..07aa1c960 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -12,31 +12,42 @@ from .test_account import get_account_balance from .test_fee_token import mint -from .util import call, deploy, devnet_in_background, invoke, run_devnet_in_background, terminate_and_wait +from .util import ( + call, + deploy, + devnet_in_background, + invoke, + run_devnet_in_background, + terminate_and_wait, +) from .settings import APP_URL from .shared import CONTRACT_PATH, ABI_PATH DUMP_PATH = "dump.pkl" + class DevnetBackgroundProc: - """ Helper for ensuring we always have only 1 active devnet server running in background """ + """Helper for ensuring we always have only 1 active devnet server running in background""" + def __init__(self): self.proc = None def start(self, *args, stderr=None, stdout=None): - """ Starts a new devnet-server instance. Previously active instance will be stopped. """ + """Starts a new devnet-server instance. Previously active instance will be stopped.""" self.stop() self.proc = run_devnet_in_background(*args, stderr=stderr, stdout=stdout) return self.proc def stop(self): - """ Stops the currently active devnet-server instance """ + """Stops the currently active devnet-server instance""" if self.proc: terminate_and_wait(self.proc) self.proc = None + ACTIVE_DEVNET = DevnetBackgroundProc() + @pytest.fixture(autouse=True) def run_before_and_after_test(): """Cleanup after tests finish.""" @@ -53,39 +64,46 @@ def run_before_and_after_test(): if path.endswith(".pkl"): os.remove(path) -def send_dump_request(dump_path: str=None): + +def send_dump_request(dump_path: str = None): """Send HTTP request to trigger dumping.""" - json_load = { "path": dump_path } if dump_path else None + json_load = {"path": dump_path} if dump_path else None return requests.post(f"{APP_URL}/dump", json=json_load) -def send_load_request(load_path: str=None): + +def send_load_request(load_path: str = None): """Send HTTP request to trigger loading.""" - json_load = { "path": load_path } if load_path else None + json_load = {"path": load_path} if load_path else None return requests.post(f"{APP_URL}/load", json=json_load) + def send_error_request(): """Send HTTP request to trigger error response.""" - json_body = { "dummy": "dummy_value" } + json_body = {"dummy": "dummy_value"} return requests.post(f"{APP_URL}/dump", json=json_body) + def assert_dump_present(dump_path: str, sleep_seconds=2): """Assert there is a non-empty dump file.""" time.sleep(sleep_seconds) assert os.path.isfile(dump_path) assert os.path.getsize(dump_path) > 0 + def assert_no_dump_present(dump_path: str, sleep_seconds=2): """Assert there is no dump file.""" time.sleep(sleep_seconds) assert not os.path.isfile(dump_path) -def dump_and_assert(dump_path: str=None): + +def dump_and_assert(dump_path: str = None): """Assert no dump file before dump and assert some dump file after dump.""" assert_no_dump_present(dump_path) resp = send_dump_request(dump_path) assert resp.status_code == 200 assert_dump_present(dump_path) + def assert_not_alive(): """Assert devnet is not alive.""" try: @@ -94,6 +112,7 @@ def assert_not_alive(): except requests.exceptions.ConnectionError: pass + def deploy_empty_contract(): """ Deploy sample contract with balance = 0. @@ -105,18 +124,18 @@ def deploy_empty_contract(): assert initial_balance == "0" return contract_address + def test_load_via_cli_if_no_file(): """Test loading via CLI if dump file not present.""" assert_no_dump_present(DUMP_PATH) devnet_proc = ACTIVE_DEVNET.start( - "--load-path", DUMP_PATH, - "--accounts", "0", - stderr=subprocess.PIPE + "--load-path", DUMP_PATH, "--accounts", "0", stderr=subprocess.PIPE ) assert devnet_proc.returncode == 1 expected_msg = f"Error: Cannot load from {DUMP_PATH}. Make sure the file exists and contains a Devnet dump.\n" assert expected_msg == devnet_proc.stderr.read().decode("utf-8") + def test_mint_after_load(): """Assert that minting can be done after loading.""" devnet_proc = ACTIVE_DEVNET.start("--dump-path", DUMP_PATH, "--dump-on", "exit") @@ -137,6 +156,7 @@ def test_mint_after_load(): terminate_and_wait(loaded_devnet_proc) + @devnet_in_background() def test_load_via_http_if_no_file(): """Test loading via HTTP if dump file not present.""" @@ -147,36 +167,42 @@ def test_load_via_http_if_no_file(): assert resp.json()["message"] == expected_msg assert resp.status_code == 400 + @devnet_in_background() def test_dumping_if_path_not_provided(): """Assert failure if dumping attempted without a known path.""" resp = send_dump_request() assert resp.status_code == 400 + NONEXISTENT_DIR = "nonexistent-dir" + def test_dumping_if_nonexistent_dir_via_cli(): """Assert failure if dumping attempted via cli with a path containing a nonexistent dir""" invalid_path = os.path.join(NONEXISTENT_DIR, DUMP_PATH) devnet_proc = ACTIVE_DEVNET.start( - "--dump-path", invalid_path, - "--accounts", "0", - stderr=subprocess.PIPE + "--dump-path", invalid_path, "--accounts", "0", stderr=subprocess.PIPE ) assert devnet_proc.returncode == 1 expected_msg = f"Invalid dump path: directory '{NONEXISTENT_DIR}' not found.\n" assert expected_msg == devnet_proc.stderr.read().decode("utf-8") + @devnet_in_background() def test_dumping_if_nonexistent_dir_via_http(): """Assert failure if dumping attempted via http with a path containing a nonexistent dir""" invalid_path = os.path.join(NONEXISTENT_DIR, DUMP_PATH) resp = send_dump_request(dump_path=invalid_path) - assert resp.json()["message"] == f"Invalid dump path: directory '{NONEXISTENT_DIR}' not found." + assert ( + resp.json()["message"] + == f"Invalid dump path: directory '{NONEXISTENT_DIR}' not found." + ) assert resp.status_code == 400 + @devnet_in_background("--dump-path", DUMP_PATH) def test_dumping_if_path_provided_as_cli_option(): """Test dumping if path provided as CLI option""" @@ -184,6 +210,7 @@ def test_dumping_if_path_provided_as_cli_option(): assert resp.status_code == 200 assert_dump_present(DUMP_PATH) + def test_loading_via_cli(): """Test dumping via endpoint and loading via CLI.""" # init devnet + contract @@ -207,13 +234,16 @@ def test_loading_via_cli(): # assure that new invokes can be made invoke("increase_balance", ["15", "25"], contract_address, ABI_PATH) - balance_after_invoke_on_loaded = call("get_balance", contract_address, abi_path=ABI_PATH) + balance_after_invoke_on_loaded = call( + "get_balance", contract_address, abi_path=ABI_PATH + ) assert balance_after_invoke_on_loaded == "70" os.remove(DUMP_PATH) ACTIVE_DEVNET.stop() assert_no_dump_present(DUMP_PATH) + def test_dumping_and_loading_via_endpoint(): """Test dumping and loading via endpoint.""" # init devnet + contract @@ -238,13 +268,16 @@ def test_dumping_and_loading_via_endpoint(): # assure that new invokes can be made invoke("increase_balance", ["15", "25"], contract_address, ABI_PATH) - balance_after_invoke_on_loaded = call("get_balance", contract_address, abi_path=ABI_PATH) + balance_after_invoke_on_loaded = call( + "get_balance", contract_address, abi_path=ABI_PATH + ) assert balance_after_invoke_on_loaded == "70" os.remove(DUMP_PATH) ACTIVE_DEVNET.stop() assert_no_dump_present(DUMP_PATH) + def test_dumping_on_exit(): """Test dumping on exit.""" devnet_proc = ACTIVE_DEVNET.start("--dump-on", "exit", "--dump-path", DUMP_PATH) @@ -256,21 +289,27 @@ def test_dumping_on_exit(): assert balance_after_invoke == "30" assert_no_dump_present(DUMP_PATH) - devnet_proc.send_signal(signal.SIGINT) # send SIGINT because devnet doesn't handle SIGKILL + devnet_proc.send_signal( + signal.SIGINT + ) # send SIGINT because devnet doesn't handle SIGKILL assert_dump_present(DUMP_PATH, sleep_seconds=3) + def test_invalid_dump_on_option(): """Test behavior when invalid dump-on is provided.""" devnet_proc = ACTIVE_DEVNET.start( - "--dump-on", "obviously-invalid", - "--dump-path", DUMP_PATH, - stderr=subprocess.PIPE + "--dump-on", + "obviously-invalid", + "--dump-path", + DUMP_PATH, + stderr=subprocess.PIPE, ) assert devnet_proc.returncode == 1 expected_msg = b"Error: Invalid --dump-on option: obviously-invalid. Valid options: exit, transaction\n" assert devnet_proc.stderr.read() == expected_msg + def test_dump_path_not_present_with_dump_on_present(): """Test behavior when dump-path is not present and dump-on is.""" devnet_proc = ACTIVE_DEVNET.start("--dump-on", "exit", stderr=subprocess.PIPE) @@ -279,6 +318,7 @@ def test_dump_path_not_present_with_dump_on_present(): expected_msg = b"Error: --dump-path required if --dump-on present\n" assert devnet_proc.stderr.read() == expected_msg + def assert_load(dump_path: str, contract_address: str, expected_value: str): """Load from `dump_path` and assert get_balance at `contract_address` returns `expected_value`.""" @@ -287,6 +327,7 @@ def assert_load(dump_path: str, contract_address: str, expected_value: str): ACTIVE_DEVNET.stop() os.remove(dump_path) + def test_dumping_on_each_tx(): """Test dumping on each transaction.""" ACTIVE_DEVNET.start("--dump-on", "transaction", "--dump-path", DUMP_PATH) @@ -308,6 +349,7 @@ def test_dumping_on_each_tx(): assert_load(dump_after_deploy_path, contract_address, "0") assert_load(dump_after_invoke_path, contract_address, "10") + @devnet_in_background() def test_dumping_call_with_invalid_body(): """Call with invalid body and test status code and message.""" diff --git a/test/test_endpoints.py b/test/test_endpoints.py index 561dc03e7..4b8a74ee1 100644 --- a/test/test_endpoints.py +++ b/test/test_endpoints.py @@ -19,35 +19,42 @@ INVALID_HASH = "0x58d4d4ed7580a7a98ab608883ec9fe722424ce52c19f2f369eeea301f535914" INVALID_ADDRESS = "0x123" + def send_transaction(req_dict: dict): """Sends the dict in a POST request and returns the response data.""" return app.test_client().post( "/gateway/add_transaction", content_type="application/json", - data=json.dumps(req_dict) + data=json.dumps(req_dict), ) + def send_call(req_dict: dict): """Sends the call dict in a POST request and returns the response data.""" return app.test_client().post( "/feeder_gateway/call_contract", content_type="application/json", - data=json.dumps(req_dict) + data=json.dumps(req_dict), ) + def assert_deploy_resp(resp: bytes): """Asserts the validity of deploy response body.""" resp_dict = json.loads(resp.data.decode("utf-8")) assert set(resp_dict.keys()) == set(["address", "code", "transaction_hash"]) assert resp_dict["code"] == "TRANSACTION_RECEIVED" + def assert_invoke_resp(resp: bytes): """Asserts the validity of invoke response body.""" resp_dict = json.loads(resp.data.decode("utf-8")) - assert set(resp_dict.keys()) == set(["address", "code", "transaction_hash", "result"]) + assert set(resp_dict.keys()) == set( + ["address", "code", "transaction_hash", "result"] + ) assert resp_dict["code"] == "TRANSACTION_RECEIVED" assert resp_dict["result"] == [] + @pytest.mark.deploy def test_deploy_without_calldata(): """Deploy with complete request data""" @@ -56,16 +63,16 @@ def test_deploy_without_calldata(): resp = send_transaction(req_dict) assert resp.status_code == 400 + @pytest.mark.deploy def test_deploy_with_complete_request_data(): """Deploy without calldata""" resp = app.test_client().post( - "/gateway/add_transaction", - content_type="application/json", - data=DEPLOY_CONTENT + "/gateway/add_transaction", content_type="application/json", data=DEPLOY_CONTENT ) assert_deploy_resp(resp) + @pytest.mark.invoke def test_invoke_without_signature(): """Invoke without signature""" @@ -74,6 +81,7 @@ def test_invoke_without_signature(): resp = send_transaction(req_dict) assert resp.status_code == 400 + @pytest.mark.invoke def test_invoke_without_calldata(): """Invoke without calldata""" @@ -82,6 +90,7 @@ def test_invoke_without_calldata(): resp = send_transaction(req_dict) assert resp.status_code == 400 + @pytest.mark.invoke def test_invoke_with_complete_request_data(): """Invoke with complete request data""" @@ -89,6 +98,7 @@ def test_invoke_with_complete_request_data(): resp = send_transaction(req_dict) assert_invoke_resp(resp) + @pytest.mark.call def test_call_without_signature(): """Call without signature""" @@ -97,6 +107,7 @@ def test_call_without_signature(): resp = send_call(req_dict) assert resp.status_code == 400 + @pytest.mark.call def test_call_without_calldata(): """Call without calldata""" @@ -105,29 +116,31 @@ def test_call_without_calldata(): resp = send_call(req_dict) assert resp.status_code == 400 + @pytest.mark.call def test_call_with_complete_request_data(): """Call with complete request data""" req_dict = json.loads(CALL_CONTENT) resp = send_call(req_dict) resp_dict = json.loads(resp.data.decode("utf-8")) - assert resp_dict == { "result": ["0xa"] } + assert resp_dict == {"result": ["0xa"]} + # Error response tests def send_transaction_with_requests(req_dict: dict): """Sends the dict in a POST request and returns the response data.""" return requests.post( - f"{APP_URL}/gateway/add_transaction", - json=json.dumps(req_dict) + f"{APP_URL}/gateway/add_transaction", json=json.dumps(req_dict) ) + def send_call_with_requests(req_dict: dict): """Sends the call dict in a POST request and returns the response data.""" return requests.post( - f"{APP_URL}/feeder_gateway/call_contract", - json=json.dumps(req_dict) + f"{APP_URL}/feeder_gateway/call_contract", json=json.dumps(req_dict) ) + def get_block_by_number(req_dict: dict): """Get block number from request dict""" block_number = req_dict["blockNumber"] @@ -135,42 +148,51 @@ def get_block_by_number(req_dict: dict): f"{APP_URL}/feeder_gateway/get_block?blockNumber={block_number}" ) -def get_transaction_trace(transaction_hash:str): + +def get_transaction_trace(transaction_hash: str): """Get transaction trace from request dict""" return requests.get( f"{APP_URL}/feeder_gateway/get_transaction_trace?transactionHash={transaction_hash}" ) + def get_full_contract(contract_adress): """Get full contract class of a contract at a specific address""" return requests.get( f"{APP_URL}/feeder_gateway/get_full_contract?contractAddress={contract_adress}" ) + def get_class_by_hash(class_hash: str): """Get contract class by class hash""" return requests.get( f"{APP_URL}/feeder_gateway/get_class_by_hash?classHash={class_hash}" ) + def get_class_hash_at(contract_address: str): """Get class hash of a contract at the provided address""" return requests.get( f"{APP_URL}/feeder_gateway/get_class_hash_at?contractAddress={contract_address}" ) + def get_state_update(block_hash, block_number): """Get state update""" return requests.get( f"{APP_URL}/feeder_gateway/get_state_update?blockHash={block_hash}&blockNumber={block_number}" ) + def get_transaction_status(tx_hash): """Get transaction status""" - response = requests.get(f"{APP_URL}/feeder_gateway/get_transaction_status?transactionHash={tx_hash}") + response = requests.get( + f"{APP_URL}/feeder_gateway/get_transaction_status?transactionHash={tx_hash}" + ) assert response.status_code == 200 return response.json() + @pytest.mark.deploy @devnet_in_background() def test_error_response_deploy_without_calldata(): @@ -183,6 +205,7 @@ def test_error_response_deploy_without_calldata(): msg = "Invalid tx:" assert msg in json_error_message + @pytest.mark.call @devnet_in_background() def test_error_response_call_without_calldata(): @@ -195,6 +218,7 @@ def test_error_response_call_without_calldata(): assert resp.status_code == 400 assert json_error_message is not None + @pytest.mark.call @devnet_in_background() def test_error_response_call_with_negative_block_number(): @@ -205,6 +229,7 @@ def test_error_response_call_with_negative_block_number(): assert resp.status_code == 500 assert json_error_message is not None + @pytest.mark.call @devnet_in_background() def test_error_response_call_with_invalid_transaction_hash(): @@ -216,6 +241,7 @@ def test_error_response_call_with_invalid_transaction_hash(): assert resp.status_code == 500 assert json_error_message.startswith(msg) + @pytest.mark.call @devnet_in_background() def test_error_response_call_with_unavailable_contract(): @@ -226,6 +252,7 @@ def test_error_response_call_with_unavailable_contract(): assert resp.status_code == 500 assert json_error_message is not None + @pytest.mark.call @devnet_in_background() def test_error_response_call_with_state_update(): @@ -236,6 +263,7 @@ def test_error_response_call_with_state_update(): assert resp.status_code == 500 assert json_error_message is not None + @devnet_in_background() def test_error_response_class_hash_at(): """Get class hash of invalid address""" @@ -247,6 +275,7 @@ def test_error_response_class_hash_at(): expected_message = f"Contract with address {INVALID_ADDRESS} is not deployed" assert expected_message == error_message + @devnet_in_background() def test_error_response_class_by_hash(): """Get class by invalid hash""" @@ -258,6 +287,7 @@ def test_error_response_class_by_hash(): expected_message = f"Class with hash {INVALID_HASH} is not declared" assert expected_message == error_message + @devnet_in_background() def test_create_block_endpoint(): """test empty block creationn""" @@ -280,14 +310,18 @@ def test_create_block_endpoint(): assert resp.get("block_number") == GENESIS_BLOCK_NUMBER + 3 assert resp.get("block_hash") == hex(GENESIS_BLOCK_NUMBER + 3) + @devnet_in_background() def test_get_transaction_status(): """Assert valid response schema""" - #Create Transaction - response = requests.post(f"{APP_URL}/mint", json={ - "address": "0x0513493b4Fe460031d445fFACacACf3B19196a05Fd146Ed1609B7248101eF847", - "amount": 1000e18 - }) + # Create Transaction + response = requests.post( + f"{APP_URL}/mint", + json={ + "address": "0x0513493b4Fe460031d445fFACacACf3B19196a05Fd146Ed1609B7248101eF847", + "amount": 1000e18, + }, + ) assert response.status_code == 200 tx_hash = response.json().get("tx_hash") diff --git a/test/test_estimate_fee.py b/test/test_estimate_fee.py index bb5c787ce..8b72c4862 100644 --- a/test/test_estimate_fee.py +++ b/test/test_estimate_fee.py @@ -6,30 +6,23 @@ import requests from starknet_devnet.constants import DEFAULT_GAS_PRICE -from .util import ( - deploy, - devnet_in_background, - load_file_content -) +from .util import deploy, devnet_in_background, load_file_content from .settings import APP_URL from .shared import CONTRACT_PATH DEPLOY_CONTENT = load_file_content("deploy.json") INVOKE_CONTENT = load_file_content("invoke.json") + def estimate_fee_local(req_dict: dict): """Estimate fee of a given transaction""" - return requests.post( - f"{APP_URL}/feeder_gateway/estimate_fee", - json=req_dict - ) + return requests.post(f"{APP_URL}/feeder_gateway/estimate_fee", json=req_dict) + def send_estimate_fee_with_requests(req_dict: dict): """Sends the estimate fee dict in a POST request and returns the response data.""" - return requests.post( - f"{APP_URL}/feeder_gateway/estimate_fee", - json=req_dict - ) + return requests.post(f"{APP_URL}/feeder_gateway/estimate_fee", json=req_dict) + def common_estimate_response(response): """expected response from estimate_fee request""" @@ -38,25 +31,31 @@ def common_estimate_response(response): assert response_parsed.get("gas_price") == DEFAULT_GAS_PRICE assert isinstance(response_parsed.get("gas_usage"), int) - assert response_parsed.get("overall_fee") == response_parsed.get("gas_price") * response_parsed.get("gas_usage") + assert response_parsed.get("overall_fee") == response_parsed.get( + "gas_price" + ) * response_parsed.get("gas_usage") assert response_parsed.get("unit") == "wei" + @devnet_in_background() def test_estimate_fee_with_genesis_block(): """Call without transaction, expect pass with gas_price zero""" - response = send_estimate_fee_with_requests({ - "entry_point_selector": "0x2f0b3c5710379609eb5495f1ecd348cb28167711b73609fe565a72734550354", - "calldata": [ - "1786654640273905855542517570545751199272449814774211541121677632577420730552", - "1000000000000000000000", - "0" - ], - "signature": [], - "contract_address": "0x62230ea046a9a5fbc261ac77d03c8d41e5d442db2284587570ab46455fd2488" - }) + response = send_estimate_fee_with_requests( + { + "entry_point_selector": "0x2f0b3c5710379609eb5495f1ecd348cb28167711b73609fe565a72734550354", + "calldata": [ + "1786654640273905855542517570545751199272449814774211541121677632577420730552", + "1000000000000000000000", + "0", + ], + "signature": [], + "contract_address": "0x62230ea046a9a5fbc261ac77d03c8d41e5d442db2284587570ab46455fd2488", + } + ) common_estimate_response(response) + @pytest.mark.estimate_fee @devnet_in_background() def test_estimate_fee_in_unknown_address(): @@ -69,6 +68,7 @@ def test_estimate_fee_in_unknown_address(): assert resp.status_code == 500 assert json_error_message.startswith("Contract with address") + @pytest.mark.estimate_fee @devnet_in_background() def test_estimate_fee_with_invalid_data(): @@ -80,6 +80,7 @@ def test_estimate_fee_with_invalid_data(): assert resp.status_code == 400 assert "Invalid Starknet function call" in json_error_message + @pytest.mark.estimate_fee @devnet_in_background("--gas-price", str(DEFAULT_GAS_PRICE)) def test_estimate_fee_with_complete_request_data(): @@ -87,13 +88,15 @@ def test_estimate_fee_with_complete_request_data(): deploy_info = deploy(CONTRACT_PATH, ["0"]) # increase balance with 10+20 - response = send_estimate_fee_with_requests({ - "contract_address": deploy_info["address"], - "version": "0x100000000000000000000000000000000", - "signature": [], - "calldata": ["10", "20"], - "max_fee": "0x0", - "entry_point_selector": "0x362398bec32bc0ebb411203221a35a0301193a96f317ebe5e40be9f60d15320" - }) + response = send_estimate_fee_with_requests( + { + "contract_address": deploy_info["address"], + "version": "0x100000000000000000000000000000000", + "signature": [], + "calldata": ["10", "20"], + "max_fee": "0x0", + "entry_point_selector": "0x362398bec32bc0ebb411203221a35a0301193a96f317ebe5e40be9f60d15320", + } + ) common_estimate_response(response) diff --git a/test/test_fee_token.py b/test/test_fee_token.py index 8e6feada4..986cf2072 100644 --- a/test/test_fee_token.py +++ b/test/test_fee_token.py @@ -1,7 +1,13 @@ """Fee token related tests.""" from test.settings import APP_URL -from test.test_account import deploy_empty_contract, execute, assert_tx_status, get_transaction_receipt, get_account_balance +from test.test_account import ( + deploy_empty_contract, + execute, + assert_tx_status, + get_transaction_receipt, + get_account_balance, +) from test.shared import GENESIS_BLOCK_NUMBER import json import pytest @@ -10,120 +16,114 @@ from starknet_devnet.server import app from .util import assert_equal, devnet_in_background, get_block + @pytest.mark.fee_token def test_precomputed_address_unchanged(): """Assert that the precomputed fee_token address is unchanged.""" - assert_equal(FeeToken.ADDRESS, 2774287484619332564597403632816768868845110259953541691709975889937073775752) + assert_equal( + FeeToken.ADDRESS, + 2774287484619332564597403632816768868845110259953541691709975889937073775752, + ) + @pytest.mark.fee_token def test_fee_token_address(): """Sends fee token request;""" response = app.test_client().get("/fee_token") assert response.status_code == 200 - assert response.json.get("address") == "0x62230ea046a9a5fbc261ac77d03c8d41e5d442db2284587570ab46455fd2488" + assert ( + response.json.get("address") + == "0x62230ea046a9a5fbc261ac77d03c8d41e5d442db2284587570ab46455fd2488" + ) assert response.json.get("symbol") == "ETH" def mint(address: str, amount: int, lite=False): """Sends mint request; returns parsed json body""" - response = requests.post(f"{APP_URL}/mint", json={ - "address": address, - "amount": amount, - "lite": lite - }) + response = requests.post( + f"{APP_URL}/mint", json={"address": address, "amount": amount, "lite": lite} + ) assert response.status_code == 200 return response.json() + def mint_client(data: dict): """Send mint request to app test client""" return app.test_client().post( - "/mint", - content_type="application/json", - data=json.dumps(data) + "/mint", content_type="application/json", data=json.dumps(data) ) + def test_negative_mint(): """Assert failure if mint amount negative""" - resp = mint_client({ - "amount": -10, - "address": "0x1" - }) + resp = mint_client({"amount": -10, "address": "0x1"}) assert resp.status_code == 400 assert resp.json["message"] == "amount value must be greater than 0." + def test_mint_amount_string(): """Assert failure if mint amount not int""" - resp = mint_client({ - "amount": "abc", - "address": "0x1" - }) + resp = mint_client({"amount": "abc", "address": "0x1"}) assert resp.status_code == 400 assert resp.json["message"] == "amount value must be an integer." + def test_mint_amount_bool(): """Assert failure if mint amount not int""" - resp = mint_client({ - "amount": True, - "address": "0x1" - }) + resp = mint_client({"amount": True, "address": "0x1"}) assert resp.status_code == 400 assert resp.json["message"] == "amount value must be an integer." + def test_mint_amount_scientific(): """Assert failure if mint amount not int""" - resp = mint_client({ - "amount": 10e21, - "address": "0x1" - }) + resp = mint_client({"amount": 10e21, "address": "0x1"}) assert resp.status_code == 200 + def test_mint_amount_integer_float(): """Assert failure if mint amount not int""" - resp = mint_client({ - "amount": 12.00, - "address": "0x1" - }) + resp = mint_client({"amount": 12.00, "address": "0x1"}) assert resp.status_code == 200 + def test_missing_mint_amount(): """Assert failure if mint amount missing""" - resp = mint_client({ - "address": "0x1" - }) + resp = mint_client({"address": "0x1"}) assert resp.status_code == 400 assert resp.json["message"] == "amount value must be provided." + def test_wrong_mint_address_format(): """Assert failure if mint address of wrong format""" - resp = mint_client({ - "amount": 10, - "address": "invalid_address" - }) + resp = mint_client({"amount": 10, "address": "invalid_address"}) assert resp.status_code == 400 assert resp.json["message"] == "address value must be a hex string." + def test_missing_mint_address(): """Assert failure if mint address missing""" - resp = mint_client({ - "amount": 10 - }) + resp = mint_client({"amount": 10}) assert resp.status_code == 400 assert resp.json["message"] == "address value must be provided." + @pytest.mark.fee_token @devnet_in_background() def test_mint(): """Assert that mint will increase account balance and latest block created with correct transaction amount""" - account_address = "0x6e3205f9b7c4328f00f718fdecf56ab31acfb3cd6ffeb999dcbac4123655502" + account_address = ( + "0x6e3205f9b7c4328f00f718fdecf56ab31acfb3cd6ffeb999dcbac4123655502" + ) response = mint(address=account_address, amount=50_000) assert response.get("new_balance") == 50_000 assert response.get("unit") == "wei" @@ -135,6 +135,7 @@ def test_mint(): assert response.json().get("block_number") == GENESIS_BLOCK_NUMBER + 1 assert int(response.json().get("transactions")[0].get("calldata")[1], 16) == 50_000 + @pytest.mark.fee_token @devnet_in_background() def test_mint_lite(): @@ -142,35 +143,47 @@ def test_mint_lite(): response = mint( address="0x34d09711b5c047471fd21d424afbf405c09fd584057e1d69c77223b535cf769", amount=50_000, - lite=True + lite=True, ) assert response.get("new_balance") == 50000 assert response.get("unit") == "wei" assert response.get("tx_hash") is None + @pytest.mark.fee_token @devnet_in_background( - "--accounts", "1", - "--seed", "42", - "--gas-price", "100_000_000", - "--initial-balance", "10" + "--accounts", + "1", + "--seed", + "42", + "--gas-price", + "100_000_000", + "--initial-balance", + "10", ) def test_increase_balance(): """Assert tx failure if insufficient funds; assert tx success after mint""" deploy_info = deploy_empty_contract() - account_address = "0x347be35996a21f6bf0623e75dbce52baba918ad5ae8d83b6f416045ab22961a" - private_key = 0xbdd640fb06671ad11c80317fa3b1799d + account_address = ( + "0x347be35996a21f6bf0623e75dbce52baba918ad5ae8d83b6f416045ab22961a" + ) + private_key = 0xBDD640FB06671AD11C80317FA3B1799D to_address = int(deploy_info["address"], 16) initial_account_balance = get_account_balance(account_address) args = [10, 20] calls = [(to_address, "increase_balance", args)] - invoke_tx_hash = execute(calls, account_address, private_key, max_fee=10 ** 21) # big enough + invoke_tx_hash = execute( + calls, account_address, private_key, max_fee=10**21 + ) # big enough assert_tx_status(invoke_tx_hash, "REJECTED") invoke_receipt = get_transaction_receipt(invoke_tx_hash) - assert "subtraction overflow" in invoke_receipt["transaction_failure_reason"]["error_message"] + assert ( + "subtraction overflow" + in invoke_receipt["transaction_failure_reason"]["error_message"] + ) intermediate_account_balance = get_account_balance(account_address) assert_equal(initial_account_balance, intermediate_account_balance) @@ -180,11 +193,15 @@ def test_increase_balance(): balance_after_mint = get_account_balance(account_address) assert_equal(balance_after_mint, initial_account_balance + mint_amount) - invoke_tx_hash = execute(calls, account_address, private_key, max_fee=10 ** 21) # big enough + invoke_tx_hash = execute( + calls, account_address, private_key, max_fee=10**21 + ) # big enough assert_tx_status(invoke_tx_hash, "ACCEPTED_ON_L2") invoke_receipt = get_transaction_receipt(invoke_tx_hash) actual_fee = int(invoke_receipt["actual_fee"], 16) final_account_balance = get_account_balance(account_address) - assert_equal(final_account_balance, initial_account_balance + mint_amount - actual_fee) + assert_equal( + final_account_balance, initial_account_balance + mint_amount - actual_fee + ) diff --git a/test/test_general_workflow.py b/test/test_general_workflow.py index f67676b5a..d1c6017e3 100644 --- a/test/test_general_workflow.py +++ b/test/test_general_workflow.py @@ -10,12 +10,22 @@ assert_transaction_not_received, assert_transaction_receipt_not_received, devnet_in_background, - assert_block, assert_contract_code, assert_equal, assert_failing_deploy, assert_receipt, assert_salty_deploy, - assert_storage, assert_transaction, assert_tx_status, assert_events, - call, deploy, + assert_block, + assert_contract_code, + assert_equal, + assert_failing_deploy, + assert_receipt, + assert_salty_deploy, + assert_storage, + assert_transaction, + assert_tx_status, + assert_events, + call, + deploy, get_class_by_hash, get_class_hash_at, - get_full_contract, invoke + get_full_contract, + invoke, ) from .shared import ( @@ -28,9 +38,10 @@ EXPECTED_SALTY_DEPLOY_HASH, FAILING_CONTRACT_PATH, GENESIS_BLOCK_NUMBER, - NONEXISTENT_TX_HASH + NONEXISTENT_TX_HASH, ) + @pytest.mark.general_workflow @devnet_in_background() def test_general_workflow(): @@ -68,12 +79,10 @@ def test_general_workflow(): function="increase_balance", address=deploy_info["address"], abi_path=ABI_PATH, - inputs=["10", "20"] + inputs=["10", "20"], ) value = call( - function="get_balance", - address=deploy_info["address"], - abi_path=ABI_PATH + function="get_balance", address=deploy_info["address"], abi_path=ABI_PATH ) assert_equal(value, "30", "Invoke+call failed!") @@ -87,7 +96,7 @@ def test_general_workflow(): function="sum_point_array", address=deploy_info["address"], abi_path=ABI_PATH, - inputs=["2", "10", "20", "30", "40"] + inputs=["2", "10", "20", "30", "40"], ) assert_equal(value, "40 60", "Checking complex input failed!") @@ -98,7 +107,7 @@ def test_general_workflow(): inputs=None, expected_status="ACCEPTED_ON_L2", expected_address=EXPECTED_SALTY_DEPLOY_ADDRESS, - expected_tx_hash=EXPECTED_SALTY_DEPLOY_HASH + expected_tx_hash=EXPECTED_SALTY_DEPLOY_HASH, ) assert_salty_deploy( @@ -107,14 +116,14 @@ def test_general_workflow(): inputs=None, expected_status="ACCEPTED_ON_L2", expected_address=EXPECTED_SALTY_DEPLOY_ADDRESS, - expected_tx_hash=EXPECTED_SALTY_DEPLOY_HASH + expected_tx_hash=EXPECTED_SALTY_DEPLOY_HASH, ) salty_invoke_tx_hash = invoke( function="increase_balance", address=EXPECTED_SALTY_DEPLOY_ADDRESS, abi_path=EVENTS_ABI_PATH, - inputs=["10"] + inputs=["10"], ) assert_events(salty_invoke_tx_hash, "test/expected/invoke_receipt_event.json") diff --git a/test/test_general_workflow_auth.py b/test/test_general_workflow_auth.py index 35c434568..2c602c00b 100644 --- a/test/test_general_workflow_auth.py +++ b/test/test_general_workflow_auth.py @@ -5,9 +5,15 @@ from .util import ( devnet_in_background, - assert_block, assert_equal, assert_receipt, - assert_storage, assert_transaction, assert_tx_status, - call, deploy, invoke + assert_block, + assert_equal, + assert_receipt, + assert_storage, + assert_transaction, + assert_tx_status, + call, + deploy, + invoke, ) from .shared import ARTIFACTS_PATH, GENESIS_BLOCK_NUMBER, SIGNATURE @@ -16,14 +22,19 @@ ABI_PATH = f"{ARTIFACTS_PATH}/auth_contract.cairo/auth_contract_abi.json" # PRIVATE_KEY = "12345" -PUBLIC_KEY = "1628448741648245036800002906075225705100596136133912895015035902954123957052" +PUBLIC_KEY = ( + "1628448741648245036800002906075225705100596136133912895015035902954123957052" +) INITIAL_BALANCE = "1000" SIGNATURE = [ - "1225578735933442828068102633747590437426782890965066746429241472187377583468", - "3568809569741913715045370357918125425757114920266578211811626257903121825123" + "1225578735933442828068102633747590437426782890965066746429241472187377583468", + "3568809569741913715045370357918125425757114920266578211811626257903121825123", ] -BALANCE_KEY = "142452623821144136554572927896792266630776240502820879601186867231282346767" +BALANCE_KEY = ( + "142452623821144136554572927896792266630776240502820879601186867231282346767" +) + @pytest.mark.general_workflow @devnet_in_background() @@ -40,25 +51,26 @@ def test_general_workflow_auth(): assert_block(GENESIS_BLOCK_NUMBER + 1, deploy_info["tx_hash"]) assert_receipt(deploy_info["tx_hash"], "test/expected/deploy_receipt_auth.json") - # increase and assert balance invoke_tx_hash = invoke( function="increase_balance", address=deploy_info["address"], abi_path=ABI_PATH, inputs=[PUBLIC_KEY, "4321"], - signature=SIGNATURE + signature=SIGNATURE, ) value = call( function="get_balance", address=deploy_info["address"], abi_path=ABI_PATH, - inputs=[PUBLIC_KEY] + inputs=[PUBLIC_KEY], ) assert_equal(value, "5321", "Invoke+call failed!") # check storage after deployment assert_storage(deploy_info["address"], BALANCE_KEY, "0x14c9") expected_signature = [hex(int(s)) for s in SIGNATURE] - assert_transaction(invoke_tx_hash, "ACCEPTED_ON_L2", expected_signature=expected_signature) + assert_transaction( + invoke_tx_hash, "ACCEPTED_ON_L2", expected_signature=expected_signature + ) assert_receipt(invoke_tx_hash, "test/expected/invoke_receipt_auth.json") diff --git a/test/test_general_workflow_lite.py b/test/test_general_workflow_lite.py index 198130ba3..e6a7b6ca7 100644 --- a/test/test_general_workflow_lite.py +++ b/test/test_general_workflow_lite.py @@ -10,17 +10,18 @@ devnet_in_background, assert_equal, assert_tx_status, - call, deploy, invoke + call, + deploy, + invoke, ) -from .shared import ( - ABI_PATH, - CONTRACT_PATH, - GENESIS_BLOCK_NUMBER -) +from .shared import ABI_PATH, CONTRACT_PATH, GENESIS_BLOCK_NUMBER NONEXISTENT_TX_HASH = "0x12345678910111213" -BALANCE_KEY = "916907772491729262376534102982219947830828984996257231353398618781993312401" +BALANCE_KEY = ( + "916907772491729262376534102982219947830828984996257231353398618781993312401" +) + @pytest.mark.general_workflow @devnet_in_background("--lite-mode") @@ -42,12 +43,10 @@ def test_general_workflow_lite(): function="increase_balance", address=deploy_info["address"], abi_path=ABI_PATH, - inputs=["10", "20"] + inputs=["10", "20"], ) value = call( - function="get_balance", - address=deploy_info["address"], - abi_path=ABI_PATH + function="get_balance", address=deploy_info["address"], abi_path=ABI_PATH ) assert_equal(value, "30", "Invoke+call failed!") diff --git a/test/test_general_workflow_lite_block_hash.py b/test/test_general_workflow_lite_block_hash.py index 0a0aecbae..4a4df5fab 100644 --- a/test/test_general_workflow_lite_block_hash.py +++ b/test/test_general_workflow_lite_block_hash.py @@ -8,19 +8,21 @@ assert_block_hash, assert_negative_block_input, devnet_in_background, - assert_block, assert_equal, + assert_block, + assert_equal, assert_tx_status, - call, deploy, invoke + call, + deploy, + invoke, ) -from .shared import ( - ABI_PATH, - CONTRACT_PATH, - GENESIS_BLOCK_NUMBER -) +from .shared import ABI_PATH, CONTRACT_PATH, GENESIS_BLOCK_NUMBER NONEXISTENT_TX_HASH = "0x12345678910111213" -BALANCE_KEY = "916907772491729262376534102982219947830828984996257231353398618781993312401" +BALANCE_KEY = ( + "916907772491729262376534102982219947830828984996257231353398618781993312401" +) + @pytest.mark.general_workflow @devnet_in_background("--lite-mode-block-hash") @@ -43,12 +45,10 @@ def test_general_workflow_lite(): function="increase_balance", address=deploy_info["address"], abi_path=ABI_PATH, - inputs=["10", "20"] + inputs=["10", "20"], ) value = call( - function="get_balance", - address=deploy_info["address"], - abi_path=ABI_PATH + function="get_balance", address=deploy_info["address"], abi_path=ABI_PATH ) assert_equal(value, "30", "Invoke+call failed!") diff --git a/test/test_general_workflow_lite_deploy_hash.py b/test/test_general_workflow_lite_deploy_hash.py index 845401ef9..ced667389 100644 --- a/test/test_general_workflow_lite_deploy_hash.py +++ b/test/test_general_workflow_lite_deploy_hash.py @@ -8,7 +8,9 @@ devnet_in_background, assert_equal, assert_tx_status, - call, deploy, invoke + call, + deploy, + invoke, ) from .shared import ( @@ -16,7 +18,10 @@ CONTRACT_PATH, ) -BALANCE_KEY = "916907772491729262376534102982219947830828984996257231353398618781993312401" +BALANCE_KEY = ( + "916907772491729262376534102982219947830828984996257231353398618781993312401" +) + @pytest.mark.general_workflow @devnet_in_background("--lite-mode-deploy-hash") @@ -27,18 +32,16 @@ def test_general_workflow_lite(): print("Deployment:", deploy_info) assert_tx_status(deploy_info["tx_hash"], "ACCEPTED_ON_L2") - assert_equal(deploy_info["tx_hash"],"0x0") + assert_equal(deploy_info["tx_hash"], "0x0") # increase and assert balance invoke( function="increase_balance", address=deploy_info["address"], abi_path=ABI_PATH, - inputs=["10", "20"] + inputs=["10", "20"], ) value = call( - function="get_balance", - address=deploy_info["address"], - abi_path=ABI_PATH + function="get_balance", address=deploy_info["address"], abi_path=ABI_PATH ) assert_equal(value, "30", "Invoke+call failed!") diff --git a/test/test_postman.py b/test/test_postman.py index c8e263008..7483855ca 100644 --- a/test/test_postman.py +++ b/test/test_postman.py @@ -7,7 +7,15 @@ from test.web3_util import web3_call, web3_deploy, web3_transact from test.settings import APP_URL, L1_HOST, L1_PORT, L1_URL -from test.util import call, deploy, devnet_in_background, ensure_server_alive, invoke, load_file_content, terminate_and_wait +from test.util import ( + call, + deploy, + devnet_in_background, + ensure_server_alive, + invoke, + load_file_content, + terminate_and_wait, +) import psutil import pytest @@ -22,9 +30,12 @@ ABI_PATH = f"{ARTIFACTS_PATH}/l1l2.cairo/l1l2_abi.json" ETH_CONTRACTS_PATH = "artifacts/contracts/solidity" -STARKNET_MESSAGING_PATH = f"{ETH_CONTRACTS_PATH}/MockStarknetMessaging.sol/MockStarknetMessaging.json" +STARKNET_MESSAGING_PATH = ( + f"{ETH_CONTRACTS_PATH}/MockStarknetMessaging.sol/MockStarknetMessaging.json" +) L1L2_EXAMPLE_PATH = f"{ETH_CONTRACTS_PATH}/L1L2.sol/L1L2Example.json" + @pytest.fixture(autouse=True) def run_before_and_after_test(): """Run l1 testnet before and kill it after the test run""" @@ -46,27 +57,43 @@ def run_before_and_after_test(): print("Children after killing", wrapped_node_proc.children(recursive=True)) terminate_and_wait(node_proc) + def flush(): """Flushes the postman messages. Returns response data""" - res = requests.post( - f"{APP_URL}/postman/flush" - ) + res = requests.post(f"{APP_URL}/postman/flush") return res.json() -def assert_flush_response(response, expected_from_l1, expected_from_l2, expected_l1_provider): + +def assert_flush_response( + response, expected_from_l1, expected_from_l2, expected_l1_provider +): """Asserts that the flush response is correct""" assert response["l1_provider"] == expected_l1_provider for i, l1_message in enumerate(response["consumed_messages"]["from_l1"]): - assert l1_message["args"]["from_address"] == expected_from_l1[i]["args"]["from_address"] - assert l1_message["args"]["to_address"] == expected_from_l1[i]["args"]["to_address"] - assert l1_message["args"]["payload"] == [hex(x) for x in expected_from_l1[i]["args"]["payload"]] + assert ( + l1_message["args"]["from_address"] + == expected_from_l1[i]["args"]["from_address"] + ) + assert ( + l1_message["args"]["to_address"] + == expected_from_l1[i]["args"]["to_address"] + ) + assert l1_message["args"]["payload"] == [ + hex(x) for x in expected_from_l1[i]["args"]["payload"] + ] # check if correct keys are present expected_keys = [ - "block_hash", "block_number", "transaction_hash", "transaction_index", "address", "event", "log_index" + "block_hash", + "block_number", + "transaction_hash", + "transaction_index", + "address", + "event", + "log_index", ] for key in expected_keys: @@ -86,12 +113,10 @@ def assert_flush_response(response, expected_from_l1, expected_from_l2, expected def init_messaging_contract(): """Initializes the messaging contract""" - deploy_messaging_contract_request = { - "networkUrl": L1_URL - } + deploy_messaging_contract_request = {"networkUrl": L1_URL} resp = requests.post( f"{APP_URL}/postman/load_l1_messaging_contract", - json=deploy_messaging_contract_request + json=deploy_messaging_contract_request, ) return json.loads(resp.text) @@ -105,8 +130,12 @@ def deploy_l1_contracts(web3): # Min amount of time in seconds for a message to be able to be cancelled l1_message_cancellation_delay = 0 # Deploys a new mock contract so that the feature for loading an already deployed messaging contract can be tested - starknet_messaging_contract = web3_deploy(web3, messaging_contract, l1_message_cancellation_delay) - l1l2_example = web3_deploy(web3,l1l2_example_contract,starknet_messaging_contract.address) + starknet_messaging_contract = web3_deploy( + web3, messaging_contract, l1_message_cancellation_delay + ) + l1l2_example = web3_deploy( + web3, l1l2_example_contract, starknet_messaging_contract.address + ) return starknet_messaging_contract, l1l2_example @@ -116,18 +145,19 @@ def load_messaging_contract(starknet_messaging_contract_address): load_messaging_contract_request = { "networkUrl": L1_URL, - "address": starknet_messaging_contract_address + "address": starknet_messaging_contract_address, } resp = requests.post( f"{APP_URL}/postman/load_l1_messaging_contract", - json=load_messaging_contract_request + json=load_messaging_contract_request, ) return json.loads(resp.text) + def init_l2_contract(l1l2_example_contract_address): - """Deploys the L1L2Example cairo contract, returns the result of calling 'get_balance' """ + """Deploys the L1L2Example cairo contract, returns the result of calling 'get_balance'""" deploy_info = deploy(CONTRACT_PATH) @@ -136,13 +166,13 @@ def init_l2_contract(l1l2_example_contract_address): function="increase_balance", address=deploy_info["address"], abi_path=ABI_PATH, - inputs=["1","3333"] + inputs=["1", "3333"], ) invoke( function="withdraw", address=deploy_info["address"], abi_path=ABI_PATH, - inputs=["1","1000",l1l2_example_contract_address] + inputs=["1", "1000", l1l2_example_contract_address], ) # flush L2 to L1 messages @@ -151,33 +181,33 @@ def init_l2_contract(l1l2_example_contract_address): assert_flush_response( response=flush_response, expected_from_l1=[], - expected_from_l2=[{ - "from_address": deploy_info["address"], - "to_address": l1l2_example_contract_address, - "payload": [0, 1, 1000] # MESSAGE_WITHDRAW, user, amount - }], - expected_l1_provider=L1_URL + expected_from_l2=[ + { + "from_address": deploy_info["address"], + "to_address": l1l2_example_contract_address, + "payload": [0, 1, 1000], # MESSAGE_WITHDRAW, user, amount + } + ], + expected_l1_provider=L1_URL, ) - #assert balance + # assert balance value = call( function="get_balance", address=deploy_info["address"], abi_path=ABI_PATH, - inputs=["1"] + inputs=["1"], ) assert value == "2333" return deploy_info["address"] + def l1_l2_message_exchange(web3, l1l2_example_contract, l2_contract_address): """Tests message exchange""" # assert contract balance when starting - balance = web3_call( - "userBalances", - l1l2_example_contract, - 1) + balance = web3_call("userBalances", l1l2_example_contract, 1) assert balance == 0 # withdraw in l1 and assert contract balance @@ -185,12 +215,12 @@ def l1_l2_message_exchange(web3, l1l2_example_contract, l2_contract_address): web3, "withdraw", l1l2_example_contract, - int(l2_contract_address,base=16), 1, 1000) + int(l2_contract_address, base=16), + 1, + 1000, + ) - balance = web3_call( - "userBalances", - l1l2_example_contract, - 1) + balance = web3_call("userBalances", l1l2_example_contract, 1) assert balance == 1000 # assert l2 contract balance @@ -198,7 +228,7 @@ def l1_l2_message_exchange(web3, l1l2_example_contract, l2_contract_address): function="get_balance", address=l2_contract_address, abi_path=ABI_PATH, - inputs=["1"] + inputs=["1"], ) assert l2_balance == "2333" @@ -208,12 +238,12 @@ def l1_l2_message_exchange(web3, l1l2_example_contract, l2_contract_address): web3, "deposit", l1l2_example_contract, - int(l2_contract_address,base=16), 1, 600) + int(l2_contract_address, base=16), + 1, + 600, + ) - balance = web3_call( - "userBalances", - l1l2_example_contract, - 1) + balance = web3_call("userBalances", l1l2_example_contract, 1) assert balance == 400 @@ -222,14 +252,16 @@ def l1_l2_message_exchange(web3, l1l2_example_contract, l2_contract_address): assert_flush_response( response=flush_response, - expected_from_l1=[{ - "address": None, - "args": { - "from_address": l1l2_example_contract.address, - "to_address": l2_contract_address, - "payload": [1, 600] # user, amount + expected_from_l1=[ + { + "address": None, + "args": { + "from_address": l1l2_example_contract.address, + "to_address": l2_contract_address, + "payload": [1, 600], # user, amount + }, } - }], + ], expected_from_l2=[], expected_l1_provider=L1_URL, ) @@ -239,11 +271,12 @@ def l1_l2_message_exchange(web3, l1l2_example_contract, l2_contract_address): function="get_balance", address=l2_contract_address, abi_path=ABI_PATH, - inputs=["1"] + inputs=["1"], ) assert l2_balance == "2933" + @pytest.mark.web3_messaging @devnet_in_background() def test_postman(): @@ -270,16 +303,16 @@ def test_postman(): # Test initializing the l2 example contract l2_contract_address = init_l2_contract(l1l2_example_contract.address) - l1_l2_message_exchange(web3,l1l2_example_contract,l2_contract_address) + l1_l2_message_exchange(web3, l1l2_example_contract, l2_contract_address) def load_l1_messaging_contract(req_dict: dict): """Load L1 messaging contract""" return requests.post( - f"{APP_URL}/postman/load_l1_messaging_contract", - json=(req_dict) + f"{APP_URL}/postman/load_l1_messaging_contract", json=(req_dict) ) + @devnet_in_background() def test_invalid_starknet_function_call_load_l1_messaging_contract(): """Call with invalid data on starknet function call""" diff --git a/test/test_restart.py b/test/test_restart.py index ae49f3879..adbdb7f02 100644 --- a/test/test_restart.py +++ b/test/test_restart.py @@ -6,13 +6,23 @@ import requests from .settings import APP_URL -from .util import devnet_in_background, deploy, assert_transaction_not_received, assert_tx_status, call, get_block, invoke +from .util import ( + devnet_in_background, + deploy, + assert_transaction_not_received, + assert_tx_status, + call, + get_block, + invoke, +) from .shared import CONTRACT_PATH, ABI_PATH, GENESIS_BLOCK_HASH + def restart(): """Get restart response""" return requests.post(f"{APP_URL}/restart") + def get_state_update(): """Get state update""" res = requests.get(f"{APP_URL}/feeder_gateway/get_state_update") @@ -23,6 +33,7 @@ def deploy_contract(salt=None): """Deploy empty contract with balance of 0""" return deploy(CONTRACT_PATH, inputs=["0"], salt=salt) + @pytest.mark.restart @devnet_in_background() def test_restart_on_initial_state(): @@ -51,13 +62,14 @@ def test_transaction(): assert_transaction_not_received(tx_hash=tx_hash) + @pytest.mark.restart @devnet_in_background() def test_contract(): """Checks if contract storage is reset""" salt = "0x99" deploy_info = deploy_contract(salt) - contract_address = deploy_info["address"] + contract_address = deploy_info["address"] balance = call("get_balance", contract_address, ABI_PATH) assert balance == "0" @@ -72,6 +84,7 @@ def test_contract(): balance = call("get_balance", contract_address, ABI_PATH) assert balance == "0" + @pytest.mark.restart @devnet_in_background() def test_state_update(): @@ -88,7 +101,10 @@ def test_state_update(): assert state_update["block_hash"] == GENESIS_BLOCK_HASH + GAS_PRICE = str(int(1e9)) + + @devnet_in_background("--gas-price", GAS_PRICE) def test_gas_price_unaffected_by_restart(): """Checks that gas price is not affected by restart""" diff --git a/test/test_state_update.py b/test/test_state_update.py index 41b73eda6..5a902689d 100644 --- a/test/test_state_update.py +++ b/test/test_state_update.py @@ -9,13 +9,24 @@ from starkware.starknet.public.abi import get_selector_from_name from .util import ( - deploy, invoke, load_contract_class, devnet_in_background, get_block, assert_equal, + deploy, + invoke, + load_contract_class, + devnet_in_background, + get_block, + assert_equal, ) from .settings import APP_URL -from .shared import GENESIS_BLOCK_HASH, STORAGE_CONTRACT_PATH, STORAGE_ABI_PATH, GENESIS_BLOCK_NUMBER +from .shared import ( + GENESIS_BLOCK_HASH, + STORAGE_CONTRACT_PATH, + STORAGE_ABI_PATH, + GENESIS_BLOCK_NUMBER, +) STORAGE_KEY = hex(get_selector_from_name("storage")) + def get_state_update_response(block_hash=None, block_number=None): """Get state update response""" params = { @@ -23,17 +34,16 @@ def get_state_update_response(block_hash=None, block_number=None): "blockNumber": block_number, } - res = requests.get( - f"{APP_URL}/feeder_gateway/get_state_update", - params=params - ) + res = requests.get(f"{APP_URL}/feeder_gateway/get_state_update", params=params) return res + def get_state_update(block_hash=None, block_number=None): """Get state update""" return get_state_update_response(block_hash, block_number).json() + def deploy_empty_contract(): """ Deploy storage contract @@ -44,11 +54,13 @@ def deploy_empty_contract(): return contract_address + def get_contract_hash(): """Get contract hash of the sample contract""" contract_class = load_contract_class(STORAGE_CONTRACT_PATH) return compute_class_hash(contract_class) + @pytest.mark.state_update @devnet_in_background() def test_initial_state_update(): @@ -57,6 +69,7 @@ def test_initial_state_update(): assert_equal(state_update["block_hash"], GENESIS_BLOCK_HASH) + @pytest.mark.state_update @devnet_in_background() def test_deployed_contracts(): @@ -69,10 +82,11 @@ def test_deployed_contracts(): assert_equal(len(deployed_contracts), 1) assert_equal(int(deployed_contracts[0]["address"], 16), int(contract_address, 16)) - deployed_contract_hash = deployed_contracts[0]["class_hash"] + deployed_contract_hash = deployed_contracts[0]["class_hash"] assert_equal(int(deployed_contract_hash, 16), get_contract_hash()) + @pytest.mark.state_update @devnet_in_background() def test_storage_diff(): @@ -101,6 +115,7 @@ def test_storage_diff(): assert_equal(contract_storage_diffs[0]["value"], hex(0)) assert_equal(contract_storage_diffs[0]["key"], STORAGE_KEY) + @pytest.mark.state_update @devnet_in_background() def test_block_hash(): @@ -122,6 +137,7 @@ def test_block_hash(): assert new_state_update["block_hash"] != first_block_hash assert_equal(previous_state_update, initial_state_update) + @pytest.mark.state_update @devnet_in_background() def test_wrong_block_hash(): @@ -130,6 +146,7 @@ def test_wrong_block_hash(): assert_equal(state_update_response.status_code, 500) + @pytest.mark.state_update @devnet_in_background() def test_block_number(): @@ -147,6 +164,7 @@ def test_block_number(): assert_equal(first_block_state_update, initial_state_update) assert_equal(second_block_state_update, new_state_update) + @pytest.mark.state_update @devnet_in_background() def test_wrong_block_number(): @@ -155,6 +173,7 @@ def test_wrong_block_number(): assert_equal(state_update_response.status_code, 500) + @pytest.mark.state_update @devnet_in_background() def test_roots(): diff --git a/test/test_timestamps.py b/test/test_timestamps.py index 2a98ae071..67082621f 100644 --- a/test/test_timestamps.py +++ b/test/test_timestamps.py @@ -19,31 +19,40 @@ SET_TIME_ARGUMENT = 1514764800 + def deploy_ts_contract(): """Deploys the timestamp contract""" return deploy(TS_CONTRACT_PATH) + def get_ts_from_contract(address): """Returns the timestamp of the contract""" - return int(call( - function="get_timestamp", - address=address, - abi_path=TS_ABI_PATH, - )) + return int( + call( + function="get_timestamp", + address=address, + abi_path=TS_ABI_PATH, + ) + ) + def get_ts_from_last_block(): """Returns the timestamp of the last block""" return get_block(parse=True)["timestamp"] + def increase_time(time_s): """Increases the block timestamp offset""" - increase_time_response = requests.post(f"{APP_URL}/increase_time", json={"time": time_s}) + increase_time_response = requests.post( + f"{APP_URL}/increase_time", json={"time": time_s} + ) if increase_time_response.status_code == 200: assert increase_time_response.json().get("timestamp_increased_by") == time_s return increase_time_response + def set_time(time_s): """Sets the block timestamp and offset""" set_time_response = requests.post(f"{APP_URL}/set_time", json={"time": time_s}) @@ -53,6 +62,7 @@ def set_time(time_s): return set_time_response + @pytest.mark.timestamps @devnet_in_background() def test_timestamps(): @@ -75,6 +85,7 @@ def test_timestamps(): assert ts_after_second_deploy == ts_from_second_call assert ts_from_second_call > ts_from_first_call + @pytest.mark.timestamps @devnet_in_background() def test_increase_time(): @@ -133,6 +144,7 @@ def test_set_time(): # check if offset is still the same assert third_block_ts - first_block_ts >= 86400 + @pytest.mark.timestamps @devnet_in_background("--start-time", str(SET_TIME_ARGUMENT)) def test_set_time_argument(): @@ -141,6 +153,7 @@ def test_set_time_argument(): assert first_block_ts == SET_TIME_ARGUMENT + @pytest.mark.timestamps @devnet_in_background() def test_set_time_errors(): @@ -165,6 +178,7 @@ def test_set_time_errors(): assert response.status_code == 400 assert message == "time value must be an integer." + @pytest.mark.timestamps @devnet_in_background() def test_increase_time_errors(): @@ -196,11 +210,12 @@ def test_block_info_generator(): start = int(time.time()) block_info = BlockInfo.create_for_testing(block_number=0, block_timestamp=start) - # Test if start time is set by the constructor generator = BlockInfoGenerator(start_time=10) - block_with_start_time = generator.next_block(block_info=block_info, general_config=DEFAULT_GENERAL_CONFIG) + block_with_start_time = generator.next_block( + block_info=block_info, general_config=DEFAULT_GENERAL_CONFIG + ) assert block_with_start_time.block_timestamp == 10 @@ -208,7 +223,9 @@ def test_block_info_generator(): generator.increase_time(22) - block_after_increase = generator.next_block(block_info=block_info, general_config=DEFAULT_GENERAL_CONFIG) + block_after_increase = generator.next_block( + block_info=block_info, general_config=DEFAULT_GENERAL_CONFIG + ) assert block_after_increase.block_timestamp == 32 @@ -217,16 +234,20 @@ def test_block_info_generator(): generator = BlockInfoGenerator() generator.increase_time(1_000_000_000) - block_with_increase_time = generator.next_block(block_info=block_info, general_config=DEFAULT_GENERAL_CONFIG) + block_with_increase_time = generator.next_block( + block_info=block_info, general_config=DEFAULT_GENERAL_CONFIG + ) assert block_with_increase_time.block_timestamp >= 1_000_000_000 + int(time.time()) - generator.set_next_block_time(222) - block_after_set_time = generator.next_block(block_info=block_info, general_config=DEFAULT_GENERAL_CONFIG) + block_after_set_time = generator.next_block( + block_info=block_info, general_config=DEFAULT_GENERAL_CONFIG + ) assert block_after_set_time.block_timestamp == 222 + @pytest.mark.timestamps @devnet_in_background("--lite-mode") def test_lite_mode_compatibility(): diff --git a/test/test_transaction_trace.py b/test/test_transaction_trace.py index 1365a0cc6..770b12ba1 100644 --- a/test/test_transaction_trace.py +++ b/test/test_transaction_trace.py @@ -4,11 +4,27 @@ import pytest import requests -from starkware.starknet.services.api.feeder_gateway.response_objects import BlockTransactionTraces - -from .util import declare, deploy, get_transaction_receipt, invoke, load_json_from_path, devnet_in_background +from starkware.starknet.services.api.feeder_gateway.response_objects import ( + BlockTransactionTraces, +) + +from .util import ( + declare, + deploy, + get_transaction_receipt, + invoke, + load_json_from_path, + devnet_in_background, +) from .settings import APP_URL -from .shared import ABI_PATH, CONTRACT_PATH, SIGNATURE, NONEXISTENT_TX_HASH, GENESIS_BLOCK_NUMBER +from .shared import ( + ABI_PATH, + CONTRACT_PATH, + SIGNATURE, + NONEXISTENT_TX_HASH, + GENESIS_BLOCK_NUMBER, +) + def get_transaction_trace_response(tx_hash=None): """Get transaction trace response""" @@ -16,13 +32,11 @@ def get_transaction_trace_response(tx_hash=None): "transactionHash": tx_hash, } - res = requests.get( - f"{APP_URL}/feeder_gateway/get_transaction_trace", - params=params - ) + res = requests.get(f"{APP_URL}/feeder_gateway/get_transaction_trace", params=params) return res + def deploy_empty_contract(): """ Deploy sample contract with balance = 0. @@ -30,11 +44,13 @@ def deploy_empty_contract(): """ return deploy(CONTRACT_PATH, inputs=["0"], salt="0x99") + def assert_function_invocation(function_invocation, expected_path): """Asserts function invocation""" expected_function_invocation = load_json_from_path(expected_path) assert function_invocation == expected_function_invocation + @pytest.mark.transaction_trace @devnet_in_background() def test_deploy_transaction_trace(): @@ -48,9 +64,10 @@ def test_deploy_transaction_trace(): assert transaction_trace["signature"] == [] assert_function_invocation( transaction_trace["function_invocation"], - "test/expected/deploy_function_invocation.json" + "test/expected/deploy_function_invocation.json", ) + @pytest.mark.transaction_trace @devnet_in_background() def test_invoke_transaction_hash(): @@ -65,7 +82,7 @@ def test_invoke_transaction_hash(): assert transaction_trace["signature"] == [] assert_function_invocation( transaction_trace["function_invocation"], - "test/expected/invoke_function_invocation.json" + "test/expected/invoke_function_invocation.json", ) @@ -74,7 +91,9 @@ def test_invoke_transaction_hash(): def test_invoke_transaction_hash_with_signature(): """Test invoke transaction trace with signature""" contract_address = deploy_empty_contract()["address"] - tx_hash = invoke("increase_balance", ["10", "20"], contract_address, ABI_PATH, SIGNATURE) + tx_hash = invoke( + "increase_balance", ["10", "20"], contract_address, ABI_PATH, SIGNATURE + ) res = get_transaction_trace_response(tx_hash) assert res.status_code == 200 @@ -86,9 +105,10 @@ def test_invoke_transaction_hash_with_signature(): assert_function_invocation( transaction_trace["function_invocation"], - "test/expected/invoke_function_invocation.json" + "test/expected/invoke_function_invocation.json", ) + @pytest.mark.transaction_trace @devnet_in_background() def test_nonexistent_transaction_hash(): @@ -97,11 +117,11 @@ def test_nonexistent_transaction_hash(): assert res.status_code == 500 + def assert_get_block_traces_response(params, expected_tx_hash): """Assert response of get_block_traces""" block_traces = requests.get( - f"{APP_URL}/feeder_gateway/get_block_traces", - params=params + f"{APP_URL}/feeder_gateway/get_block_traces", params=params ).json() # loading to assert valid structure @@ -111,6 +131,7 @@ def assert_get_block_traces_response(params, expected_tx_hash): actual_tx_hash = block_traces["traces"][0]["transaction_hash"] assert actual_tx_hash == expected_tx_hash + @pytest.mark.transaction_trace @devnet_in_background() def test_get_block_traces(): @@ -121,9 +142,12 @@ def test_get_block_traces(): tx_receipt = get_transaction_receipt(tx_hash=tx_hash) block_hash = tx_receipt["block_hash"] - assert_get_block_traces_response({ "blockHash": block_hash }, tx_hash) - assert_get_block_traces_response({ "blockNumber": GENESIS_BLOCK_NUMBER + 1 }, tx_hash) - assert_get_block_traces_response({}, tx_hash) # default behavior - no params provided + assert_get_block_traces_response({"blockHash": block_hash}, tx_hash) + assert_get_block_traces_response({"blockNumber": GENESIS_BLOCK_NUMBER + 1}, tx_hash) + assert_get_block_traces_response( + {}, tx_hash + ) # default behavior - no params provided + @pytest.mark.transaction_trace @devnet_in_background() diff --git a/test/test_tx_version.py b/test/test_tx_version.py index 2b4ab180a..ebbc264eb 100644 --- a/test/test_tx_version.py +++ b/test/test_tx_version.py @@ -12,6 +12,7 @@ CONTRACT_PATH = f"{ARTIFACTS_PATH}/tx_version.cairo/tx_version.json" ABI_PATH = f"{ARTIFACTS_PATH}/tx_version.cairo/tx_version_abi.json" + @pytest.mark.tx_version @devnet_in_background() def test_transaction_version(): diff --git a/test/util.py b/test/util.py index 391a0349b..eac2659a2 100644 --- a/test/util.py +++ b/test/util.py @@ -15,9 +15,11 @@ from starknet_devnet.general_config import DEFAULT_GENERAL_CONFIG from .settings import HOST, PORT, APP_URL + class ReturnCodeAssertionError(AssertionError): """Error to be raised when the return code of an executed process is not as expected.""" + def run_devnet_in_background(*args, stderr=None, stdout=None): """ Runs starknet-devnet in background. @@ -30,18 +32,29 @@ def run_devnet_in_background(*args, stderr=None, stdout=None): if "--accounts" not in args: args = [*args, "--accounts", "1"] - command = ["poetry", "run", "starknet-devnet", "--host", HOST, "--port", PORT, *args] + command = [ + "poetry", + "run", + "starknet-devnet", + "--host", + HOST, + "--port", + PORT, + *args, + ] # pylint: disable=consider-using-with proc = subprocess.Popen(command, close_fds=True, stderr=stderr, stdout=stdout) ensure_server_alive(f"{APP_URL}/is_alive", proc) return proc + def devnet_in_background(*devnet_args, **devnet_kwargs): """ Decorator that runs devnet in background and later kills it. Prints devnet output in case of AssertionError. """ + def wrapper(func): @functools.wraps(func) def inner_wrapper(*args, **kwargs): @@ -50,15 +63,21 @@ def inner_wrapper(*args, **kwargs): func(*args, **kwargs) finally: terminate_and_wait(proc) + return inner_wrapper + return wrapper + def terminate_and_wait(proc: subprocess.Popen): """Terminates the process and waits.""" proc.terminate() proc.wait() -def ensure_server_alive(url: str, proc: subprocess.Popen, check_period=0.5, max_wait=60): + +def ensure_server_alive( + url: str, proc: subprocess.Popen, check_period=0.5, max_wait=60 +): """ Ensures that server at provided `url` is alive or that `proc` has terminated. Checks every `check_period` seconds. @@ -82,14 +101,19 @@ def ensure_server_alive(url: str, proc: subprocess.Popen, check_period=0.5, max_ terminate_and_wait(proc) raise RuntimeError(f"max_wait time {max_wait} exceeded while checking {url}") + def assert_equal(actual, expected, explanation=None): """Assert that the two values are equal. Optionally provide explanation.""" - assert actual == expected, f"\nActual: {actual}\nExpected: {expected}\nAdditional_info: {explanation}" + assert ( + actual == expected + ), f"\nActual: {actual}\nExpected: {expected}\nAdditional_info: {explanation}" + def assert_hex_equal(actual, expected): """Assert that two hex strings are equal when converted to ints""" assert int(actual, 16) == int(expected, 16) + def extract(regex, stdout): """Extract from `stdout` what matches `regex`.""" matched = re.search(regex, stdout) @@ -97,30 +121,32 @@ def extract(regex, stdout): return matched.group(1) raise RuntimeError(f"Cannot extract from {stdout}") + def extract_class_hash(stdout): """Extract class hash from stdout.""" return extract(r"Contract class hash: (\w*)", stdout) + def extract_tx_hash(stdout): """Extract tx_hash from stdout.""" return extract(r"Transaction hash: (\w*)", stdout) + def extract_fee(stdout) -> int: """Extract fee from stdout.""" return int(extract(r"(\d+)", stdout)) + def extract_address(stdout): """Extract address from stdout.""" return extract(r"Contract address: (\w*)", stdout) + def run_starknet(args, raise_on_nonzero=True, add_gateway_urls=True): """Wrapper around subprocess.run""" my_args = ["poetry", "run", "starknet", *args, "--no_wallet"] if add_gateway_urls: - my_args.extend([ - "--gateway_url", APP_URL, - "--feeder_gateway_url", APP_URL - ]) + my_args.extend(["--gateway_url", APP_URL, "--feeder_gateway_url", APP_URL]) output = subprocess.run(my_args, encoding="utf-8", check=False, capture_output=True) if output.returncode != 0 and raise_on_nonzero: if output.stderr: @@ -128,15 +154,17 @@ def run_starknet(args, raise_on_nonzero=True, add_gateway_urls=True): raise ReturnCodeAssertionError(output.stdout) return output + def declare(contract): """Wrapper around starknet declare""" args = ["declare", "--contract", contract] output = run_starknet(args) return { "tx_hash": extract_tx_hash(output.stdout), - "class_hash": extract_class_hash(output.stdout) + "class_hash": extract_class_hash(output.stdout), } + def deploy(contract, inputs=None, salt=None): """Wrapper around starknet deploy""" args = ["deploy", "--contract", contract] @@ -147,9 +175,10 @@ def deploy(contract, inputs=None, salt=None): output = run_starknet(args) return { "tx_hash": extract_tx_hash(output.stdout), - "address": extract_address(output.stdout) + "address": extract_address(output.stdout), } + def assert_transaction(tx_hash, expected_status, expected_signature=None): """Wrapper around starknet get_transaction""" output = run_starknet(["get_transaction", "--hash", tx_hash]) @@ -170,50 +199,69 @@ def assert_transaction(tx_hash, expected_status, expected_signature=None): if tx_type == "INVOKE_FUNCTION": invoke_transaction_keys = [ - "calldata", "contract_address", "entry_point_selector", "entry_point_type", - "max_fee", "signature", "transaction_hash", "type" + "calldata", + "contract_address", + "entry_point_selector", + "entry_point_type", + "max_fee", + "signature", + "transaction_hash", + "type", ] assert_keys(transaction["transaction"], invoke_transaction_keys) if tx_type == "DEPLOY": deploy_transaction_keys = [ - "class_hash", "constructor_calldata", "contract_address", - "contract_address_salt", "transaction_hash", "type" + "class_hash", + "constructor_calldata", + "contract_address", + "contract_address_salt", + "transaction_hash", + "type", ] assert_keys(transaction["transaction"], deploy_transaction_keys) + def assert_keys(dictionary, keys): """Asserts that the dict has the correct keys""" expected_set = set(keys) assert dictionary.keys() == expected_set, f"{dictionary.keys()} != {expected_set}" + def assert_transaction_not_received(tx_hash): """Assert correct tx response when there is no tx with `tx_hash`.""" output = run_starknet(["get_transaction", "--hash", tx_hash]) transaction = json.loads(output.stdout) - assert_equal(transaction, { - "status": "NOT_RECEIVED" - }) + assert_equal(transaction, {"status": "NOT_RECEIVED"}) + def assert_transaction_receipt_not_received(tx_hash): """Assert correct tx receipt response when there is no tx with `tx_hash`.""" receipt = get_transaction_receipt(tx_hash) - assert_equal(receipt, { - "events": [], - "l2_to_l1_messages": [], - "status": "NOT_RECEIVED", - "transaction_hash": tx_hash - }) + assert_equal( + receipt, + { + "events": [], + "l2_to_l1_messages": [], + "status": "NOT_RECEIVED", + "transaction_hash": tx_hash, + }, + ) + # pylint: disable=too-many-arguments def invoke(function, inputs, address, abi_path, signature=None, max_fee=None): """Wrapper around starknet invoke. Returns tx hash.""" args = [ "invoke", - "--function", function, - "--inputs", *inputs, - "--address", address, - "--abi", abi_path, + "--function", + function, + "--inputs", + *inputs, + "--address", + address, + "--abi", + abi_path, ] if signature: args.extend(["--signature", *signature]) @@ -231,10 +279,14 @@ def estimate_fee(function, inputs, address, abi_path, signature=None): """Wrapper around starknet estimate_fee. Returns fee in wei.""" args = [ "estimate_fee", - "--function", function, - "--inputs", *inputs, - "--address", address, - "--abi", abi_path, + "--function", + function, + "--inputs", + *inputs, + "--address", + address, + "--abi", + abi_path, ] if signature: args.extend(["--signature", *signature]) @@ -249,9 +301,12 @@ def call(function, address, abi_path, inputs=None, signature=None, max_fee=None) """Wrapper around starknet call""" args = [ "call", - "--function", function, - "--address", address, - "--abi", abi_path, + "--function", + function, + "--address", + address, + "--abi", + abi_path, ] if inputs: args.extend(["--inputs", *inputs]) @@ -265,12 +320,14 @@ def call(function, address, abi_path, inputs=None, signature=None, max_fee=None) print("Call successful!") return output.stdout.rstrip() + def load_contract_class(contract_path: str): """Loads the contract class from the contract path""" loaded_contract = load_json_from_path(contract_path) return ContractClass.load(loaded_contract) + def assert_tx_status(tx_hash, expected_tx_status): """Asserts the tx_status of the tx with tx_hash.""" output = run_starknet(["tx_status", "--hash", tx_hash]) @@ -281,6 +338,7 @@ def assert_tx_status(tx_hash, expected_tx_status): if tx_status == "REJECTED": assert "tx_failure_reason" in response, f"Key not found in {response}" + def assert_contract_code(address): """Asserts the content of the code of a contract at address.""" output = run_starknet(["get_code", "--contract_address", address]) @@ -288,46 +346,52 @@ def assert_contract_code(address): # just checking key equality assert_equal(sorted(code.keys()), ["abi", "bytecode"]) + def assert_contract_class(actual_class: ContractClass, expected_class_path: str): """Asserts equality between `actual_class` and class at `expected_class_path`.""" loaded_contract_class = load_contract_class(expected_class_path) assert_equal(actual_class, loaded_contract_class.remove_debug_info()) + def assert_storage(address, key, expected_value): """Asserts the storage value stored at (address, key).""" - output = run_starknet([ - "get_storage_at", - "--contract_address", address, - "--key", key - ]) + output = run_starknet( + ["get_storage_at", "--contract_address", address, "--key", key] + ) assert_equal(output.stdout.rstrip(), expected_value) + def load_json_from_path(path): """Loads a json file from `path`.""" with open(path, encoding="utf-8") as expected_file: return json.load(expected_file) + def get_transaction_receipt(tx_hash): """Fetches the transaction receipt of transaction with tx_hash""" output = run_starknet(["get_transaction_receipt", "--hash", tx_hash]) return json.loads(output.stdout) + def get_full_contract(contract_address: str) -> ContractClass: """Gets contract class by contract address""" output = run_starknet(["get_full_contract", "--contract_address", contract_address]) return ContractClass.loads(output.stdout) + def get_class_hash_at(contract_address: str) -> str: """Gets class hash at given contract address""" output = run_starknet(["get_class_hash_at", "--contract_address", contract_address]) return json.loads(output.stdout) + def get_class_by_hash(class_hash: str) -> str: """Gets contract class by contract hash""" output = run_starknet(["get_class_by_hash", "--class_hash", class_hash]) return ContractClass.loads(output.stdout) + def assert_receipt(tx_hash, expected_path): """Asserts the content of the receipt of tx with tx_hash.""" receipt = get_transaction_receipt(tx_hash) @@ -340,12 +404,14 @@ def assert_receipt(tx_hash, expected_path): expected_receipt.pop(ignorable_key) assert_equal(receipt, expected_receipt) + def assert_events(tx_hash, expected_path): """Asserts the content of the events element of the receipt of tx with tx_hash.""" receipt = get_transaction_receipt(tx_hash) expected_receipt = load_json_from_path(expected_path) assert_equal(receipt["events"], expected_receipt["events"]) + def get_block(block_number=None, parse=False): """Get the block with block_number. If no number provided, return the last.""" args = ["get_block"] @@ -357,6 +423,7 @@ def get_block(block_number=None, parse=False): return run_starknet(args, raise_on_nonzero=False) + def assert_negative_block_input(): """Test behavior if get_block provided with negative input.""" try: @@ -365,11 +432,14 @@ def assert_negative_block_input(): except ReturnCodeAssertionError: print("Correctly rejecting negative block number") + def assert_block(latest_block_number, latest_tx_hash): """Asserts the content of the block with block_number.""" too_big = 1000 error_message = get_block(block_number=too_big, parse=False).stderr - total_blocks_str = re.search("There are currently (.*) blocks.", error_message).group(1) + total_blocks_str = re.search( + "There are currently (.*) blocks.", error_message + ).group(1) total_blocks = int(total_blocks_str) extracted_last_block_number = total_blocks - 1 assert_equal(extracted_last_block_number, latest_block_number) @@ -386,9 +456,12 @@ def assert_block(latest_block_number, latest_tx_hash): latest_transaction = latest_block_transactions[0] assert_equal(latest_transaction["transaction_hash"], latest_tx_hash) - assert_equal(latest_block["sequencer_address"], hex(DEFAULT_GENERAL_CONFIG.sequencer_address)) + assert_equal( + latest_block["sequencer_address"], hex(DEFAULT_GENERAL_CONFIG.sequencer_address) + ) assert_equal(latest_block["gas_price"], hex(DEFAULT_GENERAL_CONFIG.min_gas_price)) + def assert_block_hash(latest_block_number, expected_block_hash): """Asserts the content of the block with block_number.""" @@ -396,7 +469,10 @@ def assert_block_hash(latest_block_number, expected_block_hash): assert_equal(block["block_hash"], expected_block_hash) assert_equal(block["status"], "ACCEPTED_ON_L2") -def assert_salty_deploy(contract_path, inputs, salt, expected_status, expected_address, expected_tx_hash): + +def assert_salty_deploy( + contract_path, inputs, salt, expected_status, expected_address, expected_tx_hash +): """Deploy with salt and assert.""" deploy_info = deploy(contract_path, inputs, salt=salt) @@ -404,18 +480,21 @@ def assert_salty_deploy(contract_path, inputs, salt, expected_status, expected_a assert_equal(deploy_info["address"], expected_address) assert_equal(deploy_info["tx_hash"], expected_tx_hash) + def assert_failing_deploy(contract_path): """Run deployment for a contract that's expected to be rejected.""" deploy_info = deploy(contract_path) assert_tx_status(deploy_info["tx_hash"], "REJECTED") assert_transaction(deploy_info["tx_hash"], "REJECTED") + def load_file_content(file_name: str): """Load content of file located in the same directory as this test file.""" full_file_path = os.path.join(os.path.dirname(__file__), file_name) with open(full_file_path, encoding="utf-8") as deploy_file: return deploy_file.read() + def create_empty_block(): """Creates an empty block and returns it.""" resp = requests.post(f"{APP_URL}/create_block") diff --git a/test/web3_util.py b/test/web3_util.py index d89b3c8c3..4f5cbddf4 100644 --- a/test/web3_util.py +++ b/test/web3_util.py @@ -2,16 +2,18 @@ from web3 import Web3 + def web3_deploy(web3: Web3, contract, *inputs): """Deploys a Solidity contract""" - abi=contract["abi"] - bytecode=contract["bytecode"] + abi = contract["abi"] + bytecode = contract["bytecode"] contract = web3.eth.contract(abi=abi, bytecode=bytecode) tx_hash = contract.constructor(*inputs).transact() tx_receipt = web3.eth.wait_for_transaction_receipt(tx_hash) return web3.eth.contract(address=tx_receipt.contractAddress, abi=abi) -def web3_transact(web3: Web3, function, contract, *inputs): + +def web3_transact(web3: Web3, function, contract, *inputs): """Invokes a function in a Web3 contract""" contract_function = contract.get_function_by_name(function)(*inputs) @@ -20,6 +22,7 @@ def web3_transact(web3: Web3, function, contract, *inputs): return tx_hash + def web3_call(function, contract, *inputs): """Calls a function in a Web3 contract"""