diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index f284928b17135..603984a9e912e 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -2692,14 +2692,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes: cconv = e.args[5] if isa(cconv, QuoteNode) && (v = cconv.value; isa(v, Tuple{Symbol, UInt16})) override = decode_effects_override(v[2]) - effects = Effects(effects; - consistent = override.consistent ? ALWAYS_TRUE : effects.consistent, - effect_free = override.effect_free ? ALWAYS_TRUE : effects.effect_free, - nothrow = override.nothrow ? true : effects.nothrow, - terminates = override.terminates_globally ? true : effects.terminates, - notaskstate = override.notaskstate ? true : effects.notaskstate, - inaccessiblememonly = override.inaccessiblememonly ? ALWAYS_TRUE : effects.inaccessiblememonly, - noub = override.noub ? ALWAYS_TRUE : override.noub_if_noinbounds ? NOUB_IF_NOINBOUNDS : effects.noub) + effects = override_effects(effects, override) end return RTEffects(t, Any, effects) end @@ -2754,12 +2747,25 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), # N.B.: This only applies to the effects of the statement itself. # It is possible for arguments (GlobalRef/:static_parameter) to throw, # but these will be recomputed during SSA construction later. + override = decode_effects_override_from_ssaflag(get_curr_ssaflag(sv)) + effects = override_effects(effects, override) set_curr_ssaflag!(sv, flags_for_effects(effects), IR_FLAGS_EFFECTS) merge_effects!(interp, sv, effects) return RTEffects(rt, exct, effects) end +function override_effects(effects::Effects, override::EffectsOverride) + return Effects(effects; + consistent = override.consistent ? ALWAYS_TRUE : effects.consistent, + effect_free = override.effect_free ? ALWAYS_TRUE : effects.effect_free, + nothrow = override.nothrow ? true : effects.nothrow, + terminates = override.terminates_globally ? true : effects.terminates, + notaskstate = override.notaskstate ? true : effects.notaskstate, + inaccessiblememonly = override.inaccessiblememonly ? ALWAYS_TRUE : effects.inaccessiblememonly, + noub = override.noub ? ALWAYS_TRUE : override.noub_if_noinbounds ? NOUB_IF_NOINBOUNDS : effects.noub) +end + function isdefined_globalref(g::GlobalRef) return ccall(:jl_globalref_boundp, Cint, (Any,), g) != 0 end diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 528b8fd97869f..6868f3ea03477 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -71,6 +71,7 @@ function EffectsOverride( noub, noub_if_noinbounds) end +const NUM_EFFECTS_OVERRIDES = 9 # sync with julia.h # essential files and libraries include("essentials.jl") diff --git a/base/compiler/effects.jl b/base/compiler/effects.jl index b63cc95b02b9e..6211caef79d45 100644 --- a/base/compiler/effects.jl +++ b/base/compiler/effects.jl @@ -345,3 +345,6 @@ function decode_effects_override(e::UInt16) !iszero(e & (0x0001 << 7)), !iszero(e & (0x0001 << 8))) end + +decode_effects_override_from_ssaflag(flag::UInt32) = + decode_effects_override(UInt16((flag >> NUM_IR_FLAGS) & (1 << NUM_EFFECTS_OVERRIDES - 1))) diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 885a6dbab65b7..6a2372cbf7dde 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -818,7 +818,14 @@ frame_world(sv::IRInterpretationState) = sv.world callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle callers_in_cycle(sv::IRInterpretationState) = () -is_effect_overridden(sv::AbsIntState, effect::Symbol) = is_effect_overridden(frame_instance(sv), effect) +function is_effect_overridden(sv::AbsIntState, effect::Symbol) + if is_effect_overridden(frame_instance(sv), effect) + return true + elseif is_effect_overridden(decode_effects_override_from_ssaflag(get_curr_ssaflag(sv)), effect) + return true + end + return false +end function is_effect_overridden(linfo::MethodInstance, effect::Symbol) def = linfo.def return isa(def, Method) && is_effect_overridden(def, effect) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index fe8e8d0580bcb..dbbb78680595a 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -43,6 +43,8 @@ const IR_FLAG_EFIIMO = one(UInt32) << 9 # This is :inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY const IR_FLAG_INACCESSIBLE_OR_ARGMEM = one(UInt32) << 10 +const NUM_IR_FLAGS = 11 # sync with julia.h + const IR_FLAGS_EFFECTS = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW | IR_FLAG_CONSISTENT | IR_FLAG_NOUB has_flag(curr::UInt32, flag::UInt32) = (curr & flag) == flag diff --git a/base/expr.jl b/base/expr.jl index 168cd492b78e4..97e59659bf4e8 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -708,33 +708,36 @@ the call is generally total, it may however throw. """ macro assume_effects(args...) lastex = args[end] - inner = unwrap_macrocalls(lastex) - if is_function_def(inner) - ex = lastex - idx = length(args)-1 + if is_function_def(unwrap_macrocalls(lastex)) + override = compute_assumed_settings(args[begin:end-1]) + return esc(pushmeta!(lastex, form_purity_expr(override))) elseif isexpr(lastex, :macrocall) && lastex.args[1] === Symbol("@ccall") - ex = lastex - idx = length(args)-1 + override = compute_assumed_settings(args[begin:end-1]) + lastex.args[1] = GlobalRef(Base, Symbol("@ccall_effects")) + insert!(lastex.args, 3, Core.Compiler.encode_effects_override(override)) + return esc(lastex) + elseif compute_assumed_setting(EffectsOverride(), lastex) === nothing + # call site annotation case + override = compute_assumed_settings(args[begin:end-1]) + return Expr(:block, + form_purity_expr(override), + Expr(:local, Expr(:(=), :val, esc(lastex))), + Expr(:purity), # region end token + :val) else # anonymous function case - ex = nothing - idx = length(args) + override = compute_assumed_settings(args) + return Expr(:meta, form_purity_expr(override)) end +end + +function compute_assumed_settings(settings) override = EffectsOverride() - for i = 1:idx - setting = args[i] + for setting in settings override = compute_assumed_setting(override, setting) override === nothing && throw(ArgumentError("@assume_effects $setting not supported")) end - if is_function_def(inner) - return esc(pushmeta!(ex, form_purity_expr(override))) - elseif isexpr(ex, :macrocall) && ex.args[1] === Symbol("@ccall") - ex.args[1] = GlobalRef(Base, Symbol("@ccall_effects")) - insert!(ex.args, 3, Core.Compiler.encode_effects_override(override)) - return esc(ex) - else # anonymous function case - return Expr(:meta, form_purity_expr(override)) - end + return override end using Core.Compiler: EffectsOverride diff --git a/doc/src/devdocs/ast.md b/doc/src/devdocs/ast.md index 18ed12b3326e9..f829b27663e62 100644 --- a/doc/src/devdocs/ast.md +++ b/doc/src/devdocs/ast.md @@ -689,15 +689,8 @@ A (usually temporary) container for holding lowered source code. * `ssaflags` - Statement-level flags for each expression in the function. Many of these are reserved, but not yet implemented: - - * 0x01 << 0 = statement is marked as `@inbounds` - * 0x01 << 1 = statement is marked as `@inline` - * 0x01 << 2 = statement is marked as `@noinline` - * 0x01 << 3 = statement is within a block that leads to `throw` call - * 0x01 << 4 = statement may be removed if its result is unused, in particular it is thus be both pure and effect free - * 0x01 << 5-6 = - * 0x01 << 7 = has out-of-band info + Statement-level 32 bits flags for each expression in the function. + See the definition of `jl_code_info_t` in julia.h for more details. * `linetable` diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 4dd7735776315..5f43029a1cb40 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -4877,7 +4877,7 @@ f(x) = yt(x) (cons (car e) args))) ;; metadata expressions - ((lineinfo line meta inbounds loopinfo gc_preserve_end aliasscope popaliasscope inline noinline) + ((lineinfo line meta inbounds loopinfo gc_preserve_end aliasscope popaliasscope inline noinline purity) (let ((have-ret? (and (pair? code) (pair? (car code)) (eq? (caar code) 'return)))) (cond ((eq? (car e) 'line) (set! current-loc e) diff --git a/src/julia.h b/src/julia.h index e87f277092629..bc2e3c7335f1a 100644 --- a/src/julia.h +++ b/src/julia.h @@ -265,20 +265,28 @@ typedef union __jl_purity_overrides_t { uint16_t bits; } _jl_purity_overrides_t; +#define NUM_EFFECTS_OVERRIDES 9 +#define NUM_IR_FLAGS 11 + // This type describes a single function body typedef struct _jl_code_info_t { // ssavalue-indexed arrays of properties: jl_array_t *code; // Any array of statements jl_value_t *codelocs; // Int32 array of indices into the line table jl_value_t *ssavaluetypes; // types of ssa values (or count of them) - jl_array_t *ssaflags; // flags associated with each statement: - // 0 = inbounds - // 1 = inline - // 2 = noinline - // 3 = strict-ieee (strictfp) - // 4 = effect-free (may be deleted if unused) - // 5-6 = - // 7 = has out-of-band info + jl_array_t *ssaflags; // 32 bits flags associated with each statement: + // 1 << 0 = inbounds region + // 1 << 1 = callsite inline region + // 1 << 2 = callsite noinline region + // 1 << 3 = throw block + // 1 << 4 = :effect_free + // 1 << 5 = :nothrow + // 1 << 6 = :consistent + // 1 << 7 = :refined + // 1 << 8 = :noub + // 1 << 9 = :effect_free_if_inaccessiblememonly + // 1 << 10 = :inaccessiblemem_or_argmemonly + // 1 << 11-18 = callsite effects overrides // miscellaneous data: jl_value_t *method_for_inference_limit_heuristics; // optional method used during inference jl_value_t *linetable; // Table of locations [TODO: make this volatile like slotnames] @@ -1250,7 +1258,7 @@ STATIC_INLINE void jl_array_uint8_set(void *a, size_t i, uint8_t x) JL_NOTSAFEPO assert(i < jl_array_len(a)); jl_array_data(a, uint8_t)[i] = x; } -STATIC_INLINE void jl_array_uint32_set(void *a, size_t i, uint8_t x) JL_NOTSAFEPOINT +STATIC_INLINE void jl_array_uint32_set(void *a, size_t i, uint32_t x) JL_NOTSAFEPOINT { assert(i < jl_array_len(a)); assert(jl_typetagis(a, jl_array_uint32_type) || jl_typetagis(a, jl_array_int32_type)); diff --git a/src/macroexpand.scm b/src/macroexpand.scm index 6e390a6c24cf2..424e921a35713 100644 --- a/src/macroexpand.scm +++ b/src/macroexpand.scm @@ -385,7 +385,7 @@ ,(resolve-expansion-vars-with-new-env (caddr arg) env m parent-scope inarg)) (unescape-global-lhs arg env m parent-scope inarg))) (cdr e)))) - ((using import export meta line inbounds boundscheck loopinfo inline noinline) (map unescape e)) + ((using import export meta line inbounds boundscheck loopinfo inline noinline purity) (map unescape e)) ((macrocall) e) ; invalid syntax anyways, so just act like it's quoted. ((symboliclabel) e) ((symbolicgoto) e) diff --git a/src/method.c b/src/method.c index 9cb7e83d57c1c..b3c63c27724d4 100644 --- a/src/method.c +++ b/src/method.c @@ -301,10 +301,11 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir) li->ssaflags = jl_alloc_array_1d(jl_array_uint32_type, n); jl_gc_wb(li, li->ssaflags); int inbounds_depth = 0; // number of stacked inbounds - // isempty(inline_flags): no user annotation - // last(inline_flags) == 1: inline region - // last(inline_flags) == 0: noinline region + // isempty(inline_flags): no user callsite inline annotation + // last(inline_flags) == 1: callsite inline region + // last(inline_flags) == 0: callsite noinline region arraylist_t *inline_flags = arraylist_new((arraylist_t*)malloc_s(sizeof(arraylist_t)), 0); + arraylist_t *purity_exprs = arraylist_new((arraylist_t*)malloc_s(sizeof(arraylist_t)), 0); for (j = 0; j < n; j++) { jl_value_t *st = bd[j]; int is_flag_stmt = 0; @@ -327,7 +328,7 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir) else if (ma == (jl_value_t*)jl_no_constprop_sym) li->constprop = 2; else if (jl_is_expr(ma) && ((jl_expr_t*)ma)->head == jl_purity_sym) { - if (jl_expr_nargs(ma) == 9) { + if (jl_expr_nargs(ma) == NUM_EFFECTS_OVERRIDES) { li->purity.overrides.ipo_consistent = jl_unbox_bool(jl_exprarg(ma, 0)); li->purity.overrides.ipo_effect_free = jl_unbox_bool(jl_exprarg(ma, 1)); li->purity.overrides.ipo_nothrow = jl_unbox_bool(jl_exprarg(ma, 2)); @@ -381,6 +382,18 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir) } bd[j] = jl_nothing; } + else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == jl_purity_sym) { + is_flag_stmt = 1; + size_t na = jl_expr_nargs(st); + if (na == 9) { + arraylist_push(purity_exprs, (void*)st); + } + else { + assert(na == 0); + arraylist_pop(purity_exprs); + } + bd[j] = jl_nothing; + } else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == jl_boundscheck_sym) { // Don't set IR_FLAG_INBOUNDS on boundscheck at the same level is_flag_stmt = 1; @@ -394,19 +407,29 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir) if (is_flag_stmt) jl_array_uint32_set(li->ssaflags, j, 0); else { - uint8_t flag = 0; + uint32_t flag = 0; if (inbounds_depth > 0) flag |= IR_FLAG_INBOUNDS; if (inline_flags->len > 0) { - void* inline_flag = inline_flags->items[inline_flags->len - 1]; + void* inline_flag = inline_flags->items[inline_flags->len-1]; flag |= 1 << (inline_flag ? 1 : 2); } + int n_purity_exprs = purity_exprs->len; + if (n_purity_exprs > 0) { + // apply all purity overrides + for (int i = 0; i < n_purity_exprs; i++) { + void* purity_expr = purity_exprs->items[i]; + for (int j = 0; j < NUM_EFFECTS_OVERRIDES; j++) { + flag |= jl_unbox_bool(jl_exprarg((jl_value_t*)purity_expr, j)) ? (1 << (NUM_IR_FLAGS+j)) : 0; + } + } + } jl_array_uint32_set(li->ssaflags, j, flag); } } - assert(inline_flags->len == 0); // malformed otherwise - arraylist_free(inline_flags); - free(inline_flags); + assert(inline_flags->len == 0 && purity_exprs->len == 0); // malformed otherwise + arraylist_free(inline_flags); arraylist_free(purity_exprs); + free(inline_flags); free(purity_exprs); jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ir, 1); jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0); size_t nslots = jl_array_nrows(vis); diff --git a/test/compiler/effects.jl b/test/compiler/effects.jl index 3e0c24f908bc9..7bd7b4bb29fca 100644 --- a/test/compiler/effects.jl +++ b/test/compiler/effects.jl @@ -1302,3 +1302,41 @@ function getindex_nothrow(xs::Vector{Int}, i::Int) end end @test Core.Compiler.is_nothrow(Base.infer_effects(getindex_nothrow, (Vector{Int}, Int))) + +# callsite `@assume_effects` annotation +let ast = code_lowered((Int,)) do x + Base.@assume_effects :total identity(x) + end |> only + ssaflag = ast.ssaflags[findfirst(!iszero, ast.ssaflags)::Int] + override = Core.Compiler.decode_effects_override_from_ssaflag(ssaflag) + # if this gets broken, check if this is synced with expr.jl + @test override.consistent && override.effect_free && override.nothrow && + override.terminates_globally && !override.terminates_locally && + override.notaskstate && override.inaccessiblememonly && + override.noub && !override.noub_if_noinbounds +end +@test Base.infer_effects((Float64,)) do x + isinf(x) && return 0.0 + return Base.@assume_effects :nothrow sin(x) +end |> Core.Compiler.is_nothrow +let effects = Base.infer_effects((Vector{Float64},)) do xs + isempty(xs) && return 0.0 + Base.@assume_effects :nothrow begin + x = Base.@assume_effects :noub @inbounds xs[1] + isinf(x) && return 0.0 + return sin(x) + end + end + # all nested overrides should be applied + @test Core.Compiler.is_nothrow(effects) + @test Core.Compiler.is_noub(effects) +end +@test Base.infer_effects((Int,)) do x + res = 1 + 0 ≤ x < 20 || error("bad fact") + Base.@assume_effects :terminates_locally while x > 1 + res *= x + x -= 1 + end + return res +end |> Core.Compiler.is_terminates