diff --git a/test/tensor.jl b/test/tensor.jl index 2ad6edc2f0..9f9ae9b872 100644 --- a/test/tensor.jl +++ b/test/tensor.jl @@ -452,6 +452,18 @@ end mC = reshape(permutedims(C, ipC), (loA, loB)) @test mC ≈ mA * mB + # simple case with plan storage + opA = CUTENSOR.CUTENSOR_OP_IDENTITY + opB = CUTENSOR.CUTENSOR_OP_IDENTITY + opC = CUTENSOR.CUTENSOR_OP_IDENTITY + opOut = CUTENSOR.CUTENSOR_OP_IDENTITY + plan = CUTENSOR.contraction_plan(dA, indsA, opA, dB, indsB, opB, dC, indsC, opC, opOut) + dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB, + 0, dC, indsC, opC, opOut, plan=plan) + C = collect(dC) + mC = reshape(permutedims(C, ipC), (loA, loB)) + @test mC ≈ mA * mB + # with non-trivial α α = rand(eltyC) dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB,