Skip to content

Commit

Permalink
Cairo0 unsafe keccak hint (#361)
Browse files Browse the repository at this point in the history
* Implement CairoKeccakFinalizeHint

* Fixes from the comments in PR

* Remove type casting from AddOffset

* Implement unsafe keccak hint

* Code optimizations and bug fixes

* Add test case covering out of range error

* Resolve merge conflicts

* Fix merge conflicts, address comments, add tests

* Fill the tests with expected values

* Modify test case to include edgecase

* Fix test expected values

* Fix merging errors, add annotation to hint

* Add unsafe keccak hint annotation

* Correct hint annotation

* Resolve changes named in the comments

* Fix unit test after merging main

* Address comments from the PR

* Modify unit tests to match the input conditions for new consts

* Fix after merging main

* Fix hint annotation after update

* Assign __keccak_max_size an arbitrary value, fix tests

* Add Min generic function, fix wording in the hint annotation

---------

Co-authored-by: Carmen Irene Cabrera Rodríguez <[email protected]>
  • Loading branch information
MaksymMalicki and cicr99 committed Jun 11, 2024
1 parent 77b0805 commit 1b62872
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 6 deletions.
14 changes: 14 additions & 0 deletions pkg/hintrunner/hinter/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ func (sm *ScopeManager) GetVariableValueAsBigInt(name string) (*big.Int, error)
return valueBig, nil
}

func (sm *ScopeManager) GetVariableValueAsUint64(name string) (uint64, error) {
value, err := sm.GetVariableValue(name)
if err != nil {
return 0, err
}

valueUint, ok := value.(uint64)
if !ok {
return 0, fmt.Errorf("value: %s is not a uint64", value)
}

return valueUint, nil
}

func (sm *ScopeManager) getCurrentScope() (*map[string]any, error) {
if len(sm.scopes) == 0 {
return nil, fmt.Errorf("expected at least one existing scope")
Expand Down
1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ const (
blake2sFinalizeCode string = "from starkware.cairo.common.cairo_blake2s.blake2s_utils import IV, blake2s_compress\n\n_n_packed_instances = int(ids.N_PACKED_INSTANCES)\nassert 0 <= _n_packed_instances < 20\n_blake2s_input_chunk_size_felts = int(ids.INPUT_BLOCK_FELTS)\nassert 0 <= _blake2s_input_chunk_size_felts < 100\n\nmessage = [0] * _blake2s_input_chunk_size_felts\nmodified_iv = [IV[0] ^ 0x01010020] + IV[1:]\noutput = blake2s_compress(\n message=message,\n h=modified_iv,\n t0=0,\n t1=0,\n f0=0xffffffff,\n f1=0,\n)\npadding = (modified_iv + message + [0, 0xffffffff] + output) * (_n_packed_instances - 1)\nsegments.write_arg(ids.blake2s_ptr_end, padding)"

// ------ Keccak hints related code ------
unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
cairoKeccakFinalizeCode string = `# Add dummy pairs of input and output.
_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)
_block_size = int(ids.BLOCK_SIZE)
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createKeccakWriteArgsHinter(resolver)
case cairoKeccakFinalizeCode:
return createCairoKeccakFinalizeHinter(resolver)
case unsafeKeccakCode:
return createUnsafeKeccakHinter(resolver)
case blockPermutationCode:
return createBlockPermutationHinter(resolver)
// Usort hints
Expand Down
115 changes: 109 additions & 6 deletions pkg/hintrunner/zero/zerohint_keccak.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package zero

import (
"fmt"
"math"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
"github.com/holiman/uint256"
"golang.org/x/crypto/sha3"
)

// CairoKeccakFinalize writes the result of F1600 Keccak permutation padded by __keccak_state_size_felts__ zeros to consecutive memory cells, __block_size__ times.
Expand Down Expand Up @@ -68,6 +70,108 @@ func createCairoKeccakFinalizeHinter(resolver hintReferenceResolver) (hinter.Hin
return newCairoKeccakFinalizeHint(keccakPtrEnd), nil
}

// UnsafeKeccak computes keccak hash of the data in memory without validity enforcement and writes the result in the `low` and `high` memory cells
//
// `newUnsafeKeccakHint` takes 4 operanders as arguments
// - `data` is the address in memory where the base of the data array to be hashed is stored. Each word in the array is 16 bytes long, except the last one, which could vary
// - `length` is the length of the data to hash
// - `low` is the low part of the produced hash
// - `high` is the high part of the produced hash
func newUnsafeKeccakHint(data, length, high, low hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "UnsafeKeccak",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> from eth_hash.auto import keccak
//> data, length = ids.data, ids.length
//> if '__keccak_max_size' in globals():
//> assert length <= __keccak_max_size, \
//> f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \
//> f'Got: length={length}.'
//> keccak_input = bytearray()
//> for word_i, byte_i in enumerate(range(0, length, 16)):
//> word = memory[data + word_i]
//> n_bytes = min(16, length - byte_i)
//> assert 0 <= word < 2 ** (8 * n_bytes)
//> keccak_input += word.to_bytes(n_bytes, 'big')
//> hashed = keccak(keccak_input)
//> ids.high = int.from_bytes(hashed[:16], 'big')
//> ids.low = int.from_bytes(hashed[16:32], 'big')

lengthVal, err := hinter.ResolveAsUint64(vm, length)
if err != nil {
return err
}
keccakMaxSize := uint64(1 << 20)
if lengthVal > keccakMaxSize {
return fmt.Errorf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d.", keccakMaxSize, lengthVal)
}
dataPtr, err := hinter.ResolveAsAddress(vm, data)
if err != nil {
return err
}

keccakInput := make([]byte, 0)
for i := uint64(0); i < lengthVal; i += 16 {
wordFelt, err := vm.Memory.ReadAsElement(dataPtr.SegmentIndex, dataPtr.Offset)
if err != nil {
return err
}
word := uint256.Int(wordFelt.Bits())
nBytes := utils.Min(lengthVal-i, 16)
if uint64(word.BitLen()) >= 8*nBytes {
return fmt.Errorf("word %v is out range 0 <= word < 2 ** %d", &word, 8*nBytes)
}
wordBytes := word.Bytes20()
keccakInput = append(keccakInput, wordBytes[20-int(nBytes):]...)
*dataPtr, err = dataPtr.AddOffset(1)
if err != nil {
return err
}
}
hash := sha3.NewLegacyKeccak256()
hash.Write(keccakInput)
hashedBytes := hash.Sum(nil)
hashedHigh := new(fp.Element).SetBytes(hashedBytes[:16])
hashedLow := new(fp.Element).SetBytes(hashedBytes[16:32])
highAddr, err := high.GetAddress(vm)
if err != nil {
return err
}
hashedHighMV := memory.MemoryValueFromFieldElement(hashedHigh)
err = vm.Memory.WriteToAddress(&highAddr, &hashedHighMV)
if err != nil {
return err
}
lowAddr, err := low.GetAddress(vm)
if err != nil {
return err
}
hashedLowMV := memory.MemoryValueFromFieldElement(hashedLow)
return vm.Memory.WriteToAddress(&lowAddr, &hashedLowMV)
},
}
}

func createUnsafeKeccakHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
data, err := resolver.GetResOperander("data")
if err != nil {
return nil, err
}
length, err := resolver.GetResOperander("length")
if err != nil {
return nil, err
}
high, err := resolver.GetResOperander("high")
if err != nil {
return nil, err
}
low, err := resolver.GetResOperander("low")
if err != nil {
return nil, err
}
return newUnsafeKeccakHint(data, length, high, low), nil
}

// KeccakWriteArgs hint writes Keccak function arguments in memory
//
// `newKeccakWriteArgsHint` takes 3 operanders as arguments
Expand Down Expand Up @@ -106,27 +210,27 @@ func newKeccakWriteArgsHint(inputs, low, high hinter.ResOperander) hinter.Hinter
lowResultUint256Low.And(&maxUint64, &lowResultUint256Low)
lowResulBytes32Low := lowResultUint256Low.Bytes32()
lowResultFeltLow, _ := fp.BigEndian.Element(&lowResulBytes32Low)
mvLowLow := mem.MemoryValueFromFieldElement(&lowResultFeltLow)
mvLowLow := memory.MemoryValueFromFieldElement(&lowResultFeltLow)

lowResultUint256High := lowUint256
lowResultUint256High.Rsh(&lowResultUint256High, 64)
lowResultUint256High.And(&lowResultUint256High, &maxUint64)
lowResulBytes32High := lowResultUint256High.Bytes32()
lowResultFeltHigh, _ := fp.BigEndian.Element(&lowResulBytes32High)
mvLowHigh := mem.MemoryValueFromFieldElement(&lowResultFeltHigh)
mvLowHigh := memory.MemoryValueFromFieldElement(&lowResultFeltHigh)

highResultUint256Low := highUint256
highResultUint256Low.And(&maxUint64, &highResultUint256Low)
highResulBytes32Low := highResultUint256Low.Bytes32()
highResultFeltLow, _ := fp.BigEndian.Element(&highResulBytes32Low)
mvHighLow := mem.MemoryValueFromFieldElement(&highResultFeltLow)
mvHighLow := memory.MemoryValueFromFieldElement(&highResultFeltLow)

highResultUint256High := highUint256
highResultUint256High.Rsh(&highResultUint256High, 64)
highResultUint256High.And(&maxUint64, &highResultUint256High)
highResulBytes32High := highResultUint256High.Bytes32()
highResultFeltHigh, _ := fp.BigEndian.Element(&highResulBytes32High)
mvHighHigh := mem.MemoryValueFromFieldElement(&highResultFeltHigh)
mvHighHigh := memory.MemoryValueFromFieldElement(&highResultFeltHigh)

err = vm.Memory.Write(inputsPtr.SegmentIndex, inputsPtr.Offset, &mvLowLow)
if err != nil {
Expand All @@ -137,7 +241,6 @@ func newKeccakWriteArgsHint(inputs, low, high hinter.ResOperander) hinter.Hinter
if err != nil {
return err
}

err = vm.Memory.Write(inputsPtr.SegmentIndex, inputsPtr.Offset+2, &mvHighLow)
if err != nil {
return err
Expand Down
96 changes: 96 additions & 0 deletions pkg/hintrunner/zero/zerohint_keccak_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zero

import (
"fmt"
"testing"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
Expand Down Expand Up @@ -46,6 +47,101 @@ func TestZeroHintKeccak(t *testing.T) {
},
},
},
"UnsafeKeccak": {
{
operanders: []*hintOperander{
{Name: "data", Kind: uninitialized},
{Name: "length", Kind: apRelative, Value: feltUint64((1 << 20) + 1)},
{Name: "high", Kind: uninitialized},
{Name: "low", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"])
},
errCheck: errorTextContains(fmt.Sprintf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d.", 1<<20, (1<<20)+1)),
},
{
operanders: []*hintOperander{
{Name: "data", Kind: apRelative, Value: addr(5)},
{Name: "data.0", Kind: apRelative, Value: feltUint64(65537)},
{Name: "length", Kind: apRelative, Value: feltUint64(1)},
{Name: "high", Kind: uninitialized},
{Name: "low", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"])
},
errCheck: errorTextContains(fmt.Sprintf("word %v is out range 0 <= word < 2 ** %d", feltUint64(65537), 8)),
},
{
operanders: []*hintOperander{
{Name: "data", Kind: apRelative, Value: addr(5)},
{Name: "data.0", Kind: apRelative, Value: feltUint64(1)},
{Name: "data.1", Kind: apRelative, Value: feltUint64(2)},
{Name: "data.2", Kind: apRelative, Value: feltUint64(3)},
{Name: "data.3", Kind: apRelative, Value: feltUint64(4)},
{Name: "length", Kind: apRelative, Value: feltUint64(4)},
{Name: "high", Kind: uninitialized},
{Name: "low", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"])
},
check: func(t *testing.T, ctx *hintTestContext) {
varValueEquals("high", feltString("108955721224378455455648573289483395612"))(t, ctx)
varValueEquals("low", feltString("253531040214470063354971884479696309631"))(t, ctx)
},
},
{
operanders: []*hintOperander{
{Name: "data", Kind: apRelative, Value: addr(5)},
{Name: "data.0", Kind: apRelative, Value: feltUint64(1)},
{Name: "data.1", Kind: apRelative, Value: feltUint64(2)},
{Name: "data.2", Kind: apRelative, Value: feltUint64(3)},
{Name: "data.3", Kind: apRelative, Value: feltUint64(4)},
{Name: "data.4", Kind: apRelative, Value: feltUint64(1)},
{Name: "data.5", Kind: apRelative, Value: feltUint64(2)},
{Name: "data.6", Kind: apRelative, Value: feltUint64(3)},
{Name: "data.7", Kind: apRelative, Value: feltUint64(4)},
{Name: "data.8", Kind: apRelative, Value: feltUint64(1)},
{Name: "data.9", Kind: apRelative, Value: feltUint64(2)},
{Name: "data.10", Kind: apRelative, Value: feltUint64(3)},
{Name: "data.11", Kind: apRelative, Value: feltUint64(4)},
{Name: "data.12", Kind: apRelative, Value: feltUint64(1)},
{Name: "data.13", Kind: apRelative, Value: feltUint64(2)},
{Name: "data.14", Kind: apRelative, Value: feltUint64(3)},
{Name: "data.15", Kind: apRelative, Value: feltUint64(4)},
{Name: "data.16", Kind: apRelative, Value: feltUint64(4)},
{Name: "data.17", Kind: apRelative, Value: feltUint64(1)},
{Name: "data.18", Kind: apRelative, Value: feltUint64(2)},
{Name: "data.19", Kind: apRelative, Value: feltUint64(3)},
{Name: "data.20", Kind: apRelative, Value: feltUint64(4)},
{Name: "data.21", Kind: apRelative, Value: feltUint64(1)},
{Name: "data.22", Kind: apRelative, Value: feltUint64(2)},
{Name: "data.23", Kind: apRelative, Value: feltUint64(3)},
{Name: "data.24", Kind: apRelative, Value: feltUint64(4)},
{Name: "data.25", Kind: apRelative, Value: feltUint64(1)},
{Name: "data.26", Kind: apRelative, Value: feltUint64(2)},
{Name: "data.27", Kind: apRelative, Value: feltUint64(3)},
{Name: "data.28", Kind: apRelative, Value: feltUint64(4)},
{Name: "data.29", Kind: apRelative, Value: feltUint64(1)},
{Name: "data.30", Kind: apRelative, Value: feltUint64(2)},
{Name: "data.31", Kind: apRelative, Value: feltUint64(3)},
{Name: "data.32", Kind: apRelative, Value: feltUint64(4)},
{Name: "data.33", Kind: apRelative, Value: feltUint64(4)},
{Name: "length", Kind: apRelative, Value: feltUint64(34)},
{Name: "high", Kind: uninitialized},
{Name: "low", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"])
},
check: func(t *testing.T, ctx *hintTestContext) {
varValueEquals("high", feltString("43087684015060895958086736298363333858"))(t, ctx)
varValueEquals("low", feltString("115090685687501856751902560332884088627"))(t, ctx)
},
},
},
"KeccakWriteArgs": {
{
operanders: []*hintOperander{
Expand Down
7 changes: 7 additions & 0 deletions pkg/utils/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ func Max[T constraints.Integer](a, b T) T {
return b
}

func Min[T constraints.Integer](a, b T) T {
if a < b {
return a
}
return b
}

// FeltLt implements `a < b` felt comparison.
func FeltLt(a, b *fp.Element) bool {
return a.Cmp(b) == -1
Expand Down

0 comments on commit 1b62872

Please sign in to comment.