diff --git a/pkg/hintrunner/hinter/scope.go b/pkg/hintrunner/hinter/scope.go index 3f69a467..81804e2d 100644 --- a/pkg/hintrunner/hinter/scope.go +++ b/pkg/hintrunner/hinter/scope.go @@ -92,6 +92,20 @@ func (sm *ScopeManager) GetVariableValueAsBigInt(name string) (*big.Int, error) return valueBig, nil } +func (sm *ScopeManager) GetVariableValueAsUint64(name string) (uint64, error) { + value, err := sm.GetVariableValue(name) + if err != nil { + return 0, err + } + + valueUint, ok := value.(uint64) + if !ok { + return 0, fmt.Errorf("value: %s is not a uint64", value) + } + + return valueUint, nil +} + func (sm *ScopeManager) getCurrentScope() (*map[string]any, error) { if len(sm.scopes) == 0 { return nil, fmt.Errorf("expected at least one existing scope") diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index 3e4635c5..089758c6 100644 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -111,6 +111,7 @@ const ( blake2sFinalizeCode string = "from starkware.cairo.common.cairo_blake2s.blake2s_utils import IV, blake2s_compress\n\n_n_packed_instances = int(ids.N_PACKED_INSTANCES)\nassert 0 <= _n_packed_instances < 20\n_blake2s_input_chunk_size_felts = int(ids.INPUT_BLOCK_FELTS)\nassert 0 <= _blake2s_input_chunk_size_felts < 100\n\nmessage = [0] * _blake2s_input_chunk_size_felts\nmodified_iv = [IV[0] ^ 0x01010020] + IV[1:]\noutput = blake2s_compress(\n message=message,\n h=modified_iv,\n t0=0,\n t1=0,\n f0=0xffffffff,\n f1=0,\n)\npadding = (modified_iv + message + [0, 0xffffffff] + output) * (_n_packed_instances - 1)\nsegments.write_arg(ids.blake2s_ptr_end, padding)" // ------ Keccak hints related code ------ + unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')" cairoKeccakFinalizeCode string = `# Add dummy pairs of input and output. _keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS) _block_size = int(ids.BLOCK_SIZE) diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index 23edb16e..65f47390 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -155,6 +155,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64 return createKeccakWriteArgsHinter(resolver) case cairoKeccakFinalizeCode: return createCairoKeccakFinalizeHinter(resolver) + case unsafeKeccakCode: + return createUnsafeKeccakHinter(resolver) case blockPermutationCode: return createBlockPermutationHinter(resolver) // Usort hints diff --git a/pkg/hintrunner/zero/zerohint_keccak.go b/pkg/hintrunner/zero/zerohint_keccak.go index fc68396c..0da266b8 100644 --- a/pkg/hintrunner/zero/zerohint_keccak.go +++ b/pkg/hintrunner/zero/zerohint_keccak.go @@ -1,15 +1,17 @@ package zero import ( + "fmt" "math" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" + "github.com/NethermindEth/cairo-vm-go/pkg/utils" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" - mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" "github.com/holiman/uint256" + "golang.org/x/crypto/sha3" ) // CairoKeccakFinalize writes the result of F1600 Keccak permutation padded by __keccak_state_size_felts__ zeros to consecutive memory cells, __block_size__ times. @@ -68,6 +70,108 @@ func createCairoKeccakFinalizeHinter(resolver hintReferenceResolver) (hinter.Hin return newCairoKeccakFinalizeHint(keccakPtrEnd), nil } +// UnsafeKeccak computes keccak hash of the data in memory without validity enforcement and writes the result in the `low` and `high` memory cells +// +// `newUnsafeKeccakHint` takes 4 operanders as arguments +// - `data` is the address in memory where the base of the data array to be hashed is stored. Each word in the array is 16 bytes long, except the last one, which could vary +// - `length` is the length of the data to hash +// - `low` is the low part of the produced hash +// - `high` is the high part of the produced hash +func newUnsafeKeccakHint(data, length, high, low hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "UnsafeKeccak", + Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error { + //> from eth_hash.auto import keccak + //> data, length = ids.data, ids.length + //> if '__keccak_max_size' in globals(): + //> assert length <= __keccak_max_size, \ + //> f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \ + //> f'Got: length={length}.' + //> keccak_input = bytearray() + //> for word_i, byte_i in enumerate(range(0, length, 16)): + //> word = memory[data + word_i] + //> n_bytes = min(16, length - byte_i) + //> assert 0 <= word < 2 ** (8 * n_bytes) + //> keccak_input += word.to_bytes(n_bytes, 'big') + //> hashed = keccak(keccak_input) + //> ids.high = int.from_bytes(hashed[:16], 'big') + //> ids.low = int.from_bytes(hashed[16:32], 'big') + + lengthVal, err := hinter.ResolveAsUint64(vm, length) + if err != nil { + return err + } + keccakMaxSize := uint64(1 << 20) + if lengthVal > keccakMaxSize { + return fmt.Errorf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d.", keccakMaxSize, lengthVal) + } + dataPtr, err := hinter.ResolveAsAddress(vm, data) + if err != nil { + return err + } + + keccakInput := make([]byte, 0) + for i := uint64(0); i < lengthVal; i += 16 { + wordFelt, err := vm.Memory.ReadAsElement(dataPtr.SegmentIndex, dataPtr.Offset) + if err != nil { + return err + } + word := uint256.Int(wordFelt.Bits()) + nBytes := utils.Min(lengthVal-i, 16) + if uint64(word.BitLen()) >= 8*nBytes { + return fmt.Errorf("word %v is out range 0 <= word < 2 ** %d", &word, 8*nBytes) + } + wordBytes := word.Bytes20() + keccakInput = append(keccakInput, wordBytes[20-int(nBytes):]...) + *dataPtr, err = dataPtr.AddOffset(1) + if err != nil { + return err + } + } + hash := sha3.NewLegacyKeccak256() + hash.Write(keccakInput) + hashedBytes := hash.Sum(nil) + hashedHigh := new(fp.Element).SetBytes(hashedBytes[:16]) + hashedLow := new(fp.Element).SetBytes(hashedBytes[16:32]) + highAddr, err := high.GetAddress(vm) + if err != nil { + return err + } + hashedHighMV := memory.MemoryValueFromFieldElement(hashedHigh) + err = vm.Memory.WriteToAddress(&highAddr, &hashedHighMV) + if err != nil { + return err + } + lowAddr, err := low.GetAddress(vm) + if err != nil { + return err + } + hashedLowMV := memory.MemoryValueFromFieldElement(hashedLow) + return vm.Memory.WriteToAddress(&lowAddr, &hashedLowMV) + }, + } +} + +func createUnsafeKeccakHinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + data, err := resolver.GetResOperander("data") + if err != nil { + return nil, err + } + length, err := resolver.GetResOperander("length") + if err != nil { + return nil, err + } + high, err := resolver.GetResOperander("high") + if err != nil { + return nil, err + } + low, err := resolver.GetResOperander("low") + if err != nil { + return nil, err + } + return newUnsafeKeccakHint(data, length, high, low), nil +} + // KeccakWriteArgs hint writes Keccak function arguments in memory // // `newKeccakWriteArgsHint` takes 3 operanders as arguments @@ -106,27 +210,27 @@ func newKeccakWriteArgsHint(inputs, low, high hinter.ResOperander) hinter.Hinter lowResultUint256Low.And(&maxUint64, &lowResultUint256Low) lowResulBytes32Low := lowResultUint256Low.Bytes32() lowResultFeltLow, _ := fp.BigEndian.Element(&lowResulBytes32Low) - mvLowLow := mem.MemoryValueFromFieldElement(&lowResultFeltLow) + mvLowLow := memory.MemoryValueFromFieldElement(&lowResultFeltLow) lowResultUint256High := lowUint256 lowResultUint256High.Rsh(&lowResultUint256High, 64) lowResultUint256High.And(&lowResultUint256High, &maxUint64) lowResulBytes32High := lowResultUint256High.Bytes32() lowResultFeltHigh, _ := fp.BigEndian.Element(&lowResulBytes32High) - mvLowHigh := mem.MemoryValueFromFieldElement(&lowResultFeltHigh) + mvLowHigh := memory.MemoryValueFromFieldElement(&lowResultFeltHigh) highResultUint256Low := highUint256 highResultUint256Low.And(&maxUint64, &highResultUint256Low) highResulBytes32Low := highResultUint256Low.Bytes32() highResultFeltLow, _ := fp.BigEndian.Element(&highResulBytes32Low) - mvHighLow := mem.MemoryValueFromFieldElement(&highResultFeltLow) + mvHighLow := memory.MemoryValueFromFieldElement(&highResultFeltLow) highResultUint256High := highUint256 highResultUint256High.Rsh(&highResultUint256High, 64) highResultUint256High.And(&maxUint64, &highResultUint256High) highResulBytes32High := highResultUint256High.Bytes32() highResultFeltHigh, _ := fp.BigEndian.Element(&highResulBytes32High) - mvHighHigh := mem.MemoryValueFromFieldElement(&highResultFeltHigh) + mvHighHigh := memory.MemoryValueFromFieldElement(&highResultFeltHigh) err = vm.Memory.Write(inputsPtr.SegmentIndex, inputsPtr.Offset, &mvLowLow) if err != nil { @@ -137,7 +241,6 @@ func newKeccakWriteArgsHint(inputs, low, high hinter.ResOperander) hinter.Hinter if err != nil { return err } - err = vm.Memory.Write(inputsPtr.SegmentIndex, inputsPtr.Offset+2, &mvHighLow) if err != nil { return err diff --git a/pkg/hintrunner/zero/zerohint_keccak_test.go b/pkg/hintrunner/zero/zerohint_keccak_test.go index 053696d6..326e0722 100644 --- a/pkg/hintrunner/zero/zerohint_keccak_test.go +++ b/pkg/hintrunner/zero/zerohint_keccak_test.go @@ -1,6 +1,7 @@ package zero import ( + "fmt" "testing" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" @@ -46,6 +47,101 @@ func TestZeroHintKeccak(t *testing.T) { }, }, }, + "UnsafeKeccak": { + { + operanders: []*hintOperander{ + {Name: "data", Kind: uninitialized}, + {Name: "length", Kind: apRelative, Value: feltUint64((1 << 20) + 1)}, + {Name: "high", Kind: uninitialized}, + {Name: "low", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"]) + }, + errCheck: errorTextContains(fmt.Sprintf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d.", 1<<20, (1<<20)+1)), + }, + { + operanders: []*hintOperander{ + {Name: "data", Kind: apRelative, Value: addr(5)}, + {Name: "data.0", Kind: apRelative, Value: feltUint64(65537)}, + {Name: "length", Kind: apRelative, Value: feltUint64(1)}, + {Name: "high", Kind: uninitialized}, + {Name: "low", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"]) + }, + errCheck: errorTextContains(fmt.Sprintf("word %v is out range 0 <= word < 2 ** %d", feltUint64(65537), 8)), + }, + { + operanders: []*hintOperander{ + {Name: "data", Kind: apRelative, Value: addr(5)}, + {Name: "data.0", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.1", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.2", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.3", Kind: apRelative, Value: feltUint64(4)}, + {Name: "length", Kind: apRelative, Value: feltUint64(4)}, + {Name: "high", Kind: uninitialized}, + {Name: "low", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"]) + }, + check: func(t *testing.T, ctx *hintTestContext) { + varValueEquals("high", feltString("108955721224378455455648573289483395612"))(t, ctx) + varValueEquals("low", feltString("253531040214470063354971884479696309631"))(t, ctx) + }, + }, + { + operanders: []*hintOperander{ + {Name: "data", Kind: apRelative, Value: addr(5)}, + {Name: "data.0", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.1", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.2", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.3", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.4", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.5", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.6", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.7", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.8", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.9", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.10", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.11", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.12", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.13", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.14", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.15", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.16", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.17", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.18", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.19", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.20", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.21", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.22", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.23", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.24", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.25", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.26", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.27", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.28", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.29", Kind: apRelative, Value: feltUint64(1)}, + {Name: "data.30", Kind: apRelative, Value: feltUint64(2)}, + {Name: "data.31", Kind: apRelative, Value: feltUint64(3)}, + {Name: "data.32", Kind: apRelative, Value: feltUint64(4)}, + {Name: "data.33", Kind: apRelative, Value: feltUint64(4)}, + {Name: "length", Kind: apRelative, Value: feltUint64(34)}, + {Name: "high", Kind: uninitialized}, + {Name: "low", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"]) + }, + check: func(t *testing.T, ctx *hintTestContext) { + varValueEquals("high", feltString("43087684015060895958086736298363333858"))(t, ctx) + varValueEquals("low", feltString("115090685687501856751902560332884088627"))(t, ctx) + }, + }, + }, "KeccakWriteArgs": { { operanders: []*hintOperander{ diff --git a/pkg/utils/math.go b/pkg/utils/math.go index 72fb4e0d..27589f31 100644 --- a/pkg/utils/math.go +++ b/pkg/utils/math.go @@ -63,6 +63,13 @@ func Max[T constraints.Integer](a, b T) T { return b } +func Min[T constraints.Integer](a, b T) T { + if a < b { + return a + } + return b +} + // FeltLt implements `a < b` felt comparison. func FeltLt(a, b *fp.Element) bool { return a.Cmp(b) == -1