Skip to content

Commit

Permalink
Support predeploying custom accounts (0xSpaceShard#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
FabijanC committed Oct 18, 2022
1 parent 3c3c8fc commit 86f28da
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 51 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ starknet-devnet = "starknet_devnet.server:main"
[tool.pytest.ini_options]
markers = [
"account",
"account_custom",
"account_predeployed",
"call",
"declare",
Expand Down
37 changes: 13 additions & 24 deletions starknet_devnet/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,34 @@
"""

from starkware.cairo.lang.vm.crypto import pedersen_hash
from starkware.solidity.utils import load_nearby_contract
from starkware.starknet.public.abi import get_selector_from_name
from starkware.starknet.services.api.contract_class import ContractClass
from starkware.starknet.core.os.contract_address.contract_address import (
calculate_contract_address_from_hash,
)
from starkware.starknet.testing.contract import StarknetContract
from starkware.python.utils import to_bytes
from starkware.starknet.testing.starknet import Starknet

from starknet_devnet.contract_class_wrapper import ContractClassWrapper
from starknet_devnet.util import Uint256


class Account:
"""Account contract wrapper."""

CONTRACT_CLASS: ContractClass = None # loaded lazily
CONTRACT_PATH = "accounts_artifacts/OpenZeppelin/0.4.0b-fork/Account.cairo/Account"

# Precalculated to save time
# HASH = compute_class_hash(contract_class=Account.get_contract_class()))
HASH = 250058203962332945652607154704986145054927159797127109843768594742871092378
HASH_BYTES = to_bytes(HASH)

# pylint: disable=too-many-arguments
def __init__(
self, starknet_wrapper, private_key: int, public_key: int, initial_balance: int
self,
starknet_wrapper,
private_key: int,
public_key: int,
initial_balance: int,
account_class_wrapper: ContractClassWrapper,
):
self.starknet_wrapper = starknet_wrapper
self.private_key = private_key
self.public_key = public_key
self.contract_class = account_class_wrapper.contract_class
self.class_hash_bytes = account_class_wrapper.hash_bytes

# salt and class_hash have frozen values that make the constructor_calldata
# the only thing that affects the account address
Expand All @@ -44,15 +42,6 @@ def __init__(
)
self.initial_balance = initial_balance

@classmethod
def get_contract_class(cls):
"""Returns contract class via lazy loading."""
if not cls.CONTRACT_CLASS:
cls.CONTRACT_CLASS = ContractClass.load(
load_nearby_contract(cls.CONTRACT_PATH)
)
return cls.CONTRACT_CLASS

def to_json(self):
"""Return json account"""
return {
Expand All @@ -65,11 +54,11 @@ def to_json(self):
async def deploy(self) -> StarknetContract:
"""Deploy this account."""
starknet: Starknet = self.starknet_wrapper.starknet
contract_class = Account.get_contract_class()
contract_class = self.contract_class
await starknet.state.state.set_contract_class(
Account.HASH_BYTES, contract_class
self.class_hash_bytes, contract_class
)
await starknet.state.state.deploy_contract(self.address, Account.HASH_BYTES)
await starknet.state.state.deploy_contract(self.address, self.class_hash_bytes)

await starknet.state.state.set_storage_at(
self.address, get_selector_from_name("Account_public_key"), self.public_key
Expand Down
3 changes: 3 additions & 0 deletions starknet_devnet/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from typing import List
from starkware.crypto.signature.signature import private_to_stark_key

from .account import Account


Expand All @@ -19,6 +20,7 @@ def __init__(self, starknet_wrapper):
self.starknet_wrapper = starknet_wrapper
self.__n_accounts = starknet_wrapper.config.accounts
self.__initial_balance = starknet_wrapper.config.initial_balance
self.__account_class_wrapper = starknet_wrapper.config.account_class

self.__seed = starknet_wrapper.config.seed
if self.__seed is None:
Expand Down Expand Up @@ -61,6 +63,7 @@ def __generate(self):
private_key=private_key,
public_key=public_key,
initial_balance=self.__initial_balance,
account_class_wrapper=self.__account_class_wrapper,
)
)

Expand Down
31 changes: 31 additions & 0 deletions starknet_devnet/contract_class_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Starknet ContractClass wrapper utilities"""

from dataclasses import dataclass
import os

from starkware.python.utils import to_bytes
from starkware.starknet.services.api.contract_class import ContractClass


@dataclass
class ContractClassWrapper:
"""Wrapper of ContractClass"""

contract_class: ContractClass
hash_bytes: bytes


DEFAULT_ACCOUNT_PATH = os.path.abspath(
os.path.join(
__file__,
os.pardir,
"accounts_artifacts",
"OpenZeppelin",
"0.4.0b-fork",
"Account.cairo",
"Account.json",
)
)
DEFAULT_ACCOUNT_HASH_BYTES = to_bytes(
250058203962332945652607154704986145054927159797127109843768594742871092378
)
62 changes: 60 additions & 2 deletions starknet_devnet/devnet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,21 @@

import argparse
from enum import Enum, auto
import json
import os
import sys
from typing import List

from marshmallow.exceptions import ValidationError
from starkware.python.utils import to_bytes
from starkware.starknet.core.os.class_hash import compute_class_hash
from starkware.starknet.services.api.contract_class import ContractClass

from .contract_class_wrapper import (
ContractClassWrapper,
DEFAULT_ACCOUNT_HASH_BYTES,
DEFAULT_ACCOUNT_PATH,
)
from . import __version__
from .constants import (
DEFAULT_ACCOUNTS,
Expand Down Expand Up @@ -41,7 +53,7 @@ class DumpOn(Enum):
DUMP_ON_OPTIONS_STRINGIFIED = ", ".join(DUMP_ON_OPTIONS)


def parse_dump_on(option: str):
def _parse_dump_on(option: str):
"""Parse dumping frequency option."""
if option in DUMP_ON_OPTIONS:
return DumpOn[option.upper()]
Expand All @@ -50,6 +62,43 @@ def parse_dump_on(option: str):
)


EXPECTED_ACCOUNT_METHODS = ["__execute__", "__validate__", "__validate_declare__"]


def _parse_account_class(class_path: str) -> ContractClassWrapper:
"""Parse account class"""
class_path = os.path.abspath(class_path)

if not os.path.isfile(class_path):
sys.exit(f"Error: {class_path} is not a valid file")

with open(class_path, mode="r", encoding="utf-8") as dict_file:
try:
loaded_dict = json.load(dict_file)
except json.JSONDecodeError:
sys.exit(f"Error: {class_path} is not a valid JSON file")

try:
contract_class = ContractClass.load(loaded_dict)
except ValidationError:
sys.exit(f"Error: {class_path} is not a valid contract class artifact")

if class_path == DEFAULT_ACCOUNT_PATH:
class_hash_bytes = DEFAULT_ACCOUNT_HASH_BYTES
else:
contract_methods = [entry["name"] for entry in contract_class.abi]
missing_methods = [
m for m in EXPECTED_ACCOUNT_METHODS if m not in contract_methods
]
if missing_methods:
sys.exit(
f"Error: {class_path} is missing account methods: {', '.join(missing_methods)}"
)
class_hash_bytes = to_bytes(compute_class_hash(contract_class))

return ContractClassWrapper(contract_class, class_hash_bytes)


class NonNegativeAction(argparse.Action):
"""
Action for parsing the non negative int argument.
Expand Down Expand Up @@ -102,7 +151,7 @@ def parse_args(raw_args: List[str]):
parser.add_argument(
"--dump-on",
help=f"Specify when to dump; can dump on: {DUMP_ON_OPTIONS_STRINGIFIED}",
type=parse_dump_on,
type=_parse_dump_on,
)
parser.add_argument(
"--lite-mode",
Expand Down Expand Up @@ -153,6 +202,14 @@ def parse_args(raw_args: List[str]):
default=DEFAULT_TIMEOUT,
help=f"Specify the server timeout in seconds; defaults to {DEFAULT_TIMEOUT}",
)
parser.add_argument(
"--account-class",
help="Specify the account implementation to be used for predeploying; "
"should be a path to the compiled JSON artifact; "
"defaults to a fork of OpenZeppelin v0.4.0b",
type=_parse_account_class,
default=DEFAULT_ACCOUNT_PATH,
)
# Uncomment this once fork support is added
# parser.add_argument(
# "--fork", "-f",
Expand Down Expand Up @@ -182,4 +239,5 @@ def __init__(self, args: argparse.Namespace = None):
self.start_time = self.args.start_time
self.gas_price = self.args.gas_price
self.lite_mode = self.args.lite_mode
self.account_class = self.args.account_class
self.hide_predeployed_accounts = self.args.hide_predeployed_accounts
1 change: 1 addition & 0 deletions test/custom_account.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions test/custom_account_missing_method.json

Large diffs are not rendered by default.

85 changes: 85 additions & 0 deletions test/test_account_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Test custom account"""

import subprocess
import os
import pytest

from starkware.starknet.core.os.class_hash import compute_class_hash
from starkware.starknet.services.api.contract_class import ContractClass

from .account import invoke
from .shared import (
ABI_PATH,
CONTRACT_PATH,
PREDEPLOY_ACCOUNT_CLI_ARGS,
PREDEPLOYED_ACCOUNT_ADDRESS,
PREDEPLOYED_ACCOUNT_PRIVATE_KEY,
)
from .util import (
call,
deploy,
devnet_in_background,
DevnetBackgroundProc,
get_class_hash_at,
load_file_content,
)

ACTIVE_DEVNET = DevnetBackgroundProc()

NON_EXISTENT_PATH = "most-certainly-non-existent-path.txt"
DIR_PATH = os.path.abspath(os.path.join(__file__, os.pardir))
MISSING_METHOD_PATH = os.path.join(
__file__, os.pardir, "custom_account_missing_method.json"
)
CORRECT_PATH = os.path.join(__file__, os.pardir, "custom_account.json")


@pytest.mark.account_custom
@pytest.mark.parametrize(
"class_path, expected_error",
[
(
NON_EXISTENT_PATH,
f"Error: {os.path.abspath(NON_EXISTENT_PATH)} is not a valid file\n",
),
(DIR_PATH, f"Error: {DIR_PATH} is not a valid file\n"),
(__file__, f"Error: {__file__} is not a valid JSON file\n"),
(
ABI_PATH,
f"Error: {os.path.abspath(ABI_PATH)} is not a valid contract class artifact\n",
),
(
MISSING_METHOD_PATH,
f"Error: {os.path.abspath(MISSING_METHOD_PATH)} is missing account methods: __validate_declare__\n",
),
],
)
def test_invalid_path(class_path: str, expected_error: str):
"""Test behavior on providing nonexistent path"""
proc = ACTIVE_DEVNET.start("--account-class", class_path, stderr=subprocess.PIPE)
assert proc.returncode == 1
assert proc.stderr.read().decode("utf-8") == expected_error


@pytest.mark.account_custom
@devnet_in_background("--account-class", CORRECT_PATH, *PREDEPLOY_ACCOUNT_CLI_ARGS)
def test_providing_correct_account_class():
"""Test behavior if correct custom account provided"""
fetched_class_hash = int(get_class_hash_at(PREDEPLOYED_ACCOUNT_ADDRESS), 16)

expected_contract_class = ContractClass.loads(
load_file_content("custom_account.json")
)
assert fetched_class_hash == compute_class_hash(expected_contract_class)

deploy_info = deploy(CONTRACT_PATH, ["0"])
invoke(
calls=[(deploy_info["address"], "increase_balance", [10, 20])],
account_address=PREDEPLOYED_ACCOUNT_ADDRESS,
private_key=PREDEPLOYED_ACCOUNT_PRIVATE_KEY,
)
increased_value = call(
function="get_balance", address=deploy_info["address"], abi_path=ABI_PATH
)

assert increased_value == "30"
16 changes: 11 additions & 5 deletions test/test_account_predeployed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import requests

from starkware.starknet.core.os.class_hash import compute_class_hash
from starknet_devnet.account import Account
from .util import assert_equal, devnet_in_background

from starknet_devnet.contract_class_wrapper import (
DEFAULT_ACCOUNT_HASH_BYTES,
DEFAULT_ACCOUNT_PATH,
)
from .util import assert_equal, devnet_in_background, load_contract_class
from .support.assertions import assert_valid_schema
from .settings import APP_URL

Expand All @@ -22,10 +26,12 @@


@pytest.mark.account_predeployed
def test_precomputed_contract_hash():
def test_precomputed_account_hash():
"""Test if the precomputed hash of the account contract is correct."""
recalculated_hash = compute_class_hash(contract_class=Account.get_contract_class())
assert_equal(recalculated_hash, Account.HASH)

contract_class = load_contract_class(DEFAULT_ACCOUNT_PATH)
recalculated_hash = compute_class_hash(contract_class=contract_class)
assert_equal(recalculated_hash, int.from_bytes(DEFAULT_ACCOUNT_HASH_BYTES, "big"))


@pytest.mark.account_predeployed
Expand Down
21 changes: 1 addition & 20 deletions test/test_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
call,
deploy,
devnet_in_background,
run_devnet_in_background,
terminate_and_wait,
DevnetBackgroundProc,
)
from .settings import APP_URL
from .shared import (
Expand All @@ -32,25 +32,6 @@
DUMP_PATH = "dump.pkl"


class DevnetBackgroundProc:
"""Helper for ensuring we always have only 1 active devnet server running in background"""

def __init__(self):
self.proc = None

def start(self, *args, stderr=None, stdout=None):
"""Starts a new devnet-server instance. Previously active instance will be stopped."""
self.stop()
self.proc = run_devnet_in_background(*args, stderr=stderr, stdout=stdout)
return self.proc

def stop(self):
"""Stops the currently active devnet-server instance"""
if self.proc:
terminate_and_wait(self.proc)
self.proc = None


ACTIVE_DEVNET = DevnetBackgroundProc()


Expand Down
Loading

0 comments on commit 86f28da

Please sign in to comment.