Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modexp precompile allow arbitrary input length #1576

Merged
merged 1 commit into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
modexp precompile allow arbitrary input length
  • Loading branch information
jangko committed May 11, 2023
commit 4e58f9d79a7d0ce6ed8888dd7fc99dcb9a64314d
5 changes: 5 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,8 @@
url = https://github.com/status-im/portal-spec-tests.git
ignore = untracked
branch = master
[submodule "vendor/libtommath"]
path = vendor/libtommath
url = https://github.com/libtom/libtommath
ignore = untracked
branch = develop
10 changes: 10 additions & 0 deletions nimbus/evm/interpreter/utils/utils_numeric.nim
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ proc rangeToPadded*[T: StUint](x: openArray[byte], first, last, size: int): T =
allowPadding = true
)

proc rangeToPadded*(x: openArray[byte], first, size: int): seq[byte] =
let last = first + size - 1
let lo = max(0, first)
let hi = min(min(x.high, last), (lo+size)-1)

result = newSeq[byte](size)
if not(lo <= hi):
return # 0
result[0..hi-lo] = x.toOpenArray(lo, hi)

# calculates the memory size required for a step
func calcMemSize*(offset, length: int): int {.inline.} =
if length.isZero: return 0
Expand Down
218 changes: 218 additions & 0 deletions nimbus/evm/modexp.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import
std/strutils

from os import DirSep, AltSep

const
vendorPath = currentSourcePath.rsplit({DirSep, AltSep}, 3)[0] & "/vendor"
srcPath = vendorPath & "/libtommath"

{.passc: "-IMP_32BIT"}
{.compile: srcPath & "/mp_radix_size.c"}
{.compile: srcPath & "/mp_to_radix.c"}
{.compile: srcPath & "/mp_init_u64.c"}
{.compile: srcPath & "/mp_init_i32.c"}
{.compile: srcPath & "/mp_init_multi.c"}
{.compile: srcPath & "/mp_init.c"}
{.compile: srcPath & "/mp_init_size.c"}
{.compile: srcPath & "/mp_init_copy.c"}
{.compile: srcPath & "/mp_invmod.c"}
{.compile: srcPath & "/mp_abs.c"}
{.compile: srcPath & "/mp_set_u64.c"}
{.compile: srcPath & "/mp_set_u32.c"}
{.compile: srcPath & "/mp_set_i32.c"}
{.compile: srcPath & "/mp_get_i32.c"}
{.compile: srcPath & "/mp_get_i64.c"}
{.compile: srcPath & "/mp_exptmod.c"}
{.compile: srcPath & "/mp_clear_multi.c"}
{.compile: srcPath & "/mp_clear.c"}
{.compile: srcPath & "/mp_montgomery_reduce.c"}
{.compile: srcPath & "/mp_clamp.c"}
{.compile: srcPath & "/mp_grow.c"}
{.compile: srcPath & "/mp_mul.c"}
{.compile: srcPath & "/mp_mul_2.c"}
{.compile: srcPath & "/mp_mul_2d.c"}
{.compile: srcPath & "/mp_mod_2d.c"}
{.compile: srcPath & "/mp_log_n.c"}
{.compile: srcPath & "/mp_div_2.c"}
{.compile: srcPath & "/mp_div_d.c"}
{.compile: srcPath & "/mp_add.c"}
{.compile: srcPath & "/mp_sub.c"}
{.compile: srcPath & "/mp_exch.c"}
{.compile: srcPath & "/mp_rshd.c"}
{.compile: srcPath & "/mp_lshd.c"}
{.compile: srcPath & "/mp_zero.c"}
{.compile: srcPath & "/mp_dr_reduce.c"}
{.compile: srcPath & "/mp_cmp_mag.c"}
{.compile: srcPath & "/mp_cutoffs.c"}
{.compile: srcPath & "/mp_reduce.c"}
{.compile: srcPath & "/mp_count_bits.c"}
{.compile: srcPath & "/mp_montgomery_setup.c"}
{.compile: srcPath & "/mp_dr_setup.c"}
{.compile: srcPath & "/mp_reduce_2k_setup.c"}
{.compile: srcPath & "/mp_reduce_2k_setup_l.c"}
{.compile: srcPath & "/mp_reduce_2k.c"}
{.compile: srcPath & "/mp_reduce_2k_l.c"}
{.compile: srcPath & "/mp_reduce_is_2k_l.c"}
{.compile: srcPath & "/mp_reduce_is_2k.c"}
{.compile: srcPath & "/mp_reduce_setup.c"}
{.compile: srcPath & "/mp_dr_is_modulus.c"}
{.compile: srcPath & "/mp_mulmod.c"}
{.compile: srcPath & "/mp_set.c"}
{.compile: srcPath & "/mp_mod.c"}
{.compile: srcPath & "/mp_copy.c"}
{.compile: srcPath & "/mp_div.c"}
{.compile: srcPath & "/mp_div_2d.c"}
{.compile: srcPath & "/mp_mul_d.c"}
{.compile: srcPath & "/mp_2expt.c"}
{.compile: srcPath & "/mp_cmp.c"}
{.compile: srcPath & "/mp_cmp_d.c"}
{.compile: srcPath & "/mp_log.c"}
{.compile: srcPath & "/mp_sub_d.c"}
{.compile: srcPath & "/mp_add_d.c"}
{.compile: srcPath & "/mp_cnt_lsb.c"}
{.compile: srcPath & "/mp_expt_n.c"}
{.compile: srcPath & "/mp_get_mag_u32.c"}
{.compile: srcPath & "/mp_get_mag_u64.c"}
{.compile: srcPath & "/mp_from_ubin.c"}
{.compile: srcPath & "/mp_ubin_size.c"}
{.compile: srcPath & "/mp_to_ubin.c"}
{.compile: srcPath & "/mp_montgomery_calc_normalization.c"}
{.compile: srcPath & "/s_mp_exptmod.c"}
{.compile: srcPath & "/s_mp_exptmod_fast.c"}
{.compile: srcPath & "/s_mp_zero_digs.c"}
{.compile: srcPath & "/s_mp_montgomery_reduce_comba.c"}
{.compile: srcPath & "/s_mp_add.c"}
{.compile: srcPath & "/s_mp_sub.c"}
{.compile: srcPath & "/s_mp_mul.c"}
{.compile: srcPath & "/s_mp_mul_comba.c"}
{.compile: srcPath & "/s_mp_mul_toom.c"}
{.compile: srcPath & "/s_mp_mul_karatsuba.c"}
{.compile: srcPath & "/s_mp_mul_balance.c"}
{.compile: srcPath & "/s_mp_copy_digs.c"}
{.compile: srcPath & "/s_mp_div_3.c"}
{.compile: srcPath & "/s_mp_sqr.c"}
{.compile: srcPath & "/s_mp_sqr_comba.c"}
{.compile: srcPath & "/s_mp_sqr_toom.c"}
{.compile: srcPath & "/s_mp_sqr_karatsuba.c"}
{.compile: srcPath & "/s_mp_zero_buf.c"}
{.compile: srcPath & "/s_mp_radix_map.c"}
{.compile: srcPath & "/s_mp_invmod.c"}
{.compile: srcPath & "/s_mp_invmod_odd.c"}
{.compile: srcPath & "/s_mp_mul_high.c"}
{.compile: srcPath & "/s_mp_mul_high_comba.c"}
{.compile: srcPath & "/s_mp_div_recursive.c"}
{.compile: srcPath & "/s_mp_div_school.c"}
{.compile: srcPath & "/s_mp_fp_log_d.c"}
{.compile: srcPath & "/s_mp_fp_log.c"}

{.passc: "-I" & srcPath .}

type
mp_int {.importc: "mp_int",
header: "tommath.h", byref.} = object

mp_digit = uint32

mp_err {.importc: "mp_err",
header: "tommath.h".} = cint

mp_ord = cint

{.pragma: mp_abi, importc, cdecl, header: "tommath.h".}

const
MP_OKAY = 0.mp_err

MP_LT = -1
MP_EQ = 0
MP_GT = 1

template getPtr(z: untyped): untyped =
when (NimMajor, NimMinor) > (1,6):
z.addr
else:
z.unsafeAddr

# init a bignum
# proc mp_init(a: mp_int): mp_err {.mp_abi.}
# proc mp_init_size(a: mp_int, size: cint): mp_err {.mp_abi.}

# init multiple bignum, 2nd, 3rd, and soon use addr, terminated with nil
proc mp_init_multi(mp: mp_int): mp_err {.mp_abi, varargs.}

# free a bignum
proc mp_clear(a: mp_int) {.mp_abi.}

# clear multiple mp_ints, terminated with nil
proc mp_clear_multi(mp: mp_int) {.mp_abi, varargs.}

# compare against a single digit
proc mp_cmp_d(a: mp_int, b: mp_digit): mp_ord {.mp_abi.}

# conversion from/to big endian bytes
proc mp_ubin_size(a: mp_int): csize_t {.mp_abi.}
proc mp_from_ubin(a: mp_int, buf: ptr byte, size: csize_t): mp_err {.mp_abi.}
proc mp_to_ubin(a: mp_int, buf: ptr byte, maxlen: csize_t, written: var csize_t): mp_err {.mp_abi.}

# Y = G**X (mod P)
proc mp_exptmod(G, X, P, Y: mp_int): mp_err {.mp_abi.}

proc mp_get_i32(a: mp_int): int32 {.mp_abi.}
proc mp_get_u32(a: mp_int): uint32 =
cast[uint32](mp_get_i32(a))

# proc mp_init_u64(a: mp_int, b: uint64): mp_err {.mp_abi.}
# proc mp_set_u64(a: mp_int, b: uint64) {.mp_abi.}

proc mp_to_radix(a: mp_int, str: ptr char, maxlen: csize_t, written: var csize_t, radix: cint): mp_err {.mp_abi.}
proc mp_radix_size(a: mp_int, radix: cint, size: var csize_t): mp_err {.mp_abi.}

proc toString*(a: mp_int): string =
var size: csize_t
if mp_radix_size(a, 10.cint, size) != MP_OKAY:
return
if size.int == 0:
return
result = newString(size.int)
if mp_to_radix(a, result[0].getPtr, size, size, 10.cint) != MP_OKAY:
return
result.setLen(size-1)

proc modExp*(b, e, m: openArray[byte]): seq[byte] =
var
base, exp, modulo, res: mp_int

if mp_init_multi(base, exp.addr, modulo.addr, nil) != MP_OKAY:
return

if m.len > 0:
discard mp_from_ubin(modulo, m[0].getPtr, m.len.csize_t)
if mp_cmp_d(modulo, 1.mp_digit) <= MP_EQ:
# EVM special case 1
# If m == 0: EVM returns 0.
# If m == 1: we can shortcut that to 0 as well
mp_clear(modulo)
return @[0.byte]

if e.len > 0:
discard mp_from_ubin(exp, e[0].getPtr, e.len.csize_t)
if mp_cmp_d(exp, 0.mp_digit) == MP_EQ:
# EVM special case 2
# If 0^0: EVM returns 1
# For all x != 0, x^0 == 1 as well
mp_clear_multi(exp, modulo.addr, nil)
return @[1.byte]

if b.len > 0:
discard mp_from_ubin(base, b[0].getPtr, b.len.csize_t)

if mp_exptmod(base, exp, modulo, res) == MP_OKAY:
let size = mp_ubin_size(res)
if size.int > 0:
var written: csize_t
result = newSeq[byte](size.int)
discard mp_to_ubin(res, result[0].getPtr, size, written)
result.setLen(written)

mp_clear_multi(base, exp.addr, modulo.addr, res.addr, nil)
81 changes: 23 additions & 58 deletions nimbus/evm/precompiles.nim
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import
./interpreter/[gas_meter, gas_costs, utils/utils_numeric],
../errors, eth/[common, keys], chronicles,
nimcrypto/[ripemd, sha2, utils], bncurve/[fields, groups],
../common/evmforks
../common/evmforks,
./modexp

type
PrecompileAddresses* = enum
Expand Down Expand Up @@ -162,47 +163,6 @@ proc identity*(computation: Computation) =
computation.output = computation.msg.data
#trace "Identity precompile", output = computation.output.toHex

proc modExpInternal(computation: Computation, baseLen, expLen, modLen: int, T: type StUint) =
template data: untyped {.dirty.} =
computation.msg.data

let
base = data.rangeToPadded[:T](96, 95 + baseLen, baseLen)
exp = data.rangeToPadded[:T](96 + baseLen, 95 + baseLen + expLen, expLen)
modulo = data.rangeToPadded[:T](96 + baseLen + expLen, 95 + baseLen + expLen + modLen, modLen)

# TODO: specs mentions that we should return in "M" format
# i.e. if Base and exp are uint512 and Modulo an uint256
# we should return a 256-bit big-endian byte array

# Force static evaluation
func zero(): array[T.bits div 8, byte] {.compileTime.} = discard
func one(): array[T.bits div 8, byte] {.compileTime.} =
when cpuEndian == bigEndian:
result[0] = 1
else:
result[^1] = 1

# Start with EVM special cases
let output = if modulo <= 1:
# If m == 0: EVM returns 0.
# If m == 1: we can shortcut that to 0 as well
zero()
elif exp.isZero():
# If 0^0: EVM returns 1
# For all x != 0, x^0 == 1 as well
one()
else:
powmod(base, exp, modulo).toByteArrayBE

# maximum output len is the same as modLen
# if it less than modLen, it will be zero padded at left
if output.len >= modLen:
computation.output = @(output[^modLen..^1])
else:
computation.output = newSeq[byte](modLen)
computation.output[^output.len..^1] = output[0..^1]

proc modExpFee(c: Computation, baseLen, expLen, modLen: UInt256, fork: EVMFork): GasInt =
template data: untyped {.dirty.} =
c.msg.data
Expand Down Expand Up @@ -247,8 +207,6 @@ proc modExpFee(c: Computation, baseLen, expLen, modLen: UInt256, fork: EVMFork):
let gasFee = if fork >= FkBerlin: gasCalc(mulComplexityEIP2565, GasQuadDivisorEIP2565)
else: gasCalc(mulComplexity, GasQuadDivisor)

# let gasFee = gasCalc(mulComplexity, GasQuadDivisor)

if gasFee > high(GasInt).u256:
raise newException(OutOfGas, "modExp gas overflow")

Expand Down Expand Up @@ -282,21 +240,28 @@ proc modExp*(c: Computation, fork: EVMFork = FkByzantium) =
c.output = @[]
return

let maxBytes = max(baseLen, max(expLen, modLen))
if maxBytes <= 32:
c.modExpInternal(baseLen, expLen, modLen, UInt256)
elif maxBytes <= 64:
c.modExpInternal(baseLen, expLen, modLen, StUint[512])
elif maxBytes <= 128:
c.modExpInternal(baseLen, expLen, modLen, StUint[1024])
elif maxBytes <= 256:
c.modExpInternal(baseLen, expLen, modLen, StUint[2048])
elif maxBytes <= 512:
c.modExpInternal(baseLen, expLen, modLen, StUint[4096])
elif maxBytes <= 1024:
c.modExpInternal(baseLen, expLen, modLen, StUint[8192])
const maxSize = int32.high.u256
if baseL > maxSize or expL > maxSize or modL > maxSize:
raise newException(EVMError, "The Nimbus VM doesn't support oversized modExp operand")

# TODO:
# add EVM special case:
# - modulo <= 1: return zero
# - exp == zero: return one

let output = modExp(
data.rangeToPadded(96, baseLen),
data.rangeToPadded(96 + baseLen, expLen),
data.rangeToPadded(96 + baseLen + expLen, modLen)
)

# maximum output len is the same as modLen
# if it less than modLen, it will be zero padded at left
if output.len >= modLen:
c.output = @(output[^modLen..^1])
else:
raise newException(EVMError, "The Nimbus VM doesn't support modular exponentiation with numbers larger than uint8192")
c.output = newSeq[byte](modLen)
c.output[^output.len..^1] = output[0..^1]

proc bn256ecAdd*(computation: Computation, fork: EVMFork = FkByzantium) =
let gasFee = if fork < FkIstanbul: GasECAdd else: GasECAddIstanbul
Expand Down
6 changes: 5 additions & 1 deletion tests/test_blockchain_json.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# according to those terms.

import
std/[json, os, tables, strutils, options],
std/[json, os, tables, strutils, options, times],
unittest2,
eth/rlp, eth/trie/trie_defs, eth/common/eth_types_rlp,
stew/byteutils,
Expand Down Expand Up @@ -421,6 +421,8 @@ proc blockchainJsonMain*(debugMode = false) =
when isMainModule:
var message: string

let start = getTime()

## Processing command line arguments
if test_config.processArguments(message) != test_config.Success:
echo message
Expand All @@ -431,6 +433,8 @@ when isMainModule:
quit(QuitSuccess)

blockchainJsonMain(true)
let elpd = getTime() - start
echo "TIME: ", elpd

# lastBlockHash -> every fixture has it, hash of a block header
# genesisRLP -> NOT every fixture has it, rlp bytes of genesis block header
Expand Down
Loading
Loading