Skip to content

Commit

Permalink
Implement UsortBody hint (#352)
Browse files Browse the repository at this point in the history
* Implement UsortBody hint

* Debug hint code

* Implement error check for inputLength

* Add bugfixes, add testcases

* Implement comments from PR

* Resolve merge conflicts

* Fix unit tests issues

* Specify the type of '__usort_max_size'

* Code cleanups

* Address comments from PR

* Address comments from PR

* Lint PR
  • Loading branch information
MaksymMalicki committed May 10, 2024
1 parent 70ad3af commit a9d37d1
Show file tree
Hide file tree
Showing 7 changed files with 384 additions and 9 deletions.
19 changes: 19 additions & 0 deletions pkg/hintrunner/utils/usort_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package utils

import "github.com/consensys/gnark-crypto/ecc/stark-curve/fp"

// SortFelt is a type that implements the sort.Interface for a slice of fp.Element
// This file provides utility functions for sorting slices of fp.Element
type SortFelt []fp.Element

func (s SortFelt) Len() int {
return len(s)
}

func (s SortFelt) Less(i, j int) bool {
return s[i].Cmp(&s[j]) < 0
}

func (s SortFelt) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
20 changes: 20 additions & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ const (
uint256MulDivModCode string = "a = (ids.a.high << 128) + ids.a.low/n b = (ids.b.high << 128) + ids.b.low/n div = (ids.div.high << 128) + ids.div.low/n quotient, remainder = divmod(a * b, div)/n ids.quotient_low.low = quotient & ((1 << 128) - 1)/n ids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)/n ids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)/n ids.quotient_high.high = quotient >> 384/n ids.remainder.low = remainder & ((1 << 128) - 1)/n ids.remainder.high = remainder >> 128"

// ------ Usort hints related code ------
usortBodyCode string = `
from collections import defaultdict
input_ptr = ids.input
input_len = int(ids.input_len)
if __usort_max_size is not None:
assert input_len <= __usort_max_size, (
f"usort() can only be used with input_len<={__usort_max_size}. "
f"Got: input_len={input_len}."
)
positions_dict = defaultdict(list)
for i in range(input_len):
val = memory[input_ptr + i]
positions_dict[val].append(i)
output = sorted(positions_dict.keys())
ids.output_len = len(output)
ids.output = segments.gen_arg(output)
ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
usortEnterScopeCode string = "vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))"
usortVerifyMultiplicityAssertCode string = "assert len(positions) == 0"
usortVerifyCode string = "last_pos = 0\npositions = positions_dict[ids.value][::-1]"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createUsortVerifyHinter(resolver)
case usortVerifyMultiplicityBodyCode:
return createUsortVerifyMultiplicityBodyHinter(resolver)
case usortBodyCode:
return createUsortBodyHinter(resolver)
// Dictionaries hints
case squashDictInnerAssertLenKeys:
return createSquashDictInnerAssertLenKeysHinter()
Expand Down
1 change: 0 additions & 1 deletion pkg/hintrunner/zero/zerohint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ func runHinterTests(t *testing.T, tests map[string][]hintTestCase) {
}
}
}

h := tc.makeHinter(testCtx)

err := h.Execute(vm, ctx)
Expand Down
157 changes: 156 additions & 1 deletion pkg/hintrunner/zero/zerohint_usort.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,170 @@ package zero

import (
"fmt"
"sort"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
usortUtils "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

func createUsortBodyHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
input, err := resolver.GetResOperander("input")
if err != nil {
return nil, err
}
input_len, err := resolver.GetResOperander("input_len")
if err != nil {
return nil, err
}
output, err := resolver.GetResOperander("output")
if err != nil {
return nil, err
}
output_len, err := resolver.GetResOperander("output_len")
if err != nil {
return nil, err
}
multiplicities, err := resolver.GetResOperander("multiplicities")
if err != nil {
return nil, err
}
return newUsortBodyHint(input, input_len, output, output_len, multiplicities), nil
}

// UsortBody hint sorts the input array of field elements. The sorting results in generation of output array without duplicates and multiplicites array, where each element represents the number of times the corresponding element in the output array appears in the input array. The output and multiplicities arrays are written to the new, separate segments in memory.
//
// `newSplit64Hint` takes 5 operanders as arguments
// - `input` is the pointer to the base of input array of field elements
// - `inputLen` is the length of the input array
// - `output` is the pointer to the base of the output array of field elements
// - `outputLen` is the length of the output array
// - `multiplicities` is the pointer to the base of the multiplicities array of field elements
func newUsortBodyHint(input, inputLen, output, outputLen, multiplicities hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "UsortBody",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> from collections import defaultdict
//>
//> input_ptr = ids.input
//> input_len = int(ids.input_len)
//> if __usort_max_size is not None:
//> assert input_len <= __usort_max_size, (
//> f"usort() can only be used with input_len<={__usort_max_size}. "
//> f"Got: input_len={input_len}."
//> )
//>
//> positions_dict = defaultdict(list)
//> for i in range(input_len):
//> val = memory[input_ptr + i]
//> positions_dict[val].append(i)
//>
//> output = sorted(positions_dict.keys())
//> ids.output_len = len(output)
//> ids.output = segments.gen_arg(output)
//> ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])
//>
//> input_ptr = ids.input
inputBasePtr, err := hinter.ResolveAsAddress(vm, input)
if err != nil {
return err
}
inputLenValue, err := hinter.ResolveAsUint64(vm, inputLen)
if err != nil {
return err
}
usortMaxSizeInterface, err := ctx.ScopeManager.GetVariableValue("__usort_max_size")
if err != nil {
return err
}
usortMaxSize := usortMaxSizeInterface.(uint64)
if inputLenValue > usortMaxSize {
return fmt.Errorf("usort() can only be used with input_len<=%d.\n Got: input_len=%d", usortMaxSize, inputLenValue)
}
positionsDict := make(map[fp.Element][]uint64, inputLenValue)
for i := uint64(0); i < inputLenValue; i++ {
val, err := vm.Memory.ReadFromAddressAsElement(inputBasePtr)
if err != nil {
return err
}
positionsDict[val] = append(positionsDict[val], i)
*inputBasePtr, err = inputBasePtr.AddOffset(1)
if err != nil {
return err
}
}

outputArray := make([]fp.Element, len(positionsDict))
iterator := 0
for key := range positionsDict {
outputArray[iterator] = key
iterator++
}
sort.Sort(usortUtils.SortFelt(outputArray))

outputLenAddr, err := outputLen.GetAddress(vm)
if err != nil {
return err
}
outputLenMV := memory.MemoryValueFromFieldElement(new(fp.Element).SetUint64(uint64(len(outputArray))))
err = vm.Memory.WriteToAddress(&outputLenAddr, &outputLenMV)
if err != nil {
return err
}
outputSegmentBaseAddr := vm.Memory.AllocateEmptySegment()
outputAddr, err := output.GetAddress(vm)
if err != nil {
return err
}
outputSegmentBaseAddrMV := memory.MemoryValueFromMemoryAddress(&outputSegmentBaseAddr)
err = vm.Memory.WriteToAddress(&outputAddr, &outputSegmentBaseAddrMV)
if err != nil {
return err
}
for _, v := range outputArray {
outputElementMV := memory.MemoryValueFromFieldElement(&v)
err = vm.Memory.WriteToAddress(&outputSegmentBaseAddr, &outputElementMV)
if err != nil {
return err
}
outputSegmentBaseAddr, err = outputSegmentBaseAddr.AddOffset(1)
if err != nil {
return err
}
}
multiplicitiesArray := make([]*fp.Element, len(outputArray))
for i, v := range outputArray {
multiplicitiesArray[i] = new(fp.Element).SetUint64(uint64(len(positionsDict[v])))
}
multiplicitesSegmentBaseAddr := vm.Memory.AllocateEmptySegment()
multiplicitiesAddr, err := multiplicities.GetAddress(vm)
if err != nil {
return err
}
multiplicitesSegmentBaseAddrMV := memory.MemoryValueFromMemoryAddress(&multiplicitesSegmentBaseAddr)
err = vm.Memory.WriteToAddress(&multiplicitiesAddr, &multiplicitesSegmentBaseAddrMV)
if err != nil {
return err
}
for _, v := range multiplicitiesArray {
multiplicitiesElementMV := memory.MemoryValueFromFieldElement(v)
err = vm.Memory.WriteToAddress(&multiplicitesSegmentBaseAddr, &multiplicitiesElementMV)
if err != nil {
return err
}
multiplicitesSegmentBaseAddr, err = multiplicitesSegmentBaseAddr.AddOffset(1)
if err != nil {
return err
}
}
return nil
},
}
}

// UsortEnterScope hint enters a new scope with `__usort_max_size` value
//
// `newUsortEnterScopeHint` doesn't take any operander as argument
Expand All @@ -21,7 +177,6 @@ func newUsortEnterScopeHint() hinter.Hinter {
Name: "UsortEnterScope",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))

usortMaxSize, err := ctx.ScopeManager.GetVariableValue("__usort_max_size")
if err != nil {
return err
Expand Down
Loading

0 comments on commit a9d37d1

Please sign in to comment.