Skip to content

Commit

Permalink
inference: fix too conservative effects for recursive cycles (#54323)
Browse files Browse the repository at this point in the history
The `:terminates` effect bit must be conservatively tainted unless
recursion cycle has been fully resolved. As for other effects, there's
no need to taint them at this moment because they will be tainted as we
try to resolve the cycle.

- fixes #52938
- xref #51092
  • Loading branch information
aviatesk committed Jun 5, 2024
1 parent 583981f commit 65aeaf6
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 18 deletions.
24 changes: 18 additions & 6 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,21 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
# with no active ip's, frame is done
frames = frame.callers_in_cycle
isempty(frames) && push!(frames, frame)
valid_worlds = WorldRange()
cycle_valid_worlds = WorldRange()
cycle_effects = EFFECTS_TOTAL
for caller in frames
@assert !(caller.dont_work_on_me)
caller.dont_work_on_me = true
# might might not fully intersect these earlier, so do that now
valid_worlds = intersect(caller.valid_worlds, valid_worlds)
# converge the world age range and effects for this cycle here:
# all frames in the cycle should have the same bits of `valid_worlds` and `effects`
# that are simply the intersection of each partial computation, without having
# dependencies on each other (unlike rt and exct)
cycle_valid_worlds = intersect(cycle_valid_worlds, caller.valid_worlds)
cycle_effects = merge_effects(cycle_effects, caller.ipo_effects)
end
for caller in frames
caller.valid_worlds = valid_worlds
caller.valid_worlds = cycle_valid_worlds
caller.ipo_effects = cycle_effects
finish(caller, caller.interp)
end
for caller in frames
Expand Down Expand Up @@ -864,7 +870,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
update_valid_age!(caller, frame.valid_worlds)
isinferred = is_inferred(frame)
edge = isinferred ? mi : nothing
effects = isinferred ? frame.result.ipo_effects : adjust_effects(Effects(), method) # effects are adjusted already within `finish` for ipo_effects
effects = isinferred ? frame.result.ipo_effects : # effects are adjusted already within `finish` for ipo_effects
adjust_effects(effects_for_cycle(frame.ipo_effects), method)
exc_bestguess = refine_exception_type(frame.exc_bestguess, effects)
# propagate newly inferred source to the inliner, allowing efficient inlining w/o deserialization:
# note that this result is cached globally exclusively, so we can use this local result destructively
Expand All @@ -877,11 +884,16 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
# return the current knowledge about this cycle
frame = frame::InferenceState
update_valid_age!(caller, frame.valid_worlds)
effects = adjust_effects(Effects(), method)
effects = adjust_effects(effects_for_cycle(frame.ipo_effects), method)
exc_bestguess = refine_exception_type(frame.exc_bestguess, effects)
return EdgeCallResult(frame.bestguess, exc_bestguess, nothing, effects)
end

# The `:terminates` effect bit must be conservatively tainted unless recursion cycle has
# been fully resolved. As for other effects, there's no need to taint them at this moment
# because they will be tainted as we try to resolve the cycle.
effects_for_cycle(effects::Effects) = Effects(effects; terminates=false)

function cached_return_type(code::CodeInstance)
rettype = code.rettype
isdefined(code, :rettype_const) || return rettype
Expand Down
10 changes: 9 additions & 1 deletion test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,18 @@ function CC.concrete_eval_eligible(interp::Issue48097Interp,
end
@overlay Issue48097MT @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} = return
issue48097(; kwargs...) = return 42
@test_broken fully_eliminated(; interp=Issue48097Interp(), retval=42) do
@test fully_eliminated(; interp=Issue48097Interp(), retval=42) do
issue48097(; a=1f0, b=1.0)
end

# https://github.com/JuliaLang/julia/issues/52938
@newinterp Issue52938Interp
@MethodTable ISSUE_52938_MT
CC.method_table(interp::Issue52938Interp) = CC.OverlayMethodTable(CC.get_inference_world(interp), ISSUE_52938_MT)
inner52938(x, types::Type, args...; kwargs...) = x
outer52938(x) = @inline inner52938(x, Tuple{}; foo=Ref(42), bar=1)
@test fully_eliminated(outer52938, (Any,); interp=Issue52938Interp(), retval=Argument(2))

# Should not concrete-eval overlayed methods in semi-concrete interpretation
@newinterp OverlaySinInterp
@MethodTable OverlaySinMT
Expand Down
12 changes: 6 additions & 6 deletions test/compiler/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ Base.@assume_effects :terminates_globally function recur_termination1(x)
0 x < 20 || error("bad fact")
return x * recur_termination1(x-1)
end
@test_broken Core.Compiler.is_foldable(Base.infer_effects(recur_termination1, (Int,)))
@test Core.Compiler.is_foldable(Base.infer_effects(recur_termination1, (Int,)))
@test Core.Compiler.is_terminates(Base.infer_effects(recur_termination1, (Int,)))
function recur_termination2()
Base.@assume_effects :total !:terminates_globally
recur_termination1(12)
end
@test_broken fully_eliminated(recur_termination2)
@test fully_eliminated(recur_termination2)
@test fully_eliminated() do; recur_termination2(); end

Base.@assume_effects :terminates_globally function recur_termination21(x)
Expand All @@ -104,15 +104,15 @@ Base.@assume_effects :terminates_globally function recur_termination21(x)
return recur_termination22(x)
end
recur_termination22(x) = x * recur_termination21(x-1)
@test_broken Core.Compiler.is_foldable(Base.infer_effects(recur_termination21, (Int,)))
@test_broken Core.Compiler.is_foldable(Base.infer_effects(recur_termination22, (Int,)))
@test Core.Compiler.is_foldable(Base.infer_effects(recur_termination21, (Int,)))
@test Core.Compiler.is_foldable(Base.infer_effects(recur_termination22, (Int,)))
@test Core.Compiler.is_terminates(Base.infer_effects(recur_termination21, (Int,)))
@test Core.Compiler.is_terminates(Base.infer_effects(recur_termination22, (Int,)))
function recur_termination2x()
Base.@assume_effects :total !:terminates_globally
recur_termination21(12) + recur_termination22(12)
end
@test_broken fully_eliminated(recur_termination2x)
@test fully_eliminated(recur_termination2x)
@test fully_eliminated() do; recur_termination2x(); end

# anonymous function support for `@assume_effects`
Expand Down Expand Up @@ -921,7 +921,7 @@ unknown_sparam_nothrow2(x::Ref{Ref{T}}) where T = (T; nothing)
abstractly_recursive1() = abstractly_recursive2()
abstractly_recursive2() = (Core.Compiler._return_type(abstractly_recursive1, Tuple{}); 1)
abstractly_recursive3() = abstractly_recursive2()
@test Core.Compiler.is_terminates(Base.infer_effects(abstractly_recursive3, ()))
@test_broken Core.Compiler.is_terminates(Base.infer_effects(abstractly_recursive3, ()))
actually_recursive1(x) = actually_recursive2(x)
actually_recursive2(x) = (x <= 0) ? 1 : actually_recursive1(x - 1)
actually_recursive3(x) = actually_recursive2(x)
Expand Down
18 changes: 13 additions & 5 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1595,13 +1595,21 @@ end
@testset let T = T
for f = Any[sin, cos, tan, log, log2, log10, log1p, exponent, sqrt, cbrt, fourthroot,
asin, atan, acos, sinh, cosh, tanh, asinh, acosh, atanh, exp, exp2, exp10, expm1]
@testset let f = f
@test Base.infer_return_type(f, (T,)) != Union{}
@test Core.Compiler.is_foldable(Base.infer_effects(f, (T,)))
@testset let f = f,
rt = Base.infer_return_type(f, (T,)),
effects = Base.infer_effects(f, (T,))
@test rt != Union{}
@test Core.Compiler.is_foldable(effects)
end
end
@test Core.Compiler.is_foldable(Base.infer_effects(^, (T,Int)))
@test Core.Compiler.is_foldable(Base.infer_effects(^, (T,T)))
@static if !(Sys.iswindows()&&Int==Int32) # COMBAK debug this
@testset let effects = Base.infer_effects(^, (T,Int))
@test Core.Compiler.is_foldable(effects)
end
end # @static
@testset let effects = Base.infer_effects(^, (T,T))
@test Core.Compiler.is_foldable(effects)
end
end
end
end;
Expand Down

0 comments on commit 65aeaf6

Please sign in to comment.