Skip to content

Commit

Permalink
modexp precompile allow arbitrary input length
Browse files Browse the repository at this point in the history
  • Loading branch information
jangko committed May 11, 2023
1 parent 408394a commit 4e58f9d
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 60 deletions.
5 changes: 5 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,8 @@
url = https://github.com/status-im/portal-spec-tests.git
ignore = untracked
branch = master
[submodule "vendor/libtommath"]
path = vendor/libtommath
url = https://github.com/libtom/libtommath
ignore = untracked
branch = develop
10 changes: 10 additions & 0 deletions nimbus/evm/interpreter/utils/utils_numeric.nim
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ proc rangeToPadded*[T: StUint](x: openArray[byte], first, last, size: int): T =
allowPadding = true
)

proc rangeToPadded*(x: openArray[byte], first, size: int): seq[byte] =
let last = first + size - 1
let lo = max(0, first)
let hi = min(min(x.high, last), (lo+size)-1)

result = newSeq[byte](size)
if not(lo <= hi):
return # 0
result[0..hi-lo] = x.toOpenArray(lo, hi)

# calculates the memory size required for a step
func calcMemSize*(offset, length: int): int {.inline.} =
if length.isZero: return 0
Expand Down
218 changes: 218 additions & 0 deletions nimbus/evm/modexp.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import
std/strutils

from os import DirSep, AltSep

const
vendorPath = currentSourcePath.rsplit({DirSep, AltSep}, 3)[0] & "/vendor"
srcPath = vendorPath & "/libtommath"

{.passc: "-IMP_32BIT"}
{.compile: srcPath & "/mp_radix_size.c"}
{.compile: srcPath & "/mp_to_radix.c"}
{.compile: srcPath & "/mp_init_u64.c"}
{.compile: srcPath & "/mp_init_i32.c"}
{.compile: srcPath & "/mp_init_multi.c"}
{.compile: srcPath & "/mp_init.c"}
{.compile: srcPath & "/mp_init_size.c"}
{.compile: srcPath & "/mp_init_copy.c"}
{.compile: srcPath & "/mp_invmod.c"}
{.compile: srcPath & "/mp_abs.c"}
{.compile: srcPath & "/mp_set_u64.c"}
{.compile: srcPath & "/mp_set_u32.c"}
{.compile: srcPath & "/mp_set_i32.c"}
{.compile: srcPath & "/mp_get_i32.c"}
{.compile: srcPath & "/mp_get_i64.c"}
{.compile: srcPath & "/mp_exptmod.c"}
{.compile: srcPath & "/mp_clear_multi.c"}
{.compile: srcPath & "/mp_clear.c"}
{.compile: srcPath & "/mp_montgomery_reduce.c"}
{.compile: srcPath & "/mp_clamp.c"}
{.compile: srcPath & "/mp_grow.c"}
{.compile: srcPath & "/mp_mul.c"}
{.compile: srcPath & "/mp_mul_2.c"}
{.compile: srcPath & "/mp_mul_2d.c"}
{.compile: srcPath & "/mp_mod_2d.c"}
{.compile: srcPath & "/mp_log_n.c"}
{.compile: srcPath & "/mp_div_2.c"}
{.compile: srcPath & "/mp_div_d.c"}
{.compile: srcPath & "/mp_add.c"}
{.compile: srcPath & "/mp_sub.c"}
{.compile: srcPath & "/mp_exch.c"}
{.compile: srcPath & "/mp_rshd.c"}
{.compile: srcPath & "/mp_lshd.c"}
{.compile: srcPath & "/mp_zero.c"}
{.compile: srcPath & "/mp_dr_reduce.c"}
{.compile: srcPath & "/mp_cmp_mag.c"}
{.compile: srcPath & "/mp_cutoffs.c"}
{.compile: srcPath & "/mp_reduce.c"}
{.compile: srcPath & "/mp_count_bits.c"}
{.compile: srcPath & "/mp_montgomery_setup.c"}
{.compile: srcPath & "/mp_dr_setup.c"}
{.compile: srcPath & "/mp_reduce_2k_setup.c"}
{.compile: srcPath & "/mp_reduce_2k_setup_l.c"}
{.compile: srcPath & "/mp_reduce_2k.c"}
{.compile: srcPath & "/mp_reduce_2k_l.c"}
{.compile: srcPath & "/mp_reduce_is_2k_l.c"}
{.compile: srcPath & "/mp_reduce_is_2k.c"}
{.compile: srcPath & "/mp_reduce_setup.c"}
{.compile: srcPath & "/mp_dr_is_modulus.c"}
{.compile: srcPath & "/mp_mulmod.c"}
{.compile: srcPath & "/mp_set.c"}
{.compile: srcPath & "/mp_mod.c"}
{.compile: srcPath & "/mp_copy.c"}
{.compile: srcPath & "/mp_div.c"}
{.compile: srcPath & "/mp_div_2d.c"}
{.compile: srcPath & "/mp_mul_d.c"}
{.compile: srcPath & "/mp_2expt.c"}
{.compile: srcPath & "/mp_cmp.c"}
{.compile: srcPath & "/mp_cmp_d.c"}
{.compile: srcPath & "/mp_log.c"}
{.compile: srcPath & "/mp_sub_d.c"}
{.compile: srcPath & "/mp_add_d.c"}
{.compile: srcPath & "/mp_cnt_lsb.c"}
{.compile: srcPath & "/mp_expt_n.c"}
{.compile: srcPath & "/mp_get_mag_u32.c"}
{.compile: srcPath & "/mp_get_mag_u64.c"}
{.compile: srcPath & "/mp_from_ubin.c"}
{.compile: srcPath & "/mp_ubin_size.c"}
{.compile: srcPath & "/mp_to_ubin.c"}
{.compile: srcPath & "/mp_montgomery_calc_normalization.c"}
{.compile: srcPath & "/s_mp_exptmod.c"}
{.compile: srcPath & "/s_mp_exptmod_fast.c"}
{.compile: srcPath & "/s_mp_zero_digs.c"}
{.compile: srcPath & "/s_mp_montgomery_reduce_comba.c"}
{.compile: srcPath & "/s_mp_add.c"}
{.compile: srcPath & "/s_mp_sub.c"}
{.compile: srcPath & "/s_mp_mul.c"}
{.compile: srcPath & "/s_mp_mul_comba.c"}
{.compile: srcPath & "/s_mp_mul_toom.c"}
{.compile: srcPath & "/s_mp_mul_karatsuba.c"}
{.compile: srcPath & "/s_mp_mul_balance.c"}
{.compile: srcPath & "/s_mp_copy_digs.c"}
{.compile: srcPath & "/s_mp_div_3.c"}
{.compile: srcPath & "/s_mp_sqr.c"}
{.compile: srcPath & "/s_mp_sqr_comba.c"}
{.compile: srcPath & "/s_mp_sqr_toom.c"}
{.compile: srcPath & "/s_mp_sqr_karatsuba.c"}
{.compile: srcPath & "/s_mp_zero_buf.c"}
{.compile: srcPath & "/s_mp_radix_map.c"}
{.compile: srcPath & "/s_mp_invmod.c"}
{.compile: srcPath & "/s_mp_invmod_odd.c"}
{.compile: srcPath & "/s_mp_mul_high.c"}
{.compile: srcPath & "/s_mp_mul_high_comba.c"}
{.compile: srcPath & "/s_mp_div_recursive.c"}
{.compile: srcPath & "/s_mp_div_school.c"}
{.compile: srcPath & "/s_mp_fp_log_d.c"}
{.compile: srcPath & "/s_mp_fp_log.c"}

{.passc: "-I" & srcPath .}

type
mp_int {.importc: "mp_int",
header: "tommath.h", byref.} = object

mp_digit = uint32

mp_err {.importc: "mp_err",
header: "tommath.h".} = cint

mp_ord = cint

{.pragma: mp_abi, importc, cdecl, header: "tommath.h".}

const
MP_OKAY = 0.mp_err

MP_LT = -1
MP_EQ = 0
MP_GT = 1

template getPtr(z: untyped): untyped =
when (NimMajor, NimMinor) > (1,6):
z.addr
else:
z.unsafeAddr

# init a bignum
# proc mp_init(a: mp_int): mp_err {.mp_abi.}
# proc mp_init_size(a: mp_int, size: cint): mp_err {.mp_abi.}

# init multiple bignum, 2nd, 3rd, and soon use addr, terminated with nil
proc mp_init_multi(mp: mp_int): mp_err {.mp_abi, varargs.}

# free a bignum
proc mp_clear(a: mp_int) {.mp_abi.}

# clear multiple mp_ints, terminated with nil
proc mp_clear_multi(mp: mp_int) {.mp_abi, varargs.}

# compare against a single digit
proc mp_cmp_d(a: mp_int, b: mp_digit): mp_ord {.mp_abi.}

# conversion from/to big endian bytes
proc mp_ubin_size(a: mp_int): csize_t {.mp_abi.}
proc mp_from_ubin(a: mp_int, buf: ptr byte, size: csize_t): mp_err {.mp_abi.}
proc mp_to_ubin(a: mp_int, buf: ptr byte, maxlen: csize_t, written: var csize_t): mp_err {.mp_abi.}

# Y = G**X (mod P)
proc mp_exptmod(G, X, P, Y: mp_int): mp_err {.mp_abi.}

proc mp_get_i32(a: mp_int): int32 {.mp_abi.}
proc mp_get_u32(a: mp_int): uint32 =
cast[uint32](mp_get_i32(a))

# proc mp_init_u64(a: mp_int, b: uint64): mp_err {.mp_abi.}
# proc mp_set_u64(a: mp_int, b: uint64) {.mp_abi.}

proc mp_to_radix(a: mp_int, str: ptr char, maxlen: csize_t, written: var csize_t, radix: cint): mp_err {.mp_abi.}
proc mp_radix_size(a: mp_int, radix: cint, size: var csize_t): mp_err {.mp_abi.}

proc toString*(a: mp_int): string =
var size: csize_t
if mp_radix_size(a, 10.cint, size) != MP_OKAY:
return
if size.int == 0:
return
result = newString(size.int)
if mp_to_radix(a, result[0].getPtr, size, size, 10.cint) != MP_OKAY:
return
result.setLen(size-1)

proc modExp*(b, e, m: openArray[byte]): seq[byte] =
var
base, exp, modulo, res: mp_int

if mp_init_multi(base, exp.addr, modulo.addr, nil) != MP_OKAY:
return

if m.len > 0:
discard mp_from_ubin(modulo, m[0].getPtr, m.len.csize_t)
if mp_cmp_d(modulo, 1.mp_digit) <= MP_EQ:
# EVM special case 1
# If m == 0: EVM returns 0.
# If m == 1: we can shortcut that to 0 as well
mp_clear(modulo)
return @[0.byte]

if e.len > 0:
discard mp_from_ubin(exp, e[0].getPtr, e.len.csize_t)
if mp_cmp_d(exp, 0.mp_digit) == MP_EQ:
# EVM special case 2
# If 0^0: EVM returns 1
# For all x != 0, x^0 == 1 as well
mp_clear_multi(exp, modulo.addr, nil)
return @[1.byte]

if b.len > 0:
discard mp_from_ubin(base, b[0].getPtr, b.len.csize_t)

if mp_exptmod(base, exp, modulo, res) == MP_OKAY:
let size = mp_ubin_size(res)
if size.int > 0:
var written: csize_t
result = newSeq[byte](size.int)
discard mp_to_ubin(res, result[0].getPtr, size, written)
result.setLen(written)

mp_clear_multi(base, exp.addr, modulo.addr, res.addr, nil)
81 changes: 23 additions & 58 deletions nimbus/evm/precompiles.nim
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ import
./interpreter/[gas_meter, gas_costs, utils/utils_numeric],
../errors, eth/[common, keys], chronicles,
nimcrypto/[ripemd, sha2, utils], bncurve/[fields, groups],
../common/evmforks
../common/evmforks,
./modexp

type
PrecompileAddresses* = enum
Expand Down Expand Up @@ -162,47 +163,6 @@ proc identity*(computation: Computation) =
computation.output = computation.msg.data
#trace "Identity precompile", output = computation.output.toHex

proc modExpInternal(computation: Computation, baseLen, expLen, modLen: int, T: type StUint) =
template data: untyped {.dirty.} =
computation.msg.data

let
base = data.rangeToPadded[:T](96, 95 + baseLen, baseLen)
exp = data.rangeToPadded[:T](96 + baseLen, 95 + baseLen + expLen, expLen)
modulo = data.rangeToPadded[:T](96 + baseLen + expLen, 95 + baseLen + expLen + modLen, modLen)

# TODO: specs mentions that we should return in "M" format
# i.e. if Base and exp are uint512 and Modulo an uint256
# we should return a 256-bit big-endian byte array

# Force static evaluation
func zero(): array[T.bits div 8, byte] {.compileTime.} = discard
func one(): array[T.bits div 8, byte] {.compileTime.} =
when cpuEndian == bigEndian:
result[0] = 1
else:
result[^1] = 1

# Start with EVM special cases
let output = if modulo <= 1:
# If m == 0: EVM returns 0.
# If m == 1: we can shortcut that to 0 as well
zero()
elif exp.isZero():
# If 0^0: EVM returns 1
# For all x != 0, x^0 == 1 as well
one()
else:
powmod(base, exp, modulo).toByteArrayBE

# maximum output len is the same as modLen
# if it less than modLen, it will be zero padded at left
if output.len >= modLen:
computation.output = @(output[^modLen..^1])
else:
computation.output = newSeq[byte](modLen)
computation.output[^output.len..^1] = output[0..^1]

proc modExpFee(c: Computation, baseLen, expLen, modLen: UInt256, fork: EVMFork): GasInt =
template data: untyped {.dirty.} =
c.msg.data
Expand Down Expand Up @@ -247,8 +207,6 @@ proc modExpFee(c: Computation, baseLen, expLen, modLen: UInt256, fork: EVMFork):
let gasFee = if fork >= FkBerlin: gasCalc(mulComplexityEIP2565, GasQuadDivisorEIP2565)
else: gasCalc(mulComplexity, GasQuadDivisor)

# let gasFee = gasCalc(mulComplexity, GasQuadDivisor)

if gasFee > high(GasInt).u256:
raise newException(OutOfGas, "modExp gas overflow")

Expand Down Expand Up @@ -282,21 +240,28 @@ proc modExp*(c: Computation, fork: EVMFork = FkByzantium) =
c.output = @[]
return

let maxBytes = max(baseLen, max(expLen, modLen))
if maxBytes <= 32:
c.modExpInternal(baseLen, expLen, modLen, UInt256)
elif maxBytes <= 64:
c.modExpInternal(baseLen, expLen, modLen, StUint[512])
elif maxBytes <= 128:
c.modExpInternal(baseLen, expLen, modLen, StUint[1024])
elif maxBytes <= 256:
c.modExpInternal(baseLen, expLen, modLen, StUint[2048])
elif maxBytes <= 512:
c.modExpInternal(baseLen, expLen, modLen, StUint[4096])
elif maxBytes <= 1024:
c.modExpInternal(baseLen, expLen, modLen, StUint[8192])
const maxSize = int32.high.u256
if baseL > maxSize or expL > maxSize or modL > maxSize:
raise newException(EVMError, "The Nimbus VM doesn't support oversized modExp operand")

# TODO:
# add EVM special case:
# - modulo <= 1: return zero
# - exp == zero: return one

let output = modExp(
data.rangeToPadded(96, baseLen),
data.rangeToPadded(96 + baseLen, expLen),
data.rangeToPadded(96 + baseLen + expLen, modLen)
)

# maximum output len is the same as modLen
# if it less than modLen, it will be zero padded at left
if output.len >= modLen:
c.output = @(output[^modLen..^1])
else:
raise newException(EVMError, "The Nimbus VM doesn't support modular exponentiation with numbers larger than uint8192")
c.output = newSeq[byte](modLen)
c.output[^output.len..^1] = output[0..^1]

proc bn256ecAdd*(computation: Computation, fork: EVMFork = FkByzantium) =
let gasFee = if fork < FkIstanbul: GasECAdd else: GasECAddIstanbul
Expand Down
6 changes: 5 additions & 1 deletion tests/test_blockchain_json.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# according to those terms.

import
std/[json, os, tables, strutils, options],
std/[json, os, tables, strutils, options, times],
unittest2,
eth/rlp, eth/trie/trie_defs, eth/common/eth_types_rlp,
stew/byteutils,
Expand Down Expand Up @@ -421,6 +421,8 @@ proc blockchainJsonMain*(debugMode = false) =
when isMainModule:
var message: string

let start = getTime()

## Processing command line arguments
if test_config.processArguments(message) != test_config.Success:
echo message
Expand All @@ -431,6 +433,8 @@ when isMainModule:
quit(QuitSuccess)

blockchainJsonMain(true)
let elpd = getTime() - start
echo "TIME: ", elpd

# lastBlockHash -> every fixture has it, hash of a block header
# genesisRLP -> NOT every fixture has it, rlp bytes of genesis block header
Expand Down
Loading

0 comments on commit 4e58f9d

Please sign in to comment.