Skip to content

Commit

Permalink
cmd/compile: enable late expansion for interface calls
Browse files Browse the repository at this point in the history
Includes a few tweaks to Value.copyOf(a) (make it a no-op for
a self-copy) and new pattern hack "___" (3 underscores) is
like ellipsis, except the replacement doesn't need to have
matching ellipsis/underscores.

Moved the arg-length check in generated pattern-matching code
BEFORE the args are probed, because not all instances of
variable length OpFoo will have all the args mentioned in
some rule for OpFoo, and when that happens, the compiler
panics without the early check.

Change-Id: I66de40672b3794a6427890ff96c805a488d783f4
Reviewed-on: https://go-review.googlesource.com/c/go/+/247537
Trust: David Chase <[email protected]>
Run-TryBot: David Chase <[email protected]>
TryBot-Result: Go Bot <[email protected]>
Reviewed-by: Cherry Zhang <[email protected]>
  • Loading branch information
dr2chase committed Oct 1, 2020
1 parent 75ea995 commit adef4de
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 55 deletions.
22 changes: 13 additions & 9 deletions src/cmd/compile/internal/gc/ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2556,7 +2556,7 @@ func (s *state) expr(n *Node) *ssa.Value {
return s.addr(n.Left)

case ORESULT:
if s.prevCall == nil || s.prevCall.Op != ssa.OpStaticLECall {
if s.prevCall == nil || s.prevCall.Op != ssa.OpStaticLECall && s.prevCall.Op != ssa.OpInterLECall && s.prevCall.Op != ssa.OpClosureLECall {
// Do the old thing
addr := s.constOffPtrSP(types.NewPtr(n.Type), n.Xoffset)
return s.rawLoad(n.Type, addr)
Expand Down Expand Up @@ -4409,6 +4409,9 @@ func (s *state) call(n *Node, k callKind, returnResultAddr bool) *ssa.Value {
iclosure, rcvr = s.getClosureAndRcvr(fn)
if k == callNormal {
codeptr = s.load(types.Types[TUINTPTR], iclosure)
if ssa.LateCallExpansionEnabledWithin(s.f) {
testLateExpansion = true
}
} else {
closure = iclosure
}
Expand Down Expand Up @@ -4555,16 +4558,17 @@ func (s *state) call(n *Node, k callKind, returnResultAddr bool) *ssa.Value {
codeptr = s.rawLoad(types.Types[TUINTPTR], closure)
call = s.newValue3A(ssa.OpClosureCall, types.TypeMem, ssa.ClosureAuxCall(ACArgs, ACResults), codeptr, closure, s.mem())
case codeptr != nil:
call = s.newValue2A(ssa.OpInterCall, types.TypeMem, ssa.InterfaceAuxCall(ACArgs, ACResults), codeptr, s.mem())
if testLateExpansion {
aux := ssa.InterfaceAuxCall(ACArgs, ACResults)
call = s.newValue1A(ssa.OpInterLECall, aux.LateExpansionResultType(), aux, codeptr)
call.AddArgs(callArgs...)
} else {
call = s.newValue2A(ssa.OpInterCall, types.TypeMem, ssa.InterfaceAuxCall(ACArgs, ACResults), codeptr, s.mem())
}
case sym != nil:
if testLateExpansion {
var tys []*types.Type
aux := ssa.StaticAuxCall(sym.Linksym(), ACArgs, ACResults)
for i := int64(0); i < aux.NResults(); i++ {
tys = append(tys, aux.TypeOfResult(i))
}
tys = append(tys, types.TypeMem)
call = s.newValue0A(ssa.OpStaticLECall, types.NewResults(tys), aux)
call = s.newValue0A(ssa.OpStaticLECall, aux.LateExpansionResultType(), aux)
call.AddArgs(callArgs...)
} else {
call = s.newValue1A(ssa.OpStaticCall, types.TypeMem, ssa.StaticAuxCall(sym.Linksym(), ACArgs, ACResults), s.mem())
Expand Down Expand Up @@ -4713,7 +4717,7 @@ func (s *state) addr(n *Node) *ssa.Value {
}
case ORESULT:
// load return from callee
if s.prevCall == nil || s.prevCall.Op != ssa.OpStaticLECall {
if s.prevCall == nil || s.prevCall.Op != ssa.OpStaticLECall && s.prevCall.Op != ssa.OpInterLECall && s.prevCall.Op != ssa.OpClosureLECall {
return s.constOffPtrSP(t, n.Xoffset)
}
which := s.prevCall.Aux.(*ssa.AuxCall).ResultForOffset(n.Xoffset)
Expand Down
93 changes: 59 additions & 34 deletions src/cmd/compile/internal/ssa/expand_calls.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func expandCalls(f *Func) {
} else {
hiOffset = 4
}

pairTypes := func(et types.EType) (tHi, tLo *types.Type) {
tHi = tUint32
if et == types.TINT64 {
Expand Down Expand Up @@ -231,46 +232,64 @@ func expandCalls(f *Func) {
return x
}

rewriteArgs := func(v *Value, firstArg int) *Value {
// Thread the stores on the memory arg
aux := v.Aux.(*AuxCall)
pos := v.Pos.WithNotStmt()
m0 := v.Args[len(v.Args)-1]
mem := m0
for i, a := range v.Args {
if i < firstArg {
continue
}
if a == m0 { // mem is last.
break
}
auxI := int64(i - firstArg)
if a.Op == OpDereference {
if a.MemoryArg() != m0 {
f.Fatalf("Op...LECall and OpDereference have mismatched mem, %s and %s", v.LongString(), a.LongString())
}
// "Dereference" of addressed (probably not-SSA-eligible) value becomes Move
// TODO this will be more complicated with registers in the picture.
src := a.Args[0]
dst := f.ConstOffPtrSP(src.Type, aux.OffsetOfArg(auxI), sp)
if a.Uses == 1 {
a.reset(OpMove)
a.Pos = pos
a.Type = types.TypeMem
a.Aux = aux.TypeOfArg(auxI)
a.AuxInt = aux.SizeOfArg(auxI)
a.SetArgs3(dst, src, mem)
mem = a
} else {
mem = a.Block.NewValue3A(pos, OpMove, types.TypeMem, aux.TypeOfArg(auxI), dst, src, mem)
mem.AuxInt = aux.SizeOfArg(auxI)
}
} else {
mem = storeArg(pos, v.Block, a, aux.TypeOfArg(auxI), aux.OffsetOfArg(auxI), mem)
}
}
v.resetArgs()
return mem
}

// Step 0: rewrite the calls to convert incoming args to stores.
for _, b := range f.Blocks {
for _, v := range b.Values {
switch v.Op {
case OpStaticLECall:
// Thread the stores on the memory arg
m0 := v.MemoryArg()
mem := m0
pos := v.Pos.WithNotStmt()
aux := v.Aux.(*AuxCall)
for i, a := range v.Args {
if a == m0 { // mem is last.
break
}
if a.Op == OpDereference {
// "Dereference" of addressed (probably not-SSA-eligible) value becomes Move
// TODO this will be more complicated with registers in the picture.
if a.MemoryArg() != m0 {
f.Fatalf("Op...LECall and OpDereference have mismatched mem, %s and %s", v.LongString(), a.LongString())
}
src := a.Args[0]
dst := f.ConstOffPtrSP(src.Type, aux.OffsetOfArg(int64(i)), sp)
if a.Uses == 1 {
a.reset(OpMove)
a.Pos = pos
a.Type = types.TypeMem
a.Aux = aux.TypeOfArg(int64(i))
a.AuxInt = aux.SizeOfArg(int64(i))
a.SetArgs3(dst, src, mem)
mem = a
} else {
mem = a.Block.NewValue3A(pos, OpMove, types.TypeMem, aux.TypeOfArg(int64(i)), dst, src, mem)
mem.AuxInt = aux.SizeOfArg(int64(i))
}
} else {
mem = storeArg(pos, b, a, aux.TypeOfArg(int64(i)), aux.OffsetOfArg(int64(i)), mem)
}
}
v.resetArgs()
mem := rewriteArgs(v, 0)
v.SetArgs1(mem)
case OpClosureLECall:
code := v.Args[0]
context := v.Args[1]
mem := rewriteArgs(v, 2)
v.SetArgs3(code, context, mem)
case OpInterLECall:
code := v.Args[0]
mem := rewriteArgs(v, 1)
v.SetArgs2(code, mem)
}
}
}
Expand Down Expand Up @@ -370,6 +389,12 @@ func expandCalls(f *Func) {
case OpStaticLECall:
v.Op = OpStaticCall
v.Type = types.TypeMem
case OpClosureLECall:
v.Op = OpClosureCall
v.Type = types.TypeMem
case OpInterLECall:
v.Op = OpInterCall
v.Type = types.TypeMem
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions src/cmd/compile/internal/ssa/gen/generic.rules
Original file line number Diff line number Diff line change
Expand Up @@ -2024,6 +2024,13 @@
(InterCall [argsize] {auxCall} (Load (OffPtr [off] (ITab (IMake (Addr {itab} (SB)) _))) _) mem) && devirt(v, auxCall, itab, off) != nil =>
(StaticCall [int32(argsize)] {devirt(v, auxCall, itab, off)} mem)

// De-virtualize late-expanded interface calls into late-expanded static calls.
// Note that (ITab (IMake)) doesn't get rewritten until after the first opt pass,
// so this rule should trigger reliably.
// devirtLECall removes the first argument, adds the devirtualized symbol to the AuxCall, and changes the opcode
(InterLECall [argsize] {auxCall} (Load (OffPtr [off] (ITab (IMake (Addr {itab} (SB)) _))) _) ___) && devirtLESym(v, auxCall, itab, off) !=
nil => devirtLECall(v, devirtLESym(v, auxCall, itab, off))

// Move and Zero optimizations.
// Move source and destination may overlap.

Expand Down
10 changes: 6 additions & 4 deletions src/cmd/compile/internal/ssa/gen/genericOps.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,12 @@ var genericOps = []opData{
// TODO(josharian): ClosureCall and InterCall should have Int32 aux
// to match StaticCall's 32 bit arg size limit.
// TODO(drchase,josharian): could the arg size limit be bundled into the rules for CallOff?
{name: "ClosureCall", argLength: 3, aux: "CallOff", call: true}, // arg0=code pointer, arg1=context ptr, arg2=memory. auxint=arg size. Returns memory.
{name: "StaticCall", argLength: 1, aux: "CallOff", call: true}, // call function aux.(*obj.LSym), arg0=memory. auxint=arg size. Returns memory.
{name: "InterCall", argLength: 2, aux: "CallOff", call: true}, // interface call. arg0=code pointer, arg1=memory, auxint=arg size. Returns memory.
{name: "StaticLECall", argLength: -1, aux: "CallOff", call: true}, // late-expanded static call function aux.(*ssa.AuxCall.Fn). arg0..argN-1 are inputs, argN is mem. auxint = arg size. Result is tuple of result(s), plus mem.
{name: "ClosureCall", argLength: 3, aux: "CallOff", call: true}, // arg0=code pointer, arg1=context ptr, arg2=memory. auxint=arg size. Returns memory.
{name: "StaticCall", argLength: 1, aux: "CallOff", call: true}, // call function aux.(*obj.LSym), arg0=memory. auxint=arg size. Returns memory.
{name: "InterCall", argLength: 2, aux: "CallOff", call: true}, // interface call. arg0=code pointer, arg1=memory, auxint=arg size. Returns memory.
{name: "ClosureLECall", argLength: -1, aux: "CallOff", call: true}, // late-expanded closure call. arg0=code pointer, arg1=context ptr, arg2..argN-1 are inputs, argN is mem. auxint = arg size. Result is tuple of result(s), plus mem.
{name: "StaticLECall", argLength: -1, aux: "CallOff", call: true}, // late-expanded static call function aux.(*ssa.AuxCall.Fn). arg0..argN-1 are inputs, argN is mem. auxint = arg size. Result is tuple of result(s), plus mem.
{name: "InterLECall", argLength: -1, aux: "CallOff", call: true}, // late-expanded interface call. arg0=code pointer, arg1..argN-1 are inputs, argN is mem. auxint = arg size. Result is tuple of result(s), plus mem.

// Conversions: signed extensions, zero (unsigned) extensions, truncations
{name: "SignExt8to16", argLength: 1, typ: "Int16"},
Expand Down
22 changes: 18 additions & 4 deletions src/cmd/compile/internal/ssa/gen/rulegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ import (
// variable ::= some token
// opcode ::= one of the opcodes from the *Ops.go files

// special rules: trailing ellipsis "..." (in the outermost sexpr?) must match on both sides of a rule.
// trailing three underscore "___" in the outermost match sexpr indicate the presence of
// extra ignored args that need not appear in the replacement

// extra conditions is just a chunk of Go that evaluates to a boolean. It may use
// variables declared in the matching sexpr. The variable "v" is predefined to be
// variables declared in the matching tsexpr. The variable "v" is predefined to be
// the value matched by the entire rule.

// If multiple rules match, the first one in file order is selected.
Expand Down Expand Up @@ -1019,6 +1023,19 @@ func genMatch0(rr *RuleRewrite, arch arch, match, v string, cnt map[string]int,
pos = v + ".Pos"
}

// If the last argument is ___, it means "don't care about trailing arguments, really"
// The likely/intended use is for rewrites that are too tricky to express in the existing pattern language
// Do a length check early because long patterns fed short (ultimately not-matching) inputs will
// do an indexing error in pattern-matching.
if op.argLength == -1 {
l := len(args)
if l == 0 || args[l-1] != "___" {
rr.add(breakf("len(%s.Args) != %d", v, l))
} else if l > 1 && args[l-1] == "___" {
rr.add(breakf("len(%s.Args) < %d", v, l-1))
}
}

for _, e := range []struct {
name, field, dclType string
}{
Expand Down Expand Up @@ -1159,9 +1176,6 @@ func genMatch0(rr *RuleRewrite, arch arch, match, v string, cnt map[string]int,
}
}

if op.argLength == -1 {
rr.add(breakf("len(%s.Args) != %d", v, len(args)))
}
return pos, checkOp
}

Expand Down
11 changes: 11 additions & 0 deletions src/cmd/compile/internal/ssa/op.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ func (a *AuxCall) NResults() int64 {
return int64(len(a.results))
}

// LateExpansionResultType returns the result type (including trailing mem)
// for a call that will be expanded later in the SSA phase.
func (a *AuxCall) LateExpansionResultType() *types.Type {
var tys []*types.Type
for i := int64(0); i < a.NResults(); i++ {
tys = append(tys, a.TypeOfResult(i))
}
tys = append(tys, types.TypeMem)
return types.NewResults(tys)
}

// NArgs returns the number of arguments
func (a *AuxCall) NArgs() int64 {
return int64(len(a.args))
Expand Down
16 changes: 16 additions & 0 deletions src/cmd/compile/internal/ssa/opGen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions src/cmd/compile/internal/ssa/rewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,36 @@ func devirt(v *Value, aux interface{}, sym Sym, offset int64) *AuxCall {
return StaticAuxCall(lsym, va.args, va.results)
}

// de-virtualize an InterLECall
// 'sym' is the symbol for the itab
func devirtLESym(v *Value, aux interface{}, sym Sym, offset int64) *obj.LSym {
n, ok := sym.(*obj.LSym)
if !ok {
return nil
}

f := v.Block.Func
lsym := f.fe.DerefItab(n, offset)
if f.pass.debug > 0 {
if lsym != nil {
f.Warnl(v.Pos, "de-virtualizing call")
} else {
f.Warnl(v.Pos, "couldn't de-virtualize call")
}
}
if lsym == nil {
return nil
}
return lsym
}

func devirtLECall(v *Value, sym *obj.LSym) *Value {
v.Op = OpStaticLECall
v.Aux.(*AuxCall).Fn = sym
v.RemoveArg(0)
return v
}

// isSamePtr reports whether p1 and p2 point to the same address.
func isSamePtr(p1, p2 *Value) bool {
if p1 == p2 {
Expand Down
Loading

0 comments on commit adef4de

Please sign in to comment.