Skip to content

Commit

Permalink
Support estimate_message_fee
Browse files Browse the repository at this point in the history
  • Loading branch information
FabijanC committed Sep 27, 2022
1 parent 26cf787 commit 19fb3ea
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 35 deletions.
21 changes: 17 additions & 4 deletions starknet_devnet/blueprints/feeder_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
InvokeFunction,
)
from starkware.starknet.services.api.gateway.transaction import Transaction
from starkware.starknet.services.api.feeder_gateway.request_objects import CallFunction
from starkware.starknet.services.api.feeder_gateway.request_objects import (
CallFunction,
CallL1Handler,
)
from starkware.starknet.services.api.feeder_gateway.response_objects import (
TransactionSimulationInfo,
FeeEstimationInfo,
)
from werkzeug.datastructures import MultiDict

Expand Down Expand Up @@ -276,7 +278,7 @@ async def estimate_fee():
transaction = validate_request(request.data, InvokeFunction) # version 0

_, fee_response = await state.starknet_wrapper.calculate_trace_and_fee(transaction)
return jsonify(FeeEstimationInfo.load(fee_response))
return jsonify(fee_response)


@feeder_gateway.route("/simulate_transaction", methods=["POST"])
Expand All @@ -288,7 +290,7 @@ async def simulate_transaction():
)

simulation_info = TransactionSimulationInfo(
trace=trace, fee_estimation=FeeEstimationInfo.load(fee_response)
trace=trace, fee_estimation=fee_response
)

return jsonify(simulation_info.dump())
Expand All @@ -302,3 +304,14 @@ async def get_nonce():
nonce = await state.starknet_wrapper.get_nonce(contract_address)

return jsonify(hex(nonce))


@feeder_gateway.route("/estimate_message_fee", methods=["POST"])
async def estimate_message_fee():
"""Message fee estimation endpoint"""

_check_block_hash(request.args)

call = validate_request(request.data, CallL1Handler)
fee_estimation = await state.starknet_wrapper.estimate_message_fee(call)
return jsonify(fee_estimation)
9 changes: 5 additions & 4 deletions starknet_devnet/blueprints/rpc/structures/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TransactionType,
BlockStateUpdate,
DeclareSpecificInfo,
FeeEstimationInfo,
)
from starkware.starknet.services.api.gateway.transaction import InvokeFunction
from starkware.starknet.services.api.gateway.transaction_utils import compress_program
Expand Down Expand Up @@ -231,14 +232,14 @@ class RpcFeeEstimate(TypedDict):
overall_fee: NumAsHex


def rpc_fee_estimate(fee_estimate: dict) -> dict:
def rpc_fee_estimate(fee_estimate: FeeEstimationInfo) -> dict:
"""
Convert gateway estimate_fee response to rpc_fee_estimate
"""
result: RpcFeeEstimate = {
"gas_consumed": hex(fee_estimate["gas_usage"]),
"gas_price": hex(fee_estimate["gas_price"]),
"overall_fee": hex(fee_estimate["overall_fee"]),
"gas_consumed": hex(fee_estimate.gas_usage),
"gas_price": hex(fee_estimate.gas_price),
"overall_fee": hex(fee_estimate.overall_fee),
}
return result

Expand Down
39 changes: 28 additions & 11 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
InternalInvokeFunction,
InternalDeclare,
InternalDeploy,
InternalL1Handler,
)
from starkware.starknet.business_logic.state.state import BlockInfo, CachedState
from starkware.starknet.services.api.gateway.transaction import (
Expand All @@ -21,10 +22,14 @@
from starkware.starknet.testing.starknet import Starknet
from starkware.starkware_utils.error_handling import StarkException
from starkware.starknet.services.api.contract_class import EntryPointType, ContractClass
from starkware.starknet.services.api.feeder_gateway.request_objects import CallFunction
from starkware.starknet.services.api.feeder_gateway.request_objects import (
CallFunction,
CallL1Handler,
)
from starkware.starknet.services.api.feeder_gateway.response_objects import (
TransactionStatus,
)

from starkware.starknet.testing.contract import StarknetContract
from starkware.starknet.testing.objects import FunctionInvocation
from starkware.starknet.services.api.feeder_gateway.response_objects import (
Expand All @@ -37,7 +42,7 @@
StorageEntry,
)

from starknet_devnet.util import to_bytes
from starknet_devnet.util import to_bytes, get_fee_estimation_info
from starknet_devnet.constants import DUMMY_STATE_ROOT

from .lite_mode.lite_internal_deploy import LiteInternalDeploy
Expand Down Expand Up @@ -519,17 +524,29 @@ async def calculate_trace_and_fee(self, external_tx: InvokeFunction):
signature=external_tx.signature,
)

tx_fee = execution_info.actual_fee
fee_estimation_info = get_fee_estimation_info(
execution_info.actual_fee, state.state.block_info.gas_price
)

return trace, fee_estimation_info

async def estimate_message_fee(self, call: CallL1Handler):
"""Estimate fee of message from L1 to L2"""
state = self.get_state()
internal_call: InternalL1Handler = call.to_internal(
state.general_config.chain_id.value
)

gas_price = state.state.block_info.gas_price
gas_usage = tx_fee // gas_price if gas_price else 0
execution_info = await internal_call.apply_state_updates(
# pylint: disable=protected-access
state.state._copy(),
state.general_config,
)

return trace, {
"overall_fee": tx_fee,
"unit": "wei",
"gas_price": gas_price,
"gas_usage": gas_usage,
}
fee_estimation_info = get_fee_estimation_info(
execution_info.actual_fee, state.state.block_info.gas_price
)
return fee_estimation_info

def increase_block_time(self, time_s: int):
"""Increases the block time by `time_s`."""
Expand Down
16 changes: 16 additions & 0 deletions starknet_devnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from starkware.starknet.business_logic.state.state import CachedState
from starkware.starknet.services.api.feeder_gateway.response_objects import (
DeployedContract,
FeeEstimationInfo,
StorageEntry,
)

Expand Down Expand Up @@ -162,3 +163,18 @@ async def get_storage_diffs(
)

return storage_diffs


def get_fee_estimation_info(tx_fee: int, gas_price: int):
"""Construct fee estimation response"""

gas_usage = tx_fee // gas_price if gas_price else 0

return FeeEstimationInfo.load(
{
"overall_fee": tx_fee,
"unit": "wei",
"gas_price": gas_price,
"gas_usage": gas_usage,
}
)
2 changes: 1 addition & 1 deletion test/contracts/cairo/l1l2.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func withdraw{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
func deposit{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
from_address: felt, user: felt, amount: felt
) {
// Make sure the message was sent by the intended L1 contract.
// In a real case scenario, here we would assert from_address value

// Read the current balance.
let (res) = balance.read(user=user);
Expand Down
2 changes: 2 additions & 0 deletions test/rpc/test_rpc_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def test_add_invoke_transaction(invoke_content):
assert set(receipt.keys()) == {"transaction_hash"}
assert receipt["transaction_hash"][:3] == "0x0"


@pytest.mark.usefixtures("run_devnet_in_background")
def test_add_invoke_transaction_positional_args(invoke_content):
"""
Expand Down Expand Up @@ -340,6 +341,7 @@ def test_add_invoke_transaction_positional_args(invoke_content):
assert set(receipt.keys()) == {"transaction_hash"}
assert receipt["transaction_hash"][:3] == "0x0"


@pytest.mark.usefixtures("run_devnet_in_background")
def test_add_declare_transaction_on_incorrect_contract(declare_content):
"""
Expand Down
3 changes: 3 additions & 0 deletions test/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
DEPLOYER_CONTRACT_PATH = f"{ARTIFACTS_PATH}/deployer.cairo/deployer.json"
DEPLOYER_ABI_PATH = f"{ARTIFACTS_PATH}/deployer.cairo/deployer_abi.json"

L1L2_CONTRACT_PATH = f"{ARTIFACTS_PATH}/l1l2.cairo/l1l2.json"
L1L2_ABI_PATH = f"{ARTIFACTS_PATH}/l1l2.cairo/l1l2_abi.json"

BALANCE_KEY = (
"916907772491729262376534102982219947830828984996257231353398618781993312401"
)
Expand Down
31 changes: 29 additions & 2 deletions test/test_estimate_fee.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@


from starknet_devnet.constants import DEFAULT_GAS_PRICE
from .util import deploy, devnet_in_background, load_file_content
from .util import call, deploy, devnet_in_background, estimate_message_fee, load_file_content
from .settings import APP_URL
from .shared import CONTRACT_PATH, EXPECTED_CLASS_HASH
from .shared import CONTRACT_PATH, EXPECTED_CLASS_HASH, L1L2_ABI_PATH, L1L2_CONTRACT_PATH, PREDEPLOY_ACCOUNT_CLI_ARGS

DEPLOY_CONTENT = load_file_content("deploy.json")
INVOKE_CONTENT = load_file_content("invoke.json")
Expand Down Expand Up @@ -165,3 +165,30 @@ def test_simulate_transaction():
},
"signature": [],
}


@devnet_in_background(*PREDEPLOY_ACCOUNT_CLI_ARGS)
def test_estimate_message_fee():
"""Estimate message fee from l1 to l2"""

dummy_l1_address = "0x1"
user_id = "1"

l2_contract_address = deploy(contract=L1L2_CONTRACT_PATH)["address"]

message_fee = estimate_message_fee(
from_address=dummy_l1_address,
function="deposit",
inputs=[user_id, "100"],
to_address=l2_contract_address,
abi_path=L1L2_ABI_PATH
)
assert int(message_fee) > 0

balance = call(
function="get_balance",
address=l2_contract_address,
abi_path=L1L2_ABI_PATH,
)

assert int(balance) == 0
31 changes: 19 additions & 12 deletions test/test_postman.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,20 @@
deploy,
devnet_in_background,
ensure_server_alive,
estimate_message_fee,
load_file_content,
terminate_and_wait,
)
from .settings import APP_URL, L1_HOST, L1_PORT, L1_URL
from .shared import (
ARTIFACTS_PATH,
L1L2_ABI_PATH,
L1L2_CONTRACT_PATH,
PREDEPLOY_ACCOUNT_CLI_ARGS,
PREDEPLOYED_ACCOUNT_ADDRESS,
PREDEPLOYED_ACCOUNT_PRIVATE_KEY,
)
from .web3_util import web3_call, web3_deploy, web3_transact

CONTRACT_PATH = f"{ARTIFACTS_PATH}/l1l2.cairo/l1l2.json"
ABI_PATH = f"{ARTIFACTS_PATH}/l1l2.cairo/l1l2_abi.json"

ETH_CONTRACTS_PATH = "artifacts/contracts/solidity"
STARKNET_MESSAGING_PATH = (
f"{ETH_CONTRACTS_PATH}/MockStarknetMessaging.sol/MockStarknetMessaging.json"
Expand Down Expand Up @@ -164,7 +163,7 @@ def load_messaging_contract(starknet_messaging_contract_address):
def _init_l2_contract(l1l2_example_contract_address: str):
"""Deploys the L1L2Example cairo contract, returns the result of calling 'get_balance'"""

deploy_info = deploy(CONTRACT_PATH)
deploy_info = deploy(L1L2_CONTRACT_PATH)
l2_address = deploy_info["address"]

# increase and withdraw balance
Expand Down Expand Up @@ -200,8 +199,8 @@ def _init_l2_contract(l1l2_example_contract_address: str):
value = call(
function="get_balance",
address=deploy_info["address"],
abi_path=ABI_PATH,
inputs=[USER_ID],
abi_path=L1L2_ABI_PATH,
inputs=[str(USER_ID)],
)

assert value == "2333"
Expand Down Expand Up @@ -232,12 +231,20 @@ def _l1_l2_message_exchange(web3, l1l2_example_contract, l2_contract_address):
l2_balance = call(
function="get_balance",
address=l2_contract_address,
abi_path=ABI_PATH,
inputs=[USER_ID],
abi_path=L1L2_ABI_PATH,
inputs=[str(USER_ID)],
)

assert l2_balance == "2333"

message_fee = estimate_message_fee(
from_address=l1l2_example_contract.address,
function="deposit",
inputs=[str(USER_ID), "100"],
to_address=l2_contract_address,
abi_path=L1L2_ABI_PATH
)
assert int(message_fee) > 0

# deposit in l1 and assert contract balance
web3_transact(
web3,
Expand Down Expand Up @@ -275,8 +282,8 @@ def _l1_l2_message_exchange(web3, l1l2_example_contract, l2_contract_address):
l2_balance = call(
function="get_balance",
address=l2_contract_address,
abi_path=ABI_PATH,
inputs=[USER_ID],
abi_path=L1L2_ABI_PATH,
inputs=[str(USER_ID)],
)

assert l2_balance == "2933"
Expand Down
26 changes: 25 additions & 1 deletion test/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import subprocess
import time
from typing import List
import requests

from starkware.starknet.services.api.contract_class import ContractClass
Expand Down Expand Up @@ -134,7 +135,7 @@ def extract_tx_hash(stdout):

def extract_fee(stdout) -> int:
"""Extract fee from stdout."""
return int(extract(r"(\d+)", stdout))
return int(extract(r"The estimated fee is: (\d+) WEI", stdout))


def extract_address(stdout):
Expand Down Expand Up @@ -169,6 +170,29 @@ def deploy(contract, inputs=None, salt=None):
}


def estimate_message_fee(
from_address: str, function: str, inputs: List[str], to_address: str, abi_path: str
):
"""Wrapper around starknet estimate_message_fee"""
output = run_starknet(
[
"estimate_message_fee",
"--from_address",
from_address,
"--function",
function,
"--inputs",
*inputs,
"--address",
to_address,
"--abi",
abi_path,
]
)

return extract_fee(output.stdout)


def assert_transaction(tx_hash, expected_status, expected_signature=None):
"""Wrapper around starknet get_transaction"""
output = run_starknet(["get_transaction", "--hash", tx_hash])
Expand Down

0 comments on commit 19fb3ea

Please sign in to comment.