From 0ba19fe6ec83c7d9341b2c7527c9cc4aa4b8b678 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Sat, 25 Nov 2023 11:23:33 +0900 Subject: [PATCH] irinterp: improve semi-concrete interpretation accuracy (#52275) By enforcing re-inference on calls with all constant arguments. While it's debatable whether this approach is the most efficient, it was the easiest choice given that `used_ssas` based on `IncrementaCompact` wasn't an option for irinterp. - fixes #52202 - fixes #50037 --- base/compiler/abstractinterpretation.jl | 6 ++++-- base/compiler/ssair/ir.jl | 2 +- base/compiler/ssair/irinterp.jl | 12 ++++++++++++ stdlib/LinearAlgebra/test/matmul.jl | 4 ++++ test/compiler/inference.jl | 7 +++++++ 5 files changed, 28 insertions(+), 3 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 24d5eb52c1ffb..8c6840a4b0706 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -903,12 +903,14 @@ end is_all_const_arg(arginfo::ArgInfo, start::Int) = is_all_const_arg(arginfo.argtypes, start::Int) function is_all_const_arg(argtypes::Vector{Any}, start::Int) for i = start:length(argtypes) - a = widenslotwrapper(argtypes[i]) - isa(a, Const) || isconstType(a) || issingletontype(a) || return false + argtype = widenslotwrapper(argtypes[i]) + is_const_argtype(argtype) || return false end return true end +is_const_argtype(@nospecialize argtype) = isa(argtype, Const) || isconstType(argtype) || issingletontype(argtype) + any_conditional(argtypes::Vector{Any}) = any(@nospecialize(x)->isa(x, Conditional), argtypes) any_conditional(arginfo::ArgInfo) = any_conditional(arginfo.argtypes) diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 5c347846e6e18..9b89b30854cdf 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -956,7 +956,7 @@ function insert_node_here!(compact::IncrementalCompact, newinst::NewInstruction, return inst end -function delete_inst_here!(compact) +function delete_inst_here!(compact::IncrementalCompact) # Delete the statement, update refcounts etc compact[SSAValue(compact.result_idx-1)] = nothing # Pretend that we never compacted this statement in the first place diff --git a/base/compiler/ssair/irinterp.jl b/base/compiler/ssair/irinterp.jl index 9f8bc17beca7f..be072ca64e1c4 100644 --- a/base/compiler/ssair/irinterp.jl +++ b/base/compiler/ssair/irinterp.jl @@ -278,6 +278,15 @@ end populate_def_use_map!(tpdum::TwoPhaseDefUseMap, ir::IRCode) = populate_def_use_map!(tpdum, BBScanner(ir)) +function is_all_const_call(@nospecialize(stmt), interp::AbstractInterpreter, irsv::IRInterpretationState) + isexpr(stmt, :call) || return false + @inbounds for i = 2:length(stmt.args) + argtype = abstract_eval_value(interp, stmt.args[i], nothing, irsv) + is_const_argtype(argtype) || return false + end + return true +end + function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState; externally_refined::Union{Nothing,BitSet} = nothing) (; ir, tpdum, ssa_refined) = irsv @@ -302,6 +311,9 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR if has_flag(flag, IR_FLAG_REFINED) any_refined = true sub_flag!(inst, IR_FLAG_REFINED) + elseif is_all_const_call(stmt, interp, irsv) + # force reinference on calls with all constant arguments + any_refined = true end for ur in userefs(stmt) val = ur[] diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 30cc74694b3f4..afdc02534728a 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -23,6 +23,10 @@ mul_wrappers = [ @test @inferred(f(A)) === A g(A) = LinearAlgebra.wrap(A, 'T') @test @inferred(g(A)) === transpose(A) + # https://github.com/JuliaLang/julia/issues/52202 + @test Base.infer_return_type((Vector{Float64},)) do v + LinearAlgebra.wrap(v, 'N') + end == Vector{Float64} end @testset "matrices with zero dimensions" begin diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 9e06d120e501f..88bd4d8fba007 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -5592,6 +5592,13 @@ end |> only === Float64 @test Base.infer_exception_type(c::Missing -> c ? 1 : 2) == TypeError @test Base.infer_exception_type(c::Any -> c ? 1 : 2) == TypeError +# semi-concrete interpretation accuracy +# https://github.com/JuliaLang/julia/issues/50037 +@inline countvars50037(bitflags::Int, var::Int) = bitflags >> 0 +@test Base.infer_return_type() do var::Int + Val(countvars50037(1, var)) +end == Val{1} + # Issue #52168 f52168(x, t::Type) = x::NTuple{2, Base.inferencebarrier(t)::Type} @test f52168((1, 2.), Any) === (1, 2.)