Skip to content

Commit

Permalink
effects: support callsite @assume_effects annotation
Browse files Browse the repository at this point in the history
`@assume_effects` on method definitions is very useful, but at some
times, its application can be too broad on generic signatures. We've
addressed this by defining methods specialized for particular types and
applying `@assume_effects` to them. However, this method has downsides,
such as being restricted to pre-set types or increasing nearly identical
method definitions.

To remedy this, this commit introduces a support for callsite
`@assume_effects` annotation, enabling finer control over its
application. Now the annotations for `@ccall` are just specific
instances of callsite `@assume_effects` annotation, so it might be a
good idea to replace `@ccall` annotations with callsite
`@assume_effects` annotations.
  • Loading branch information
aviatesk committed Dec 5, 2023
1 parent cc0638b commit 19f95a1
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 57 deletions.
22 changes: 14 additions & 8 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2692,14 +2692,7 @@ function abstract_eval_foreigncall(interp::AbstractInterpreter, e::Expr, vtypes:
cconv = e.args[5]
if isa(cconv, QuoteNode) && (v = cconv.value; isa(v, Tuple{Symbol, UInt16}))
override = decode_effects_override(v[2])
effects = Effects(effects;
consistent = override.consistent ? ALWAYS_TRUE : effects.consistent,
effect_free = override.effect_free ? ALWAYS_TRUE : effects.effect_free,
nothrow = override.nothrow ? true : effects.nothrow,
terminates = override.terminates_globally ? true : effects.terminates,
notaskstate = override.notaskstate ? true : effects.notaskstate,
inaccessiblememonly = override.inaccessiblememonly ? ALWAYS_TRUE : effects.inaccessiblememonly,
noub = override.noub ? ALWAYS_TRUE : override.noub_if_noinbounds ? NOUB_IF_NOINBOUNDS : effects.noub)
effects = override_effects(effects, override)
end
return RTEffects(t, Any, effects)
end
Expand Down Expand Up @@ -2754,12 +2747,25 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
# N.B.: This only applies to the effects of the statement itself.
# It is possible for arguments (GlobalRef/:static_parameter) to throw,
# but these will be recomputed during SSA construction later.
override = decode_effects_override_from_ssaflag(get_curr_ssaflag(sv))
effects = override_effects(effects, override)
set_curr_ssaflag!(sv, flags_for_effects(effects), IR_FLAGS_EFFECTS)
merge_effects!(interp, sv, effects)

return RTEffects(rt, exct, effects)
end

function override_effects(effects::Effects, override::EffectsOverride)
return Effects(effects;
consistent = override.consistent ? ALWAYS_TRUE : effects.consistent,
effect_free = override.effect_free ? ALWAYS_TRUE : effects.effect_free,
nothrow = override.nothrow ? true : effects.nothrow,
terminates = override.terminates_globally ? true : effects.terminates,
notaskstate = override.notaskstate ? true : effects.notaskstate,
inaccessiblememonly = override.inaccessiblememonly ? ALWAYS_TRUE : effects.inaccessiblememonly,
noub = override.noub ? ALWAYS_TRUE : override.noub_if_noinbounds ? NOUB_IF_NOINBOUNDS : effects.noub)
end

function isdefined_globalref(g::GlobalRef)
return ccall(:jl_globalref_boundp, Cint, (Any,), g) != 0
end
Expand Down
1 change: 1 addition & 0 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ function EffectsOverride(
noub,
noub_if_noinbounds)
end
const NUM_EFFECTS_OVERRIDES = 9 # sync with julia.h

# essential files and libraries
include("essentials.jl")
Expand Down
3 changes: 3 additions & 0 deletions base/compiler/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,6 @@ function decode_effects_override(e::UInt16)
!iszero(e & (0x0001 << 7)),
!iszero(e & (0x0001 << 8)))
end

decode_effects_override_from_ssaflag(flag::UInt32) =
decode_effects_override(UInt16((flag >> NUM_IR_FLAGS) & (1 << NUM_EFFECTS_OVERRIDES - 1)))
9 changes: 8 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,14 @@ frame_world(sv::IRInterpretationState) = sv.world
callers_in_cycle(sv::InferenceState) = sv.callers_in_cycle
callers_in_cycle(sv::IRInterpretationState) = ()

is_effect_overridden(sv::AbsIntState, effect::Symbol) = is_effect_overridden(frame_instance(sv), effect)
function is_effect_overridden(sv::AbsIntState, effect::Symbol)
if is_effect_overridden(frame_instance(sv), effect)
return true
elseif is_effect_overridden(decode_effects_override_from_ssaflag(get_curr_ssaflag(sv)), effect)
return true
end
return false
end
function is_effect_overridden(linfo::MethodInstance, effect::Symbol)
def = linfo.def
return isa(def, Method) && is_effect_overridden(def, effect)
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ const IR_FLAG_EFIIMO = one(UInt32) << 9
# This is :inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
const IR_FLAG_INACCESSIBLE_OR_ARGMEM = one(UInt32) << 10

const NUM_IR_FLAGS = 11 # sync with julia.h

const IR_FLAGS_EFFECTS = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW | IR_FLAG_CONSISTENT | IR_FLAG_NOUB

has_flag(curr::UInt32, flag::UInt32) = (curr & flag) == flag
Expand Down
41 changes: 22 additions & 19 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -708,33 +708,36 @@ the call is generally total, it may however throw.
"""
macro assume_effects(args...)
lastex = args[end]
inner = unwrap_macrocalls(lastex)
if is_function_def(inner)
ex = lastex
idx = length(args)-1
if is_function_def(unwrap_macrocalls(lastex))
override = compute_assumed_settings(args[begin:end-1])
return esc(pushmeta!(lastex, form_purity_expr(override)))
elseif isexpr(lastex, :macrocall) && lastex.args[1] === Symbol("@ccall")
ex = lastex
idx = length(args)-1
override = compute_assumed_settings(args[begin:end-1])
lastex.args[1] = GlobalRef(Base, Symbol("@ccall_effects"))
insert!(lastex.args, 3, Core.Compiler.encode_effects_override(override))
return esc(lastex)
elseif compute_assumed_setting(EffectsOverride(), lastex) === nothing
# call site annotation case
override = compute_assumed_settings(args[begin:end-1])
return Expr(:block,
form_purity_expr(override),
Expr(:local, Expr(:(=), :val, esc(lastex))),
Expr(:purity), # region end token
:val)
else # anonymous function case
ex = nothing
idx = length(args)
override = compute_assumed_settings(args)
return Expr(:meta, form_purity_expr(override))
end
end

function compute_assumed_settings(settings)
override = EffectsOverride()
for i = 1:idx
setting = args[i]
for setting in settings
override = compute_assumed_setting(override, setting)
override === nothing &&
throw(ArgumentError("@assume_effects $setting not supported"))
end
if is_function_def(inner)
return esc(pushmeta!(ex, form_purity_expr(override)))
elseif isexpr(ex, :macrocall) && ex.args[1] === Symbol("@ccall")
ex.args[1] = GlobalRef(Base, Symbol("@ccall_effects"))
insert!(ex.args, 3, Core.Compiler.encode_effects_override(override))
return esc(ex)
else # anonymous function case
return Expr(:meta, form_purity_expr(override))
end
return override
end

using Core.Compiler: EffectsOverride
Expand Down
11 changes: 2 additions & 9 deletions doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -689,15 +689,8 @@ A (usually temporary) container for holding lowered source code.

* `ssaflags`

Statement-level flags for each expression in the function. Many of these are reserved, but not yet implemented:

* 0x01 << 0 = statement is marked as `@inbounds`
* 0x01 << 1 = statement is marked as `@inline`
* 0x01 << 2 = statement is marked as `@noinline`
* 0x01 << 3 = statement is within a block that leads to `throw` call
* 0x01 << 4 = statement may be removed if its result is unused, in particular it is thus be both pure and effect free
* 0x01 << 5-6 = <unused>
* 0x01 << 7 = <reserved> has out-of-band info
Statement-level 32 bits flags for each expression in the function.
See the definition of `jl_code_info_t` in julia.h for more details.

* `linetable`

Expand Down
2 changes: 1 addition & 1 deletion src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -4877,7 +4877,7 @@ f(x) = yt(x)
(cons (car e) args)))

;; metadata expressions
((lineinfo line meta inbounds loopinfo gc_preserve_end aliasscope popaliasscope inline noinline)
((lineinfo line meta inbounds loopinfo gc_preserve_end aliasscope popaliasscope inline noinline purity)
(let ((have-ret? (and (pair? code) (pair? (car code)) (eq? (caar code) 'return))))
(cond ((eq? (car e) 'line)
(set! current-loc e)
Expand Down
26 changes: 17 additions & 9 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,20 +265,28 @@ typedef union __jl_purity_overrides_t {
uint16_t bits;
} _jl_purity_overrides_t;

#define NUM_EFFECTS_OVERRIDES 9
#define NUM_IR_FLAGS 11

// This type describes a single function body
typedef struct _jl_code_info_t {
// ssavalue-indexed arrays of properties:
jl_array_t *code; // Any array of statements
jl_value_t *codelocs; // Int32 array of indices into the line table
jl_value_t *ssavaluetypes; // types of ssa values (or count of them)
jl_array_t *ssaflags; // flags associated with each statement:
// 0 = inbounds
// 1 = inline
// 2 = noinline
// 3 = <reserved> strict-ieee (strictfp)
// 4 = effect-free (may be deleted if unused)
// 5-6 = <unused>
// 7 = has out-of-band info
jl_array_t *ssaflags; // 32 bits flags associated with each statement:
// 1 << 0 = inbounds region
// 1 << 1 = callsite inline region
// 1 << 2 = callsite noinline region
// 1 << 3 = throw block
// 1 << 4 = :effect_free
// 1 << 5 = :nothrow
// 1 << 6 = :consistent
// 1 << 7 = :refined
// 1 << 8 = :noub
// 1 << 9 = :effect_free_if_inaccessiblememonly
// 1 << 10 = :inaccessiblemem_or_argmemonly
// 1 << 11-18 = callsite effects overrides
// miscellaneous data:
jl_value_t *method_for_inference_limit_heuristics; // optional method used during inference
jl_value_t *linetable; // Table of locations [TODO: make this volatile like slotnames]
Expand Down Expand Up @@ -1250,7 +1258,7 @@ STATIC_INLINE void jl_array_uint8_set(void *a, size_t i, uint8_t x) JL_NOTSAFEPO
assert(i < jl_array_len(a));
jl_array_data(a, uint8_t)[i] = x;
}
STATIC_INLINE void jl_array_uint32_set(void *a, size_t i, uint8_t x) JL_NOTSAFEPOINT
STATIC_INLINE void jl_array_uint32_set(void *a, size_t i, uint32_t x) JL_NOTSAFEPOINT
{
assert(i < jl_array_len(a));
assert(jl_typetagis(a, jl_array_uint32_type) || jl_typetagis(a, jl_array_int32_type));
Expand Down
2 changes: 1 addition & 1 deletion src/macroexpand.scm
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@
,(resolve-expansion-vars-with-new-env (caddr arg) env m parent-scope inarg))
(unescape-global-lhs arg env m parent-scope inarg)))
(cdr e))))
((using import export meta line inbounds boundscheck loopinfo inline noinline) (map unescape e))
((using import export meta line inbounds boundscheck loopinfo inline noinline purity) (map unescape e))
((macrocall) e) ; invalid syntax anyways, so just act like it's quoted.
((symboliclabel) e)
((symbolicgoto) e)
Expand Down
41 changes: 32 additions & 9 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,11 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
li->ssaflags = jl_alloc_array_1d(jl_array_uint32_type, n);
jl_gc_wb(li, li->ssaflags);
int inbounds_depth = 0; // number of stacked inbounds
// isempty(inline_flags): no user annotation
// last(inline_flags) == 1: inline region
// last(inline_flags) == 0: noinline region
// isempty(inline_flags): no user callsite inline annotation
// last(inline_flags) == 1: callsite inline region
// last(inline_flags) == 0: callsite noinline region
arraylist_t *inline_flags = arraylist_new((arraylist_t*)malloc_s(sizeof(arraylist_t)), 0);
arraylist_t *purity_exprs = arraylist_new((arraylist_t*)malloc_s(sizeof(arraylist_t)), 0);
for (j = 0; j < n; j++) {
jl_value_t *st = bd[j];
int is_flag_stmt = 0;
Expand All @@ -327,7 +328,7 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
else if (ma == (jl_value_t*)jl_no_constprop_sym)
li->constprop = 2;
else if (jl_is_expr(ma) && ((jl_expr_t*)ma)->head == jl_purity_sym) {
if (jl_expr_nargs(ma) == 9) {
if (jl_expr_nargs(ma) == NUM_EFFECTS_OVERRIDES) {
li->purity.overrides.ipo_consistent = jl_unbox_bool(jl_exprarg(ma, 0));
li->purity.overrides.ipo_effect_free = jl_unbox_bool(jl_exprarg(ma, 1));
li->purity.overrides.ipo_nothrow = jl_unbox_bool(jl_exprarg(ma, 2));
Expand Down Expand Up @@ -381,6 +382,18 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
}
bd[j] = jl_nothing;
}
else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == jl_purity_sym) {
is_flag_stmt = 1;
size_t na = jl_expr_nargs(st);
if (na == 9) {
arraylist_push(purity_exprs, (void*)st);
}
else {
assert(na == 0);
arraylist_pop(purity_exprs);
}
bd[j] = jl_nothing;
}
else if (jl_is_expr(st) && ((jl_expr_t*)st)->head == jl_boundscheck_sym) {
// Don't set IR_FLAG_INBOUNDS on boundscheck at the same level
is_flag_stmt = 1;
Expand All @@ -394,19 +407,29 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
if (is_flag_stmt)
jl_array_uint32_set(li->ssaflags, j, 0);
else {
uint8_t flag = 0;
uint32_t flag = 0;
if (inbounds_depth > 0)
flag |= IR_FLAG_INBOUNDS;
if (inline_flags->len > 0) {
void* inline_flag = inline_flags->items[inline_flags->len - 1];
void* inline_flag = inline_flags->items[inline_flags->len-1];
flag |= 1 << (inline_flag ? 1 : 2);
}
int n_purity_exprs = purity_exprs->len;
if (n_purity_exprs > 0) {
// apply all purity overrides
for (int i = 0; i < n_purity_exprs; i++) {
void* purity_expr = purity_exprs->items[i];
for (int j = 0; j < NUM_EFFECTS_OVERRIDES; j++) {
flag |= jl_unbox_bool(jl_exprarg((jl_value_t*)purity_expr, j)) ? (1 << (NUM_IR_FLAGS+j)) : 0;
}
}
}
jl_array_uint32_set(li->ssaflags, j, flag);
}
}
assert(inline_flags->len == 0); // malformed otherwise
arraylist_free(inline_flags);
free(inline_flags);
assert(inline_flags->len == 0 && purity_exprs->len == 0); // malformed otherwise
arraylist_free(inline_flags); arraylist_free(purity_exprs);
free(inline_flags); free(purity_exprs);
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ir, 1);
jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0);
size_t nslots = jl_array_nrows(vis);
Expand Down
38 changes: 38 additions & 0 deletions test/compiler/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1302,3 +1302,41 @@ function getindex_nothrow(xs::Vector{Int}, i::Int)
end
end
@test Core.Compiler.is_nothrow(Base.infer_effects(getindex_nothrow, (Vector{Int}, Int)))

# callsite `@assume_effects` annotation
let ast = code_lowered((Int,)) do x
Base.@assume_effects :total identity(x)
end |> only
ssaflag = ast.ssaflags[findfirst(!iszero, ast.ssaflags)::Int]
override = Core.Compiler.decode_effects_override_from_ssaflag(ssaflag)
# if this gets broken, check if this is synced with expr.jl
@test override.consistent && override.effect_free && override.nothrow &&
override.terminates_globally && !override.terminates_locally &&
override.notaskstate && override.inaccessiblememonly &&
override.noub && !override.noub_if_noinbounds
end
@test Base.infer_effects((Float64,)) do x
isinf(x) && return 0.0
return Base.@assume_effects :nothrow sin(x)
end |> Core.Compiler.is_nothrow
let effects = Base.infer_effects((Vector{Float64},)) do xs
isempty(xs) && return 0.0
Base.@assume_effects :nothrow begin
x = Base.@assume_effects :noub @inbounds xs[1]
isinf(x) && return 0.0
return sin(x)
end
end
# all nested overrides should be applied
@test Core.Compiler.is_nothrow(effects)
@test Core.Compiler.is_noub(effects)
end
@test Base.infer_effects((Int,)) do x
res = 1
0 x < 20 || error("bad fact")
Base.@assume_effects :terminates_locally while x > 1
res *= x
x -= 1
end
return res
end |> Core.Compiler.is_terminates

0 comments on commit 19f95a1

Please sign in to comment.