From 969f48a3a24032c3dd1ec351302b5b62407dfb88 Mon Sep 17 00:00:00 2001 From: Wayne Zuo Date: Fri, 29 Jul 2022 14:24:26 +0800 Subject: [PATCH] cmd/compile: intrinsify Add64 on riscv64 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit According to RISCV instruction set manual v2.2 Sec 2.4, we can implement overflowing check for unsigned addition cheaply using SLTU instructions. After this CL, the performance difference in crypto/elliptic benchmarks on linux/riscv64 are: name old time/op new time/op delta ScalarBaseMult/P256 1.93ms ± 1% 1.64ms ± 1% -14.96% (p=0.008 n=5+5) ScalarBaseMult/P224 1.80ms ± 2% 1.53ms ± 1% -14.89% (p=0.008 n=5+5) ScalarBaseMult/P384 6.15ms ± 2% 5.12ms ± 2% -16.73% (p=0.008 n=5+5) ScalarBaseMult/P521 25.9ms ± 1% 22.3ms ± 2% -13.78% (p=0.008 n=5+5) ScalarMult/P256 5.59ms ± 1% 4.49ms ± 2% -19.79% (p=0.008 n=5+5) ScalarMult/P224 5.42ms ± 1% 4.33ms ± 1% -20.01% (p=0.008 n=5+5) ScalarMult/P384 19.9ms ± 2% 16.3ms ± 1% -18.15% (p=0.008 n=5+5) ScalarMult/P521 97.3ms ± 1% 100.7ms ± 0% +3.48% (p=0.008 n=5+5) Change-Id: Ic4c82ced4b072a4a6575343fa9f29dd09b0cabc4 Reviewed-on: https://go-review.googlesource.com/c/go/+/420094 Reviewed-by: David Chase Reviewed-by: Cherry Mui Run-TryBot: Wayne Zuo Reviewed-by: Joel Sing TryBot-Result: Gopher Robot --- .../compile/internal/ssa/gen/RISCV64.rules | 7 ++ .../compile/internal/ssa/rewriteRISCV64.go | 76 +++++++++++++++++++ src/cmd/compile/internal/ssagen/ssa.go | 4 +- test/codegen/mathbits.go | 8 ++ 4 files changed, 93 insertions(+), 2 deletions(-) diff --git a/src/cmd/compile/internal/ssa/gen/RISCV64.rules b/src/cmd/compile/internal/ssa/gen/RISCV64.rules index 5bc47ee1cc1..9d2d785d0ea 100644 --- a/src/cmd/compile/internal/ssa/gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/gen/RISCV64.rules @@ -52,6 +52,10 @@ (Hmul32 x y) => (SRAI [32] (MUL (SignExt32to64 x) (SignExt32to64 y))) (Hmul32u x y) => (SRLI [32] (MUL (ZeroExt32to64 x) (ZeroExt32to64 y))) +(Select0 (Add64carry x y c)) => (ADD (ADD x y) c) +(Select1 (Add64carry x y c)) => + (OR (SLTU s:(ADD x y) x) (SLTU (ADD s c) s)) + // (x + y) / 2 => (x / 2) + (y / 2) + (x & y & 1) (Avg64u x y) => (ADD (ADD (SRLI [1] x) (SRLI [1] y)) (ANDI [1] (AND x y))) @@ -743,6 +747,9 @@ (SLTI [x] (MOVDconst [y])) => (MOVDconst [b2i(int64(y) < int64(x))]) (SLTIU [x] (MOVDconst [y])) => (MOVDconst [b2i(uint64(y) < uint64(x))]) +(SLT x x) => (MOVDconst [0]) +(SLTU x x) => (MOVDconst [0]) + // deadcode for LoweredMuluhilo (Select0 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MULHU x y) (Select1 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MUL x y) diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 9253d2d7296..e4e4003f34e 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -509,10 +509,14 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64SLL(v) case OpRISCV64SLLI: return rewriteValueRISCV64_OpRISCV64SLLI(v) + case OpRISCV64SLT: + return rewriteValueRISCV64_OpRISCV64SLT(v) case OpRISCV64SLTI: return rewriteValueRISCV64_OpRISCV64SLTI(v) case OpRISCV64SLTIU: return rewriteValueRISCV64_OpRISCV64SLTIU(v) + case OpRISCV64SLTU: + return rewriteValueRISCV64_OpRISCV64SLTU(v) case OpRISCV64SRA: return rewriteValueRISCV64_OpRISCV64SRA(v) case OpRISCV64SRAI: @@ -4864,6 +4868,22 @@ func rewriteValueRISCV64_OpRISCV64SLLI(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64SLT(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (SLT x x) + // result: (MOVDconst [0]) + for { + x := v_0 + if x != v_1 { + break + } + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(0) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SLTI(v *Value) bool { v_0 := v.Args[0] // match: (SLTI [x] (MOVDconst [y])) @@ -4896,6 +4916,22 @@ func rewriteValueRISCV64_OpRISCV64SLTIU(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64SLTU(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (SLTU x x) + // result: (MOVDconst [0]) + for { + x := v_0 + if x != v_1 { + break + } + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(0) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -6036,6 +6072,23 @@ func rewriteValueRISCV64_OpRsh8x8(v *Value) bool { } func rewriteValueRISCV64_OpSelect0(v *Value) bool { v_0 := v.Args[0] + b := v.Block + typ := &b.Func.Config.Types + // match: (Select0 (Add64carry x y c)) + // result: (ADD (ADD x y) c) + for { + if v_0.Op != OpAdd64carry { + break + } + c := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + v.reset(OpRISCV64ADD) + v0 := b.NewValue0(v.Pos, OpRISCV64ADD, typ.UInt64) + v0.AddArg2(x, y) + v.AddArg2(v0, c) + return true + } // match: (Select0 m:(LoweredMuluhilo x y)) // cond: m.Uses == 1 // result: (MULHU x y) @@ -6057,6 +6110,29 @@ func rewriteValueRISCV64_OpSelect0(v *Value) bool { } func rewriteValueRISCV64_OpSelect1(v *Value) bool { v_0 := v.Args[0] + b := v.Block + typ := &b.Func.Config.Types + // match: (Select1 (Add64carry x y c)) + // result: (OR (SLTU s:(ADD x y) x) (SLTU (ADD s c) s)) + for { + if v_0.Op != OpAdd64carry { + break + } + c := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + v.reset(OpRISCV64OR) + v0 := b.NewValue0(v.Pos, OpRISCV64SLTU, typ.UInt64) + s := b.NewValue0(v.Pos, OpRISCV64ADD, typ.UInt64) + s.AddArg2(x, y) + v0.AddArg2(s, x) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTU, typ.UInt64) + v3 := b.NewValue0(v.Pos, OpRISCV64ADD, typ.UInt64) + v3.AddArg2(s, c) + v2.AddArg2(v3, s) + v.AddArg2(v0, v2) + return true + } // match: (Select1 m:(LoweredMuluhilo x y)) // cond: m.Uses == 1 // result: (MUL x y) diff --git a/src/cmd/compile/internal/ssagen/ssa.go b/src/cmd/compile/internal/ssagen/ssa.go index dda813518a5..107944170fc 100644 --- a/src/cmd/compile/internal/ssagen/ssa.go +++ b/src/cmd/compile/internal/ssagen/ssa.go @@ -4726,8 +4726,8 @@ func InitTables() { func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return s.newValue3(ssa.OpAdd64carry, types.NewTuple(types.Types[types.TUINT64], types.Types[types.TUINT64]), args[0], args[1], args[2]) }, - sys.AMD64, sys.ARM64, sys.PPC64, sys.S390X) - alias("math/bits", "Add", "math/bits", "Add64", sys.ArchAMD64, sys.ArchARM64, sys.ArchPPC64, sys.ArchPPC64LE, sys.ArchS390X) + sys.AMD64, sys.ARM64, sys.PPC64, sys.S390X, sys.RISCV64) + alias("math/bits", "Add", "math/bits", "Add64", sys.ArchAMD64, sys.ArchARM64, sys.ArchPPC64, sys.ArchPPC64LE, sys.ArchS390X, sys.ArchRISCV64) addF("math/bits", "Sub64", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return s.newValue3(ssa.OpSub64borrow, types.NewTuple(types.Types[types.TUINT64], types.Types[types.TUINT64]), args[0], args[1], args[2]) diff --git a/test/codegen/mathbits.go b/test/codegen/mathbits.go index a507d32843d..f36916ad03a 100644 --- a/test/codegen/mathbits.go +++ b/test/codegen/mathbits.go @@ -442,6 +442,7 @@ func Add(x, y, ci uint) (r, co uint) { // ppc64: "ADDC", "ADDE", "ADDZE" // ppc64le: "ADDC", "ADDE", "ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add(x, y, ci) } @@ -451,6 +452,7 @@ func AddC(x, ci uint) (r, co uint) { // ppc64: "ADDC", "ADDE", "ADDZE" // ppc64le: "ADDC", "ADDE", "ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add(x, 7, ci) } @@ -460,6 +462,7 @@ func AddZ(x, y uint) (r, co uint) { // ppc64: "ADDC", -"ADDE", "ADDZE" // ppc64le: "ADDC", -"ADDE", "ADDZE" // s390x:"ADDC",-"ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add(x, y, 0) } @@ -469,6 +472,7 @@ func AddR(x, y, ci uint) uint { // ppc64: "ADDC", "ADDE", -"ADDZE" // ppc64le: "ADDC", "ADDE", -"ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD",-"SLTU" r, _ := bits.Add(x, y, ci) return r } @@ -489,6 +493,7 @@ func Add64(x, y, ci uint64) (r, co uint64) { // ppc64: "ADDC", "ADDE", "ADDZE" // ppc64le: "ADDC", "ADDE", "ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add64(x, y, ci) } @@ -498,6 +503,7 @@ func Add64C(x, ci uint64) (r, co uint64) { // ppc64: "ADDC", "ADDE", "ADDZE" // ppc64le: "ADDC", "ADDE", "ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add64(x, 7, ci) } @@ -507,6 +513,7 @@ func Add64Z(x, y uint64) (r, co uint64) { // ppc64: "ADDC", -"ADDE", "ADDZE" // ppc64le: "ADDC", -"ADDE", "ADDZE" // s390x:"ADDC",-"ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add64(x, y, 0) } @@ -516,6 +523,7 @@ func Add64R(x, y, ci uint64) uint64 { // ppc64: "ADDC", "ADDE", -"ADDZE" // ppc64le: "ADDC", "ADDE", -"ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD",-"SLTU" r, _ := bits.Add64(x, y, ci) return r }