Skip to content

Commit

Permalink
allow CodeInfo to be returned directly from generated function genera…
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Aug 2, 2017
1 parent 51fbb0c commit 898d650
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 28 deletions.
60 changes: 32 additions & 28 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -316,37 +316,41 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
jl_value_t *generated_body = jl_call_staged(sparam_vals, generator, jl_svec_data(tt->parameters), jl_nparams(tt));
jl_array_ptr_set(body->args, 2, generated_body);

if (linfo->def.method->sparam_syms != jl_emptysvec) {
// mark this function as having the same static parameters as the generator
size_t i, nsp = jl_svec_len(linfo->def.method->sparam_syms);
jl_expr_t *newast = jl_exprn(jl_symbol("with-static-parameters"), nsp + 1);
jl_exprarg(newast, 0) = (jl_value_t*)ex;
// (with-static-parameters func_expr sp_1 sp_2 ...)
for (i = 0; i < nsp; i++)
jl_exprarg(newast, i+1) = jl_svecref(linfo->def.method->sparam_syms, i);
ex = newast;
}
if (jl_is_code_info(generated_body)) {
func = (jl_code_info_t*)generated_body;
} else {
if (linfo->def.method->sparam_syms != jl_emptysvec) {
// mark this function as having the same static parameters as the generator
size_t i, nsp = jl_svec_len(linfo->def.method->sparam_syms);
jl_expr_t *newast = jl_exprn(jl_symbol("with-static-parameters"), nsp + 1);
jl_exprarg(newast, 0) = (jl_value_t*)ex;
// (with-static-parameters func_expr sp_1 sp_2 ...)
for (i = 0; i < nsp; i++)
jl_exprarg(newast, i+1) = jl_svecref(linfo->def.method->sparam_syms, i);
ex = newast;
}

func = (jl_code_info_t*)jl_expand((jl_value_t*)ex, linfo->def.method->module);
if (!jl_is_code_info(func)) {
if (jl_is_expr(func) && ((jl_expr_t*)func)->head == error_sym)
jl_interpret_toplevel_expr_in(linfo->def.method->module, (jl_value_t*)func, NULL, NULL);
jl_error("generated function body is not pure. this likely means it contains a closure or comprehension.");
}
func = (jl_code_info_t*)jl_expand((jl_value_t*)ex, linfo->def.method->module);
if (!jl_is_code_info(func)) {
if (jl_is_expr(func) && ((jl_expr_t*)func)->head == error_sym)
jl_interpret_toplevel_expr_in(linfo->def.method->module, (jl_value_t*)func, NULL, NULL);
jl_error("generated function body is not pure. this likely means it contains a closure or comprehension.");
}

jl_array_t *stmts = (jl_array_t*)func->code;
size_t i, l;
for (i = 0, l = jl_array_len(stmts); i < l; i++) {
jl_value_t *stmt = jl_array_ptr_ref(stmts, i);
stmt = jl_resolve_globals(stmt, linfo->def.method->module, env);
jl_array_ptr_set(stmts, i, stmt);
}
jl_array_t *stmts = (jl_array_t*)func->code;
size_t i, l;
for (i = 0, l = jl_array_len(stmts); i < l; i++) {
jl_value_t *stmt = jl_array_ptr_ref(stmts, i);
stmt = jl_resolve_globals(stmt, linfo->def.method->module, env);
jl_array_ptr_set(stmts, i, stmt);
}

// add pop_loc meta
jl_array_ptr_1d_push(stmts, jl_nothing);
jl_expr_t *poploc = jl_exprn(meta_sym, 1);
jl_array_ptr_set(stmts, jl_array_len(stmts) - 1, poploc);
jl_array_ptr_set(poploc->args, 0, jl_symbol("pop_loc"));
// add pop_loc meta
jl_array_ptr_1d_push(stmts, jl_nothing);
jl_expr_t *poploc = jl_exprn(meta_sym, 1);
jl_array_ptr_set(stmts, jl_array_len(stmts) - 1, poploc);
jl_array_ptr_set(poploc->args, 0, jl_symbol("pop_loc"));
}

ptls->in_pure_callback = last_in;
jl_lineno = last_lineno;
Expand Down
22 changes: 22 additions & 0 deletions test/staged.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,25 @@ g10178(x) = f10178(x)
# issue #22135
@generated f22135(x::T) where T = x
@test f22135(1) === Int

# PR #22440

f22440kernel(x...) = x[1] + x[1]
f22440kernel(x::AbstractFloat) = x * x
f22440kernel(::Type{T}) where {T} = one(T)
f22440kernel(::Type{T}) where {T<:AbstractFloat} = zero(T)

@generated function f22440(y)
sig, spvals, method = Base._methods_by_ftype(Tuple{typeof(f22440kernel),y}, -1, typemax(UInt))[1]
code_info = Base.uncompressed_ast(method)
body = Expr(:block, code_info.code...)
Base.Core.Inference.substitute!(body, 0, Any[], sig, Any[spvals...], 0)
return code_info
end

@test f22440(Int) === f22440kernel(Int)
@test f22440(Float64) === f22440kernel(Float64)
@test f22440(Float32) === f22440kernel(Float32)
@test f22440(0.0) === f22440kernel(0.0)
@test f22440(0.0f0) === f22440kernel(0.0f0)
@test f22440(0) === f22440kernel(0)

0 comments on commit 898d650

Please sign in to comment.