Skip to content

Commit

Permalink
update positional slice default value. use sep of struct tag
Browse files Browse the repository at this point in the history
  • Loading branch information
zkep authored and leaanthony committed May 9, 2024
1 parent 274b4ac commit 9a039b7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 65 deletions.
87 changes: 40 additions & 47 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Command struct {
helpFlag bool
hidden bool
positionalArgsMap map[string]reflect.Value
sliceSeparator map[string]string
}

// NewCommand creates a new Command
Expand All @@ -37,6 +38,7 @@ func NewCommand(name string, description string) *Command {
subCommandsMap: make(map[string]*Command),
hidden: false,
positionalArgsMap: make(map[string]reflect.Value),
sliceSeparator: make(map[string]string),
}

return result
Expand Down Expand Up @@ -284,7 +286,11 @@ func (c *Command) AddFlags(optionStruct interface{}) *Command {
description := tag.Get("description")
defaultValue := tag.Get("default")
pos := tag.Get("pos")
sep := tag.Get("sep")
c.positionalArgsMap[pos] = field
if sep != "" {
c.sliceSeparator[pos] = sep
}
if name == "" {
name = strings.ToLower(t.Elem().Field(i).Name)
}
Expand Down Expand Up @@ -427,7 +433,7 @@ func (c *Command) AddFlags(optionStruct interface{}) *Command {
}
c.Float64Flag(name, description, field.Addr().Interface().(*float64))
case reflect.Slice:
c.addSliceField(field, defaultValue, tag.Get("sep"))
c.addSliceField(field, defaultValue, sep)
c.addSliceFlags(name, description, field)
default:
if pos != "" {
Expand Down Expand Up @@ -489,9 +495,6 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
if defaultValue == "" {
return c
}
if separator == "" {
separator = ","
}
if field.Kind() != reflect.Slice {
panic("addSliceField() requires a pointer to a slice")
}
Expand All @@ -502,11 +505,14 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
if t.Elem().Kind() != reflect.Slice {
panic("addSliceField() requires a pointer to a slice")
}
defaultSlice := []string{defaultValue}
if separator != "" {
defaultSlice = strings.Split(defaultValue, separator)
}
switch t.Elem().Elem().Kind() {
case reflect.Bool:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]bool, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]bool, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.ParseBool(value)
if err != nil {
panic("Invalid default value for bool flag")
Expand All @@ -515,12 +521,10 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.String:
defaultValues := strings.Split(defaultValue, separator)
field.Set(reflect.ValueOf(defaultValues))
field.Set(reflect.ValueOf(defaultSlice))
case reflect.Int:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]int, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]int, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for int flag")
Expand All @@ -529,9 +533,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Int8:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]int8, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]int8, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for int8 flag")
Expand All @@ -540,9 +543,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Int16:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]int16, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]int16, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for int16 flag")
Expand All @@ -551,20 +553,18 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Int32:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]int32, 0, len(defaultSplit))
for _, value := range defaultSplit {
val, err := strconv.ParseInt(value, 10, 64)
defaultValues := make([]int32, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.ParseInt(value, 10, 32)
if err != nil {
panic("Invalid default value for int32 flag")
}
defaultValues = append(defaultValues, int32(val))
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Int64:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]int64, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]int64, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.ParseInt(value, 10, 64)
if err != nil {
panic("Invalid default value for int64 flag")
Expand All @@ -573,9 +573,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Uint:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]uint, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]uint, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for uint flag")
Expand All @@ -584,9 +583,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Uint8:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]uint8, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]uint8, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for uint8 flag")
Expand All @@ -595,9 +593,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Uint16:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]uint16, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]uint16, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for uint16 flag")
Expand All @@ -606,9 +603,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Uint32:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]uint32, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]uint32, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for uint32 flag")
Expand All @@ -617,9 +613,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Uint64:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]uint64, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]uint64, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for uint64 flag")
Expand All @@ -628,9 +623,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Float32:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]float32, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]float32, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for float32 flag")
Expand All @@ -639,9 +633,8 @@ func (c *Command) addSliceField(field reflect.Value, defaultValue, separator str
}
field.Set(reflect.ValueOf(defaultValues))
case reflect.Float64:
defaultSplit := strings.Split(defaultValue, separator)
defaultValues := make([]float64, 0, len(defaultSplit))
for _, value := range defaultSplit {
defaultValues := make([]float64, 0, len(defaultSlice))
for _, value := range defaultSlice {
val, err := strconv.Atoi(value)
if err != nil {
panic("Invalid default value for float64 flag")
Expand Down Expand Up @@ -1332,7 +1325,7 @@ func (c *Command) parsePositionalArgs(args []string) error {
}
field.SetFloat(value)
case reflect.Slice:
c.addSliceField(field, posArg, "")
c.addSliceField(field, posArg, c.sliceSeparator[key])
default:
return errors.New("Unsupported type for positional argument: " + fieldType.Name())
}
Expand Down
37 changes: 19 additions & 18 deletions examples/flags-slice/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,59 @@ import (
type Flags struct {
String string `name:"string" description:"The string" pos:"1"`
Strings []string `name:"strings" description:"The strings" pos:"2"`
StringsDefault []string `name:"strings_default" description:"The strings default" default:"one,two,three" pos:"3"`
StringsDefault []string `name:"strings_default" description:"The strings default" default:"one|two|three" sep:"|" pos:"3"`

Int int `name:"int" description:"The int" pos:"4"`
Ints []int `name:"ints" description:"The ints" pos:"5"`
IntsDefault []int `name:"ints_default" description:"The ints default" default:"3|4|5" sep:"|" pos:"6"`
IntsDefault []int `name:"ints_default" description:"The ints default" default:"3|4|5" sep:"|" pos:"6"`

Int8 int8 `name:"int8" description:"The int8" pos:"7"`
Int8s []int8 `name:"int8s" description:"The int8s" pos:"8"`
Int8sDefault []int8 `name:"int8s_default" description:"The int8s default" default:"3,4,5" pos:"9"`
Int8sDefault []int8 `name:"int8s_default" description:"The int8s default" default:"3,4,5" sep:"," pos:"9"`

Int16 int16 `name:"int16" description:"The int16" pos:"10"`
Int16s []int16 `name:"int16s" description:"The int16s" pos:"11"`
Int16sDefault []int16 `name:"int16s_default" description:"The int16s default" default:"3,4,5" pos:"12"`
Int16sDefault []int16 `name:"int16s_default" description:"The int16s default" default:"3,4,5" sep:"," pos:"12"`

Int32 int32 `name:"int32" description:"The int32" pos:"13"`
Int32s []int32 `name:"int32s" description:"The int32s" pos:"14"`
Int32sDefault []int32 `name:"int32s_default" description:"The int32 default" default:"3,4,5" pos:"15"`
Int32sDefault []int32 `name:"int32s_default" description:"The int32 default" default:"3,4,5" sep:"," pos:"15"`

Int64 int64 `name:"int64" description:"The int64" pos:"16"`
Int64s []int64 `name:"int64s" description:"The int64s" pos:"17"`
Int64sDefault []int64 `name:"int64s_default" description:"The int64s default" default:"3,4,5" pos:"18"`
Int64sDefault []int64 `name:"int64s_default" description:"The int64s default" default:"3,4,5" sep:"," pos:"18"`

Uint uint `name:"uint" description:"The uint" pos:"19"`
Uints []uint `name:"uints" description:"The uints" pos:"20"`
UintsDefault []uint `name:"uints_default" description:"The uints default" default:"3,4,5" pos:"21"`
UintsDefault []uint `name:"uints_default" description:"The uints default" default:"3,4,5" sep:"," pos:"21"`

Uint8 uint8 `name:"uint8" description:"The uint8" pos:"22"`
Uint8s []uint8 `name:"uint8s" description:"The uint8s" pos:"23"`
Uint8sDefault []uint8 `name:"uint8s_default" description:"The uint8s default" default:"3,4,5" pos:"24"`
Uint8sDefault []uint8 `name:"uint8s_default" description:"The uint8s default" default:"3,4,5" sep:"," pos:"24"`

Uint16 uint16 `name:"uint16" description:"The uint16" pos:"25"`
Uint16s []uint16 `name:"uint16s" description:"The uint16s" pos:"26"`
Uint16sDefault []uint16 `name:"uint16s_default" description:"The uint16 default" default:"3,4,5" pos:"27"`
Uint16sDefault []uint16 `name:"uint16s_default" description:"The uint16 default" default:"3,4,5" sep:"," pos:"27"`

Uint32 uint32 `name:"uint32" description:"The uint32" pos:"28"`
Uint32s []uint32 `name:"uint32s" description:"The uint32s" pos:"29"`
Uint32sDefault []uint32 `name:"uint32s_default" description:"The uint32s default" default:"3,4,5" pos:"30"`
Uint32sDefault []uint32 `name:"uint32s_default" description:"The uint32s default" default:"3,4,5" sep:"," pos:"30"`

Uint64 uint64 `name:"uint64" description:"The uint64" pos:"31"`
Uint64s []uint64 `name:"uint64s" description:"The uint64s" pos:"32"`
Uint64sDefault []uint64 `name:"uint64s_default" description:"The uint64s default" default:"3,4,5" pos:"33"`
Uint64sDefault []uint64 `name:"uint64s_default" description:"The uint64s default" default:"3,4,5" sep:"," pos:"33"`

Float32 float32 `name:"float32" description:"The float32" pos:"34"`
Float32s []float32 `name:"float32s" description:"The float32s" pos:"35"`
Float32sDefault []float32 `name:"float32s_default" description:"The float32s default" default:"3,4,5" pos:"36"`
Float32sDefault []float32 `name:"float32s_default" description:"The float32s default" default:"3|4|5" sep:"|" pos:"36"`

Float64 float64 `name:"float64" description:"The float64" pos:"37"`
Float64s []float64 `name:"float64s" description:"The float64s" pos:"38"`
Float64sDefault []float64 `name:"float64s_default" description:"The float64s default" default:"3,4,5" pos:"39"`
Float64sDefault []float64 `name:"float64s_default" description:"The float64s default" default:"3|4|5" sep:"|" pos:"39"`

Bool bool `name:"bool" description:"The bool" pos:"40"`
Bools []bool `name:"bools" description:"The bools" pos:"41"`
BoolsDefault []bool `name:"bools_default" description:"The bools default" default:"false,true,false,true" pos:"42"`
BoolsDefault []bool `name:"bools_default" description:"The bools default" default:"false|true|false|true" sep:"|" pos:"42"`
}

func main() {
Expand Down Expand Up @@ -190,12 +190,12 @@ func main() {
panic(fmt.Sprintf("expected 'hello', got '%v'", f.String))
}

if !reflect.DeepEqual(f.Strings, []string{"zkep", "hello", "clir"}) {
panic(fmt.Sprintf("expected '[zkep hello clir]', got '%v'", f.Strings))
if !reflect.DeepEqual(f.Strings, []string{"zkep,hello,clir"}) {
panic(fmt.Sprintf("expected 'zkep,hello,clir', got '%v'", f.Strings))
}

if !reflect.DeepEqual(f.StringsDefault, []string{"zkep", "clir", "hello"}) {
panic(fmt.Sprintf("expected '[zkep clir hello]', got '%v'", f.StringsDefault))
panic(fmt.Sprintf("expected '[zkep,clir,hello]', got '%v'", f.StringsDefault))
}

println("string:", fmt.Sprintf("%#v", f.String))
Expand All @@ -207,7 +207,8 @@ func main() {
})

// Run!
if err := cli.Run("positional", "hello", "zkep,hello,clir", "zkep,clir,hello"); err != nil {
// The pos 3 slice separator is '|' in struct tag
if err := cli.Run("positional", "hello", "zkep,hello,clir", "zkep|clir|hello"); err != nil {
panic(err)
}

Expand Down

0 comments on commit 9a039b7

Please sign in to comment.