From 70c7fb75e9768630ca23ff5cbf79c9b597bc068e Mon Sep 17 00:00:00 2001 From: Joel Sing Date: Sun, 26 Nov 2023 02:48:05 +1100 Subject: [PATCH] cmd/compile: correct code generation for right shifts on riscv64 The code generation on riscv64 will currently result in incorrect assembly when a 32 bit integer is right shifted by an amount that exceeds the size of the type. In particular, this occurs when an int32 or uint32 is cast to a 64 bit type and right shifted by a value larger than 31. Fix this by moving the SRAW/SRLW conversion into the right shift rules and removing the SignExt32to64/ZeroExt32to64. Add additional rules that rewrite to SRAIW/SRLIW when the shift is less than the size of the type, or replace/eliminate the shift when it exceeds the size of the type. Add SSA tests that would have caught this issue. Also add additional codegen tests to ensure that the resulting assembly is what we expect in these overflow cases. Fixes #64285 Change-Id: Ie97b05668597cfcb91413afefaab18ee1aa145ec Reviewed-on: https://go-review.googlesource.com/c/go/+/545035 Reviewed-by: Russ Cox Reviewed-by: Cherry Mui Reviewed-by: M Zhuo Reviewed-by: Mark Ryan Run-TryBot: Joel Sing TryBot-Result: Gopher Robot --- .../compile/internal/ssa/_gen/RISCV64.rules | 100 +++-- .../compile/internal/ssa/rewriteRISCV64.go | 409 ++++++++++-------- .../internal/test/testdata/arith_test.go | 66 +++ test/codegen/shift.go | 30 ++ 4 files changed, 383 insertions(+), 222 deletions(-) diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index 9afe5995ae8..fc206c42d3d 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -153,27 +153,27 @@ // SRL only considers the bottom 6 bits of y, similarly SRLW only considers the // bottom 5 bits of y. Ensure that the result is always zero if the shift exceeds // the maximum value. See Lsh above for a detailed description. -(Rsh8Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt8to64 y)))) -(Rsh8Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt16to64 y)))) -(Rsh8Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt32to64 y)))) -(Rsh8Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] y))) -(Rsh16Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt8to64 y)))) -(Rsh16Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt16to64 y)))) -(Rsh16Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt32to64 y)))) -(Rsh16Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] y))) -(Rsh32Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) -(Rsh32Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) -(Rsh32Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) -(Rsh32Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] y))) -(Rsh64Ux8 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt8to64 y)))) -(Rsh64Ux16 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt16to64 y)))) -(Rsh64Ux32 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt32to64 y)))) -(Rsh64Ux64 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] y))) - -(Rsh8Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt8to64 x) y) -(Rsh16Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt16to64 x) y) -(Rsh32Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt32to64 x) y) -(Rsh64Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL x y) +(Rsh8Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt8to64 y)))) +(Rsh8Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt16to64 y)))) +(Rsh8Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] (ZeroExt32to64 y)))) +(Rsh8Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt8to64 x) y) (Neg8 (SLTIU [64] y))) +(Rsh16Ux8 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt8to64 y)))) +(Rsh16Ux16 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt16to64 y)))) +(Rsh16Ux32 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] (ZeroExt32to64 y)))) +(Rsh16Ux64 x y) && !shiftIsBounded(v) => (AND (SRL (ZeroExt16to64 x) y) (Neg16 (SLTIU [64] y))) +(Rsh32Ux8 x y) && !shiftIsBounded(v) => (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) +(Rsh32Ux16 x y) && !shiftIsBounded(v) => (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) +(Rsh32Ux32 x y) && !shiftIsBounded(v) => (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) +(Rsh32Ux64 x y) && !shiftIsBounded(v) => (AND (SRLW x y) (Neg32 (SLTIU [32] y))) +(Rsh64Ux8 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt8to64 y)))) +(Rsh64Ux16 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt16to64 y)))) +(Rsh64Ux32 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] (ZeroExt32to64 y)))) +(Rsh64Ux64 x y) && !shiftIsBounded(v) => (AND (SRL x y) (Neg64 (SLTIU [64] y))) + +(Rsh8Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt8to64 x) y) +(Rsh16Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL (ZeroExt16to64 x) y) +(Rsh32Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRLW x y) +(Rsh64Ux(64|32|16|8) x y) && shiftIsBounded(v) => (SRL x y) // SRA only considers the bottom 6 bits of y, similarly SRAW only considers the // bottom 5 bits. If y is greater than the maximum value (either 63 or 31 @@ -188,27 +188,27 @@ // // We don't need to sign-extend the OR result, as it will be at minimum 8 bits, // more than the 5 or 6 bits SRAW and SRA care about. -(Rsh8x8 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) -(Rsh8x16 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) -(Rsh8x32 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) -(Rsh8x64 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) -(Rsh16x8 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) -(Rsh16x16 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) -(Rsh16x32 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) -(Rsh16x64 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) -(Rsh32x8 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) -(Rsh32x16 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) -(Rsh32x32 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) -(Rsh32x64 x y) && !shiftIsBounded(v) => (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] y)))) -(Rsh64x8 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) -(Rsh64x16 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) -(Rsh64x32 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) -(Rsh64x64 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] y)))) - -(Rsh8x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt8to64 x) y) -(Rsh16x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt16to64 x) y) -(Rsh32x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt32to64 x) y) -(Rsh64x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA x y) +(Rsh8x8 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) +(Rsh8x16 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) +(Rsh8x32 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) +(Rsh8x64 x y) && !shiftIsBounded(v) => (SRA (SignExt8to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) +(Rsh16x8 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) +(Rsh16x16 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) +(Rsh16x32 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) +(Rsh16x64 x y) && !shiftIsBounded(v) => (SRA (SignExt16to64 x) (OR y (ADDI [-1] (SLTIU [64] y)))) +(Rsh32x8 x y) && !shiftIsBounded(v) => (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) +(Rsh32x16 x y) && !shiftIsBounded(v) => (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) +(Rsh32x32 x y) && !shiftIsBounded(v) => (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) +(Rsh32x64 x y) && !shiftIsBounded(v) => (SRAW x (OR y (ADDI [-1] (SLTIU [32] y)))) +(Rsh64x8 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt8to64 y))))) +(Rsh64x16 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt16to64 y))))) +(Rsh64x32 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) +(Rsh64x64 x y) && !shiftIsBounded(v) => (SRA x (OR y (ADDI [-1] (SLTIU [64] y)))) + +(Rsh8x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt8to64 x) y) +(Rsh16x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA (SignExt16to64 x) y) +(Rsh32x(64|32|16|8) x y) && shiftIsBounded(v) => (SRAW x y) +(Rsh64x(64|32|16|8) x y) && shiftIsBounded(v) => (SRA x y) // Rotates. (RotateLeft8 x (MOVDconst [c])) => (Or8 (Lsh8x64 x (MOVDconst [c&7])) (Rsh8Ux64 x (MOVDconst [-c&7]))) @@ -710,10 +710,18 @@ (MOVDnop (MOVDconst [c])) => (MOVDconst [c]) // Avoid unnecessary zero and sign extension when right shifting. -(SRL (MOVWUreg x) y) => (SRLW x y) -(SRLI [x] (MOVWUreg y)) => (SRLIW [int64(x&31)] y) -(SRA (MOVWreg x) y) => (SRAW x y) -(SRAI [x] (MOVWreg y)) => (SRAIW [int64(x&31)] y) +(SRAI [x] (MOVWreg y)) && x >= 0 && x <= 31 => (SRAIW [int64(x)] y) +(SRLI [x] (MOVWUreg y)) && x >= 0 && x <= 31 => (SRLIW [int64(x)] y) + +// Replace right shifts that exceed size of signed type. +(SRAI [x] (MOVBreg y)) && x >= 8 => (SRAI [63] (SLLI [56] y)) +(SRAI [x] (MOVHreg y)) && x >= 16 => (SRAI [63] (SLLI [48] y)) +(SRAI [x] (MOVWreg y)) && x >= 32 => (SRAIW [31] y) + +// Eliminate right shifts that exceed size of unsigned type. +(SRLI [x] (MOVBUreg y)) && x >= 8 => (MOVDconst [0]) +(SRLI [x] (MOVHUreg y)) && x >= 16 => (MOVDconst [0]) +(SRLI [x] (MOVWUreg y)) && x >= 32 => (MOVDconst [0]) // Fold constant into immediate instructions where possible. (ADD (MOVDconst [val]) x) && is32Bit(val) && !t.IsPtr() => (ADDI [val] x) diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 6009c41f2d5..52ddca1c7d5 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -6260,20 +6260,6 @@ func rewriteValueRISCV64_OpRISCV64SNEZ(v *Value) bool { func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] - // match: (SRA (MOVWreg x) y) - // result: (SRAW x y) - for { - t := v.Type - if v_0.Op != OpRISCV64MOVWreg { - break - } - x := v_0.Args[0] - y := v_1 - v.reset(OpRISCV64SRAW) - v.Type = t - v.AddArg2(x, y) - return true - } // match: (SRA x (MOVDconst [val])) // result: (SRAI [int64(val&63)] x) for { @@ -6291,8 +6277,10 @@ func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { } func rewriteValueRISCV64_OpRISCV64SRAI(v *Value) bool { v_0 := v.Args[0] + b := v.Block // match: (SRAI [x] (MOVWreg y)) - // result: (SRAIW [int64(x&31)] y) + // cond: x >= 0 && x <= 31 + // result: (SRAIW [int64(x)] y) for { t := v.Type x := auxIntToInt64(v.AuxInt) @@ -6300,9 +6288,71 @@ func rewriteValueRISCV64_OpRISCV64SRAI(v *Value) bool { break } y := v_0.Args[0] + if !(x >= 0 && x <= 31) { + break + } v.reset(OpRISCV64SRAIW) v.Type = t - v.AuxInt = int64ToAuxInt(int64(x & 31)) + v.AuxInt = int64ToAuxInt(int64(x)) + v.AddArg(y) + return true + } + // match: (SRAI [x] (MOVBreg y)) + // cond: x >= 8 + // result: (SRAI [63] (SLLI [56] y)) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVBreg { + break + } + y := v_0.Args[0] + if !(x >= 8) { + break + } + v.reset(OpRISCV64SRAI) + v.AuxInt = int64ToAuxInt(63) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(56) + v0.AddArg(y) + v.AddArg(v0) + return true + } + // match: (SRAI [x] (MOVHreg y)) + // cond: x >= 16 + // result: (SRAI [63] (SLLI [48] y)) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVHreg { + break + } + y := v_0.Args[0] + if !(x >= 16) { + break + } + v.reset(OpRISCV64SRAI) + v.AuxInt = int64ToAuxInt(63) + v0 := b.NewValue0(v.Pos, OpRISCV64SLLI, t) + v0.AuxInt = int64ToAuxInt(48) + v0.AddArg(y) + v.AddArg(v0) + return true + } + // match: (SRAI [x] (MOVWreg y)) + // cond: x >= 32 + // result: (SRAIW [31] y) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWreg { + break + } + y := v_0.Args[0] + if !(x >= 32) { + break + } + v.reset(OpRISCV64SRAIW) + v.AuxInt = int64ToAuxInt(31) v.AddArg(y) return true } @@ -6341,20 +6391,6 @@ func rewriteValueRISCV64_OpRISCV64SRAW(v *Value) bool { func rewriteValueRISCV64_OpRISCV64SRL(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] - // match: (SRL (MOVWUreg x) y) - // result: (SRLW x y) - for { - t := v.Type - if v_0.Op != OpRISCV64MOVWUreg { - break - } - x := v_0.Args[0] - y := v_1 - v.reset(OpRISCV64SRLW) - v.Type = t - v.AddArg2(x, y) - return true - } // match: (SRL x (MOVDconst [val])) // result: (SRLI [int64(val&63)] x) for { @@ -6373,7 +6409,8 @@ func rewriteValueRISCV64_OpRISCV64SRL(v *Value) bool { func rewriteValueRISCV64_OpRISCV64SRLI(v *Value) bool { v_0 := v.Args[0] // match: (SRLI [x] (MOVWUreg y)) - // result: (SRLIW [int64(x&31)] y) + // cond: x >= 0 && x <= 31 + // result: (SRLIW [int64(x)] y) for { t := v.Type x := auxIntToInt64(v.AuxInt) @@ -6381,12 +6418,66 @@ func rewriteValueRISCV64_OpRISCV64SRLI(v *Value) bool { break } y := v_0.Args[0] + if !(x >= 0 && x <= 31) { + break + } v.reset(OpRISCV64SRLIW) v.Type = t - v.AuxInt = int64ToAuxInt(int64(x & 31)) + v.AuxInt = int64ToAuxInt(int64(x)) v.AddArg(y) return true } + // match: (SRLI [x] (MOVBUreg y)) + // cond: x >= 8 + // result: (MOVDconst [0]) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVBUreg { + break + } + if !(x >= 8) { + break + } + v.reset(OpRISCV64MOVDconst) + v.Type = t + v.AuxInt = int64ToAuxInt(0) + return true + } + // match: (SRLI [x] (MOVHUreg y)) + // cond: x >= 16 + // result: (MOVDconst [0]) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVHUreg { + break + } + if !(x >= 16) { + break + } + v.reset(OpRISCV64MOVDconst) + v.Type = t + v.AuxInt = int64ToAuxInt(0) + return true + } + // match: (SRLI [x] (MOVWUreg y)) + // cond: x >= 32 + // result: (MOVDconst [0]) + for { + t := v.Type + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVWUreg { + break + } + if !(x >= 32) { + break + } + v.reset(OpRISCV64MOVDconst) + v.Type = t + v.AuxInt = int64ToAuxInt(0) + return true + } // match: (SRLI [x] (MOVDconst [y])) // result: (MOVDconst [int64(uint64(y) >> uint32(x))]) for { @@ -7035,7 +7126,7 @@ func rewriteValueRISCV64_OpRsh32Ux16(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux16 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) + // result: (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt16to64 y)))) for { t := v.Type x := v_0 @@ -7044,33 +7135,29 @@ func rewriteValueRISCV64_OpRsh32Ux16(v *Value) bool { break } v.reset(OpRISCV64AND) - v0 := b.NewValue0(v.Pos, OpRISCV64SRL, t) - v1 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v1.AddArg(x) - v0.AddArg2(v1, y) - v2 := b.NewValue0(v.Pos, OpNeg32, t) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64SRLW, t) + v0.AddArg2(x, y) + v1 := b.NewValue0(v.Pos, OpNeg32, t) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v.AddArg2(v0, v2) + v1.AddArg(v2) + v.AddArg2(v0, v1) return true } // match: (Rsh32Ux16 x y) // cond: shiftIsBounded(v) - // result: (SRL (ZeroExt32to64 x) y) + // result: (SRLW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRL) - v0 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRLW) + v.AddArg2(x, y) return true } return false @@ -7082,7 +7169,7 @@ func rewriteValueRISCV64_OpRsh32Ux32(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux32 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) + // result: (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt32to64 y)))) for { t := v.Type x := v_0 @@ -7091,33 +7178,29 @@ func rewriteValueRISCV64_OpRsh32Ux32(v *Value) bool { break } v.reset(OpRISCV64AND) - v0 := b.NewValue0(v.Pos, OpRISCV64SRL, t) - v1 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v1.AddArg(x) - v0.AddArg2(v1, y) - v2 := b.NewValue0(v.Pos, OpNeg32, t) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64SRLW, t) + v0.AddArg2(x, y) + v1 := b.NewValue0(v.Pos, OpNeg32, t) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v.AddArg2(v0, v2) + v1.AddArg(v2) + v.AddArg2(v0, v1) return true } // match: (Rsh32Ux32 x y) // cond: shiftIsBounded(v) - // result: (SRL (ZeroExt32to64 x) y) + // result: (SRLW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRL) - v0 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRLW) + v.AddArg2(x, y) return true } return false @@ -7126,10 +7209,9 @@ func rewriteValueRISCV64_OpRsh32Ux64(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] b := v.Block - typ := &b.Func.Config.Types // match: (Rsh32Ux64 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] y))) + // result: (AND (SRLW x y) (Neg32 (SLTIU [32] y))) for { t := v.Type x := v_0 @@ -7138,31 +7220,27 @@ func rewriteValueRISCV64_OpRsh32Ux64(v *Value) bool { break } v.reset(OpRISCV64AND) - v0 := b.NewValue0(v.Pos, OpRISCV64SRL, t) - v1 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v1.AddArg(x) - v0.AddArg2(v1, y) - v2 := b.NewValue0(v.Pos, OpNeg32, t) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(32) - v3.AddArg(y) - v2.AddArg(v3) - v.AddArg2(v0, v2) + v0 := b.NewValue0(v.Pos, OpRISCV64SRLW, t) + v0.AddArg2(x, y) + v1 := b.NewValue0(v.Pos, OpNeg32, t) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) + v2.AuxInt = int64ToAuxInt(32) + v2.AddArg(y) + v1.AddArg(v2) + v.AddArg2(v0, v1) return true } // match: (Rsh32Ux64 x y) // cond: shiftIsBounded(v) - // result: (SRL (ZeroExt32to64 x) y) + // result: (SRLW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRL) - v0 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRLW) + v.AddArg2(x, y) return true } return false @@ -7174,7 +7252,7 @@ func rewriteValueRISCV64_OpRsh32Ux8(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32Ux8 x y) // cond: !shiftIsBounded(v) - // result: (AND (SRL (ZeroExt32to64 x) y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) + // result: (AND (SRLW x y) (Neg32 (SLTIU [32] (ZeroExt8to64 y)))) for { t := v.Type x := v_0 @@ -7183,33 +7261,29 @@ func rewriteValueRISCV64_OpRsh32Ux8(v *Value) bool { break } v.reset(OpRISCV64AND) - v0 := b.NewValue0(v.Pos, OpRISCV64SRL, t) - v1 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v1.AddArg(x) - v0.AddArg2(v1, y) - v2 := b.NewValue0(v.Pos, OpNeg32, t) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64SRLW, t) + v0.AddArg2(x, y) + v1 := b.NewValue0(v.Pos, OpNeg32, t) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, t) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v.AddArg2(v0, v2) + v1.AddArg(v2) + v.AddArg2(v0, v1) return true } // match: (Rsh32Ux8 x y) // cond: shiftIsBounded(v) - // result: (SRL (ZeroExt32to64 x) y) + // result: (SRLW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRL) - v0 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRLW) + v.AddArg2(x, y) return true } return false @@ -7221,7 +7295,7 @@ func rewriteValueRISCV64_OpRsh32x16(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x16 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) + // result: (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt16to64 y))))) for { t := v.Type x := v_0 @@ -7229,36 +7303,32 @@ func rewriteValueRISCV64_OpRsh32x16(v *Value) bool { if !(!shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) + v.reset(OpRISCV64SRAW) v.Type = t - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v1 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) - v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) - v2.AuxInt = int64ToAuxInt(-1) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) + v1 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) + v1.AuxInt = int64ToAuxInt(-1) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt16to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v1.AddArg2(y, v2) - v.AddArg2(v0, v1) + v1.AddArg(v2) + v0.AddArg2(y, v1) + v.AddArg2(x, v0) return true } // match: (Rsh32x16 x y) // cond: shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) y) + // result: (SRAW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRAW) + v.AddArg2(x, y) return true } return false @@ -7270,7 +7340,7 @@ func rewriteValueRISCV64_OpRsh32x32(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x32 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) + // result: (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt32to64 y))))) for { t := v.Type x := v_0 @@ -7278,36 +7348,32 @@ func rewriteValueRISCV64_OpRsh32x32(v *Value) bool { if !(!shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) + v.reset(OpRISCV64SRAW) v.Type = t - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v1 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) - v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) - v2.AuxInt = int64ToAuxInt(-1) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) + v1 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) + v1.AuxInt = int64ToAuxInt(-1) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt32to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v1.AddArg2(y, v2) - v.AddArg2(v0, v1) + v1.AddArg(v2) + v0.AddArg2(y, v1) + v.AddArg2(x, v0) return true } // match: (Rsh32x32 x y) // cond: shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) y) + // result: (SRAW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRAW) + v.AddArg2(x, y) return true } return false @@ -7316,10 +7382,9 @@ func rewriteValueRISCV64_OpRsh32x64(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] b := v.Block - typ := &b.Func.Config.Types // match: (Rsh32x64 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] y)))) + // result: (SRAW x (OR y (ADDI [-1] (SLTIU [32] y)))) for { t := v.Type x := v_0 @@ -7327,34 +7392,30 @@ func rewriteValueRISCV64_OpRsh32x64(v *Value) bool { if !(!shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) + v.reset(OpRISCV64SRAW) v.Type = t - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v1 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) - v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) - v2.AuxInt = int64ToAuxInt(-1) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(32) - v3.AddArg(y) - v2.AddArg(v3) - v1.AddArg2(y, v2) - v.AddArg2(v0, v1) + v0 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) + v1 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) + v1.AuxInt = int64ToAuxInt(-1) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) + v2.AuxInt = int64ToAuxInt(32) + v2.AddArg(y) + v1.AddArg(v2) + v0.AddArg2(y, v1) + v.AddArg2(x, v0) return true } // match: (Rsh32x64 x y) // cond: shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) y) + // result: (SRAW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRAW) + v.AddArg2(x, y) return true } return false @@ -7366,7 +7427,7 @@ func rewriteValueRISCV64_OpRsh32x8(v *Value) bool { typ := &b.Func.Config.Types // match: (Rsh32x8 x y) // cond: !shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) + // result: (SRAW x (OR y (ADDI [-1] (SLTIU [32] (ZeroExt8to64 y))))) for { t := v.Type x := v_0 @@ -7374,36 +7435,32 @@ func rewriteValueRISCV64_OpRsh32x8(v *Value) bool { if !(!shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) + v.reset(OpRISCV64SRAW) v.Type = t - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v1 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) - v2 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) - v2.AuxInt = int64ToAuxInt(-1) - v3 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) - v3.AuxInt = int64ToAuxInt(32) - v4 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) - v4.AddArg(y) - v3.AddArg(v4) + v0 := b.NewValue0(v.Pos, OpRISCV64OR, y.Type) + v1 := b.NewValue0(v.Pos, OpRISCV64ADDI, y.Type) + v1.AuxInt = int64ToAuxInt(-1) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTIU, y.Type) + v2.AuxInt = int64ToAuxInt(32) + v3 := b.NewValue0(v.Pos, OpZeroExt8to64, typ.UInt64) + v3.AddArg(y) v2.AddArg(v3) - v1.AddArg2(y, v2) - v.AddArg2(v0, v1) + v1.AddArg(v2) + v0.AddArg2(y, v1) + v.AddArg2(x, v0) return true } // match: (Rsh32x8 x y) // cond: shiftIsBounded(v) - // result: (SRA (SignExt32to64 x) y) + // result: (SRAW x y) for { x := v_0 y := v_1 if !(shiftIsBounded(v)) { break } - v.reset(OpRISCV64SRA) - v0 := b.NewValue0(v.Pos, OpSignExt32to64, typ.Int64) - v0.AddArg(x) - v.AddArg2(v0, y) + v.reset(OpRISCV64SRAW) + v.AddArg2(x, y) return true } return false diff --git a/src/cmd/compile/internal/test/testdata/arith_test.go b/src/cmd/compile/internal/test/testdata/arith_test.go index 2b8cd9fad34..cd7b5bc2c4a 100644 --- a/src/cmd/compile/internal/test/testdata/arith_test.go +++ b/src/cmd/compile/internal/test/testdata/arith_test.go @@ -268,6 +268,70 @@ func testOverflowConstShift(t *testing.T) { } } +//go:noinline +func rsh64x64ConstOverflow8(x int8) int64 { + return int64(x) >> 9 +} + +//go:noinline +func rsh64x64ConstOverflow16(x int16) int64 { + return int64(x) >> 17 +} + +//go:noinline +func rsh64x64ConstOverflow32(x int32) int64 { + return int64(x) >> 33 +} + +func testArithRightShiftConstOverflow(t *testing.T) { + allSet := int64(-1) + if got, want := rsh64x64ConstOverflow8(0x7f), int64(0); got != want { + t.Errorf("rsh64x64ConstOverflow8 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow16(0x7fff), int64(0); got != want { + t.Errorf("rsh64x64ConstOverflow16 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow32(0x7ffffff), int64(0); got != want { + t.Errorf("rsh64x64ConstOverflow32 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow8(int8(-1)), allSet; got != want { + t.Errorf("rsh64x64ConstOverflow8 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow16(int16(-1)), allSet; got != want { + t.Errorf("rsh64x64ConstOverflow16 failed: got %v, want %v", got, want) + } + if got, want := rsh64x64ConstOverflow32(int32(-1)), allSet; got != want { + t.Errorf("rsh64x64ConstOverflow32 failed: got %v, want %v", got, want) + } +} + +//go:noinline +func rsh64Ux64ConstOverflow8(x uint8) uint64 { + return uint64(x) >> 9 +} + +//go:noinline +func rsh64Ux64ConstOverflow16(x uint16) uint64 { + return uint64(x) >> 17 +} + +//go:noinline +func rsh64Ux64ConstOverflow32(x uint32) uint64 { + return uint64(x) >> 33 +} + +func testRightShiftConstOverflow(t *testing.T) { + if got, want := rsh64Ux64ConstOverflow8(0xff), uint64(0); got != want { + t.Errorf("rsh64Ux64ConstOverflow8 failed: got %v, want %v", got, want) + } + if got, want := rsh64Ux64ConstOverflow16(0xffff), uint64(0); got != want { + t.Errorf("rsh64Ux64ConstOverflow16 failed: got %v, want %v", got, want) + } + if got, want := rsh64Ux64ConstOverflow32(0xffffffff), uint64(0); got != want { + t.Errorf("rsh64Ux64ConstOverflow32 failed: got %v, want %v", got, want) + } +} + // test64BitConstMult tests that rewrite rules don't fold 64 bit constants // into multiply instructions. func test64BitConstMult(t *testing.T) { @@ -918,6 +982,8 @@ func TestArithmetic(t *testing.T) { testShiftCX(t) testSubConst(t) testOverflowConstShift(t) + testArithRightShiftConstOverflow(t) + testRightShiftConstOverflow(t) testArithConstShift(t) testArithRshConst(t) testLargeConst(t) diff --git a/test/codegen/shift.go b/test/codegen/shift.go index 32cfaffae00..50d60426d0e 100644 --- a/test/codegen/shift.go +++ b/test/codegen/shift.go @@ -22,12 +22,42 @@ func rshConst64Ux64(v uint64) uint64 { return v >> uint64(33) } +func rshConst64Ux64Overflow32(v uint32) uint64 { + // riscv64:"MOV\t\\$0,",-"SRL" + return uint64(v) >> 32 +} + +func rshConst64Ux64Overflow16(v uint16) uint64 { + // riscv64:"MOV\t\\$0,",-"SRL" + return uint64(v) >> 16 +} + +func rshConst64Ux64Overflow8(v uint8) uint64 { + // riscv64:"MOV\t\\$0,",-"SRL" + return uint64(v) >> 8 +} + func rshConst64x64(v int64) int64 { // ppc64x:"SRAD" // riscv64:"SRAI\t",-"OR",-"SLTIU" return v >> uint64(33) } +func rshConst64x64Overflow32(v int32) int64 { + // riscv64:"SRAIW",-"SLLI",-"SRAI\t" + return int64(v) >> 32 +} + +func rshConst64x64Overflow16(v int16) int64 { + // riscv64:"SLLI","SRAI",-"SRAIW" + return int64(v) >> 16 +} + +func rshConst64x64Overflow8(v int8) int64 { + // riscv64:"SLLI","SRAI",-"SRAIW" + return int64(v) >> 8 +} + func lshConst32x64(v int32) int32 { // ppc64x:"SLW" // riscv64:"SLLI",-"AND",-"SLTIU", -"MOVW"