From 1b6287245400f096e377870acd034a35199cd230 Mon Sep 17 00:00:00 2001 From: MaksymMalicki <81577596+MaksymMalicki@users.noreply.github.com> Date: Tue, 11 Jun 2024 14:19:10 +0200 Subject: [PATCH] Cairo0 unsafe keccak hint (#361) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement CairoKeccakFinalizeHint * Fixes from the comments in PR * Remove type casting from AddOffset * Implement unsafe keccak hint * Code optimizations and bug fixes * Add test case covering out of range error * Resolve merge conflicts * Fix merge conflicts, address comments, add tests * Fill the tests with expected values * Modify test case to include edgecase * Fix test expected values * Fix merging errors, add annotation to hint * Add unsafe keccak hint annotation * Correct hint annotation * Resolve changes named in the comments * Fix unit test after merging main * Address comments from the PR * Modify unit tests to match the input conditions for new consts * Fix after merging main * Fix hint annotation after update * Assign __keccak_max_size an arbitrary value, fix tests * Add Min generic function, fix wording in the hint annotation --------- Co-authored-by: Carmen Irene Cabrera Rodríguez <49727740+cicr99@users.noreply.github.com> --- pkg/hintrunner/hinter/scope.go | 14 +++ pkg/hintrunner/zero/hintcode.go | 1 + pkg/hintrunner/zero/zerohint.go | 2 + pkg/hintrunner/zero/zerohint_keccak.go | 115 +++++++++++++++++++- pkg/hintrunner/zero/zerohint_keccak_test.go | 96 ++++++++++++++++ pkg/utils/math.go | 7 ++ 6 files changed, 229 insertions(+), 6 deletions(-) diff --git a/pkg/hintrunner/hinter/scope.go b/pkg/hintrunner/hinter/scope.go index 3f69a467..81804e2d 100644 --- a/pkg/hintrunner/hinter/scope.go +++ b/pkg/hintrunner/hinter/scope.go @@ -92,6 +92,20 @@ func (sm *ScopeManager) GetVariableValueAsBigInt(name string) (*big.Int, error) return valueBig, nil } +func (sm *ScopeManager) GetVariableValueAsUint64(name string) (uint64, error) { + value, err := sm.GetVariableValue(name) + if err != nil { + return 0, err + } + + valueUint, ok := value.(uint64) + if !ok { + return 0, fmt.Errorf("value: %s is not a uint64", value) + } + + return valueUint, nil +} + func (sm *ScopeManager) getCurrentScope() (*map[string]any, error) { if len(sm.scopes) == 0 { return nil, fmt.Errorf("expected at least one existing scope") diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index 3e4635c5..089758c6 100644 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -111,6 +111,7 @@ const ( blake2sFinalizeCode string = "from starkware.cairo.common.cairo_blake2s.blake2s_utils import IV, blake2s_compress\n\n_n_packed_instances = int(ids.N_PACKED_INSTANCES)\nassert 0 <= _n_packed_instances < 20\n_blake2s_input_chunk_size_felts = int(ids.INPUT_BLOCK_FELTS)\nassert 0 <= _blake2s_input_chunk_size_felts < 100\n\nmessage = [0] * _blake2s_input_chunk_size_felts\nmodified_iv = [IV[0] ^ 0x01010020] + IV[1:]\noutput = blake2s_compress(\n message=message,\n h=modified_iv,\n t0=0,\n t1=0,\n f0=0xffffffff,\n f1=0,\n)\npadding = (modified_iv + message + [0, 0xffffffff] + output) * (_n_packed_instances - 1)\nsegments.write_arg(ids.blake2s_ptr_end, padding)" // ------ Keccak hints related code ------ + unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')" cairoKeccakFinalizeCode string = `# Add dummy pairs of input and output. _keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS) _block_size = int(ids.BLOCK_SIZE) diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index 23edb16e..65f47390 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -155,6 +155,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64 return createKeccakWriteArgsHinter(resolver) case cairoKeccakFinalizeCode: return createCairoKeccakFinalizeHinter(resolver) + case unsafeKeccakCode: + return createUnsafeKeccakHinter(resolver) case blockPermutationCode: return createBlockPermutationHinter(resolver) // Usort hints diff --git a/pkg/hintrunner/zero/zerohint_keccak.go b/pkg/hintrunner/zero/zerohint_keccak.go index fc68396c..0da266b8 100644 --- a/pkg/hintrunner/zero/zerohint_keccak.go +++ b/pkg/hintrunner/zero/zerohint_keccak.go @@ -1,15 +1,17 @@ package zero import ( + "fmt" "math" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" - mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" "github.com/holiman/uint256" + "golang.org/x/crypto/sha3" ) // CairoKeccakFinalize writes the result of F1600 Keccak permutation padded by __keccak_state_size_felts__ zeros to consecutive memory cells, __block_size__ times. @@ -68,6 +70,108 @@ func createCairoKeccakFinalizeHinter(resolver hintReferenceResolver) (hinter.Hin return newCairoKeccakFinalizeHint(keccakPtrEnd), nil } +// UnsafeKeccak computes keccak hash of the data in memory without validity enforcement and writes the result in the `low` and `high` memory cells +// +// `newUnsafeKeccakHint` takes 4 operanders as arguments +// - `data` is the address in memory where the base of the data array to be hashed is stored. Each word in the array is 16 bytes long, except the last one, which could vary +// - `length` is the length of the data to hash +// - `low` is the low part of the produced hash +// - `high` is the high part of the produced hash +func newUnsafeKeccakHint(data, length, high, low hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "UnsafeKeccak", + Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error { + //> from eth_hash.auto import keccak + //> data, length = ids.data, ids.length + //> if '__keccak_max_size' in globals(): + //> assert length <= __keccak_max_size, \ + //> f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \ + //> f'Got: length={length}.' + //> keccak_input = bytearray() + //> for word_i, byte_i in enumerate(range(0, length, 16)): + //> word = memory[data + word_i] + //> n_bytes = min(16, length - byte_i) + //> assert 0 <= word < 2 ** (8 * n_bytes) + //> keccak_input += word.to_bytes(n_bytes, 'big') + //> hashed = keccak(keccak_input) + //> ids.high = int.from_bytes(hashed[:16], 'big') + //> ids.low = int.from_bytes(hashed[16:32], 'big') + + lengthVal, err := hinter.ResolveAsUint64(vm, length) + if err != nil { + return err + } + keccakMaxSize := uint64(1 << 20) + if lengthVal > keccakMaxSize { + return fmt.Errorf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d.", keccakMaxSize, lengthVal) + } + dataPtr, err := hinter.ResolveAsAddress(vm, data) + if err != nil { + return err + } + + keccakInput := make([]byte, 0) + for i := uint64(0); i < lengthVal; i += 16 { + wordFelt, err := vm.Memory.ReadAsElement(dataPtr.SegmentIndex, dataPtr.Offset) + if err != nil { + return err + } + word := uint256.Int(wordFelt.Bits()) + nBytes := utils.Min(lengthVal-i, 16) + if uint64(word.BitLen()) >= 8*nBytes { + return fmt.Errorf("word %v is out range 0 <= word < 2 ** %d", &word, 8*nBytes) + } + wordBytes := word.Bytes20() + keccakInput = append(keccakInput, wordBytes[20-int(nBytes):]...) + *dataPtr, err = dataPtr.AddOffset(1) + if err != nil { + return err + } + } + hash := sha3.NewLegacyKeccak256() + hash.Write(keccakInput) + hashedBytes := hash.Sum(nil) + hashedHigh := new(fp.Element).SetBytes(hashedBytes[:16]) + hashedLow := new(fp.Element).SetBytes(hashedBytes[16:32]) + highAddr, err := high.GetAddress(vm) + if err != nil { + return err + } + hashedHighMV := memory.MemoryValueFromFieldElement(hashedHigh) + err = vm.Memory.WriteToAddress(&highAddr, &hashedHighMV) + if err != nil { + return err + } + lowAddr, err := low.GetAddress(vm) + if err != nil { + return err + } + hashedLowMV := memory.MemoryValueFromFieldElement(hashedLow) + return vm.Memory.WriteToAddress(&lowAddr, &hashedLowMV) + }, + } +} + +func createUnsafeKeccakHinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + data, err := resolver.GetResOperander("data") + if err != nil { + return nil, err + } + length, err := resolver.GetResOperander("length") + if err != nil { + return nil, err + } + high, err := resolver.GetResOperander("high") + if err != nil { + return nil, err + } + low, err := resolver.GetResOperander("low") + if err != nil { + return nil, err + } + return newUnsafeKeccakHint(data, length, high, low), nil +} + // KeccakWriteArgs hint writes Keccak function arguments in memory // // `newKeccakWriteArgsHint` takes 3 operanders as arguments @@ -106,27 +210,27 @@ func newKeccakWriteArgsHint(inputs, low, high hinter.ResOperander) hinter.Hinter lowResultUint256Low.And(&maxUint64, &lowResultUint256Low) lowResulBytes32Low := lowResultUint256Low.Bytes32() lowResultFeltLow, _ := fp.BigEndian.Element(&lowResulBytes32Low) - mvLowLow := mem.MemoryValueFromFieldElement(&lowResultFeltLow) + mvLowLow := memory.MemoryValueFromFieldElement(&lowResultFeltLow) lowResultUint256High := lowUint256 lowResultUint256High.Rsh(&lowResultUint256High, 64) lowResultUint256High.And(&lowResultUint256High, &maxUint64) lowResulBytes32High := lowResultUint256High.Bytes32() lowResultFeltHigh, _ := fp.BigEndian.Element(&lowResulBytes32High) - mvLowHigh := mem.MemoryValueFromFieldElement(&lowResultFeltHigh) + mvLowHigh := memory.MemoryValueFromFieldElement(&lowResultFeltHigh) highResultUint256Low := highUint256 highResultUint256Low.And(&maxUint64, &highResultUint256Low) highResulBytes32Low := highResultUint256Low.Bytes32() highResultFeltLow, _ := fp.BigEndian.Element(&highResulBytes32Low) - mvHighLow := mem.MemoryValueFromFieldElement(&highResultFeltLow) + mvHighLow := memory.MemoryValueFromFieldElement(&highResultFeltLow) highResultUint256High := highUint256 highResultUint256High.Rsh(&highResultUint256High, 64) highResultUint256High.And(&maxUint64, &highResultUint256High) highResulBytes32High := highResultUint256High.Bytes32() highResultFeltHigh, _ := fp.BigEndian.Element(&highResulBytes32High) - mvHighHigh := mem.MemoryValueFromFieldElement(&highResultFeltHigh) + mvHighHigh := memory.MemoryValueFromFieldElement(&highResultFeltHigh) err = vm.Memory.Write(inputsPtr.SegmentIndex, inputsPtr.Offset, &mvLowLow) if err != nil { @@ -137,7 +241,6 @@ func newKeccakWriteArgsHint(inputs, low, high hinter.ResOperander) hinter.Hinter if err != nil { return err } - err = vm.Memory.Write(inputsPtr.SegmentIndex, inputsPtr.Offset+2, &mvHighLow) if err != nil { return err diff --git a/pkg/hintrunner/zero/zerohint_keccak_test.go b/pkg/hintrunner/zero/zerohint_keccak_test.go index 053696d6..326e0722 100644 --- a/pkg/hintrunner/zero/zerohint_keccak_test.go +++ b/pkg/hintrunner/zero/zerohint_keccak_test.go @@ -1,6 +1,7 @@ package zero import ( + "fmt" "testing" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" @@ -46,6 +47,101 @@ func TestZeroHintKeccak(t *testing.T) { }, }, }, + "UnsafeKeccak": { + { + operanders: []*hintOperander{ + {Name: "data", Kind: uninitialized}, + {Name: "length", Kind: apRelative, Value: feltUint64((1 << 20) + 1)}, + {Name: "high", Kind: uninitialized}, + {Name: "low", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"]) + }, + errCheck: errorTextContains(fmt.Sprintf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d.", 1<<20, (1<<20)+1)), + }, + { + operanders: []*hintOperander{ + {Name: "data", Kind: apRelative, Value: addr(5)}, + {Name: "data.0", Kind: apRelative, Value: feltUint64(65537)}, + {Name: "length", Kind: apRelative, Value: feltUint64(1)}, + {Name: "high", Kind: uninitialized}, + {Name: "low", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"]) + }, + errCheck: errorTextContains(fmt.Sprintf("word %v is out range 0 <= word < 2 ** %d", feltUint64(65537), 8)), + }, + { + operanders: []*hintOperander{ + {Name: "data", Kind: apRelative, Value: addr(5)}, + {Name: "data.0", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.1", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.2", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.3", Kind: apRelative, Value: feltUint64(4)}, + {Name: "length", Kind: apRelative, Value: feltUint64(4)}, + {Name: "high", Kind: uninitialized}, + {Name: "low", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"]) + }, + check: func(t *testing.T, ctx *hintTestContext) { + varValueEquals("high", feltString("108955721224378455455648573289483395612"))(t, ctx) + varValueEquals("low", feltString("253531040214470063354971884479696309631"))(t, ctx) + }, + }, + { + operanders: []*hintOperander{ + {Name: "data", Kind: apRelative, Value: addr(5)}, + {Name: "data.0", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.1", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.2", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.3", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.4", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.5", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.6", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.7", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.8", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.9", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.10", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.11", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.12", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.13", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.14", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.15", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.16", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.17", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.18", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.19", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.20", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.21", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.22", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.23", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.24", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.25", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.26", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.27", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.28", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.29", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.30", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.31", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.32", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.33", Kind: apRelative, Value: feltUint64(4)}, + {Name: "length", Kind: apRelative, Value: feltUint64(34)}, + {Name: "high", Kind: uninitialized}, + {Name: "low", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"]) + }, + check: func(t *testing.T, ctx *hintTestContext) { + varValueEquals("high", feltString("43087684015060895958086736298363333858"))(t, ctx) + varValueEquals("low", feltString("115090685687501856751902560332884088627"))(t, ctx) + }, + }, + }, "KeccakWriteArgs": { { operanders: []*hintOperander{ diff --git a/pkg/utils/math.go b/pkg/utils/math.go index 72fb4e0d..27589f31 100644 --- a/pkg/utils/math.go +++ b/pkg/utils/math.go @@ -63,6 +63,13 @@ func Max[T constraints.Integer](a, b T) T { return b } +func Min[T constraints.Integer](a, b T) T { + if a < b { + return a + } + return b +} + // FeltLt implements `a < b` felt comparison. func FeltLt(a, b *fp.Element) bool { return a.Cmp(b) == -1