Skip to content

Commit

Permalink
WIP: allow @generated begin ... end inside a function to provide an…
Browse files Browse the repository at this point in the history
… optional optimizer
  • Loading branch information
JeffBezanson committed Aug 9, 2017
1 parent 00f0d23 commit bbb9284
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 60 deletions.
13 changes: 11 additions & 2 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,17 @@ end

macro generated(f)
if isa(f, Expr) && (f.head === :function || is_short_function_def(f))
pushmeta!(f, :generated)
return Expr(:escape, f)
body = f.args[2]
lno = body.args[1]
return Expr(:escape,
Expr(f.head, f.args[1],
Expr(:block,
lno,
Expr(:meta, :generator, body),
Expr(:meta, :generated_only),
Expr(:return, nothing))))
elseif isa(f, Expr) && f.head === :block
return Expr(:escape, Expr(:meta, :generator, f))
else
error("invalid syntax; @generated must be used with a function definition")
end
Expand Down
13 changes: 11 additions & 2 deletions base/methodshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ function argtype_decl(env, n, sig::DataType, i::Int, nargs, isva::Bool) # -> (ar
return s, string_with_env(env, t)
end

function method_argnames(m::Method)
if !isdefined(m, :source)
gm = first(methods(m.generator))
return method_argnames(gm)[length(m.sparam_syms)+2 : end]
end
argnames = Vector{Any}(m.nargs)
ccall(:jl_fill_argnames, Void, (Any, Any), m.source, argnames)
return argnames
end

function arg_decl_parts(m::Method)
tv = Any[]
sig = m.sig
Expand All @@ -52,8 +62,7 @@ function arg_decl_parts(m::Method)
file = m.file
line = m.line
if isdefined(m, :source) || isdefined(m, :generator)
argnames = Vector{Any}(m.nargs)
ccall(:jl_fill_argnames, Void, (Any, Any), isdefined(m, :source) ? m.source : m.generator.inferred, argnames)
argnames = method_argnames(m)
show_env = ImmutableDict{Symbol, Any}()
for t in tv
show_env = ImmutableDict(show_env, :unionall_env => t)
Expand Down
3 changes: 2 additions & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,8 @@ function length(mt::MethodTable)
end
isempty(mt::MethodTable) = (mt.defs === nothing)

uncompressed_ast(m::Method) = uncompressed_ast(m, isdefined(m,:source) ? m.source : m.generator.inferred)
uncompressed_ast(m::Method) = isdefined(m,:source) ? uncompressed_ast(m, m.source) :
uncompressed_ast(first(methods(m.generator)))
uncompressed_ast(m::Method, s::CodeInfo) = s
uncompressed_ast(m::Method, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Any), m, s)::CodeInfo

Expand Down
6 changes: 4 additions & 2 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ jl_sym_t *meta_sym; jl_sym_t *compiler_temp_sym;
jl_sym_t *inert_sym; jl_sym_t *vararg_sym;
jl_sym_t *unused_sym; jl_sym_t *static_parameter_sym;
jl_sym_t *polly_sym; jl_sym_t *inline_sym;
jl_sym_t *propagate_inbounds_sym; jl_sym_t *generated_sym;
jl_sym_t *propagate_inbounds_sym; jl_sym_t *generator_sym;
jl_sym_t *generated_only_sym;
jl_sym_t *isdefined_sym; jl_sym_t *nospecialize_sym;

static uint8_t flisp_system_image[] = {
Expand Down Expand Up @@ -437,7 +438,8 @@ void jl_init_frontend(void)
propagate_inbounds_sym = jl_symbol("propagate_inbounds");
isdefined_sym = jl_symbol("isdefined");
nospecialize_sym = jl_symbol("nospecialize");
generated_sym = jl_symbol("generated");
generator_sym = jl_symbol("generator");
generated_only_sym = jl_symbol("generated_only");
}

JL_DLLEXPORT void jl_lisp_prompt(void)
Expand Down
7 changes: 5 additions & 2 deletions src/ast.scm
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,11 @@
(and (if one (length= e 3) (length> e 2))
(eq? (car e) 'meta) (eq? (cadr e) 'nospecialize)))

(define (generated-meta? e)
(and (pair? e) (eq? (car e) 'meta) (any (lambda (x) (eq? x 'generated)) (cdr e))))
(define (generator-meta? e)
(and (length= e 3) (eq? (car e) 'meta) (eq? (cadr e) 'generator)))

(define (generated_only-meta? e)
(and (length= e 2) (eq? (car e) 'meta) (eq? (cadr e) 'generated_only)))

;; flatten nested expressions with the given head
;; (op (op a b) c) => (op a b c)
Expand Down
2 changes: 0 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1197,8 +1197,6 @@ jl_llvm_functions_t jl_compile_linfo(jl_method_instance_t **pli, jl_code_info_t
li->inferred &&
// and there is something to delete (test this before calling jl_ast_flag_inlineable)
li->inferred != jl_nothing &&
// don't delete the code for the generator
li != li->def.method->generator &&
// don't delete inlineable code, unless it is constant
(li->jlcall_api == 2 || !jl_ast_flag_inlineable((jl_array_t*)li->inferred)) &&
// don't delete code when generating a precompile file
Expand Down
2 changes: 1 addition & 1 deletion src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -1417,7 +1417,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
m->unspecialized = (jl_method_instance_t*)jl_deserialize_value(s, (jl_value_t**)&m->unspecialized);
if (m->unspecialized)
jl_gc_wb(m, m->unspecialized);
m->generator = (jl_method_instance_t*)jl_deserialize_value(s, (jl_value_t**)&m->generator);
m->generator = jl_deserialize_value(s, (jl_value_t**)&m->generator);
if (m->generator)
jl_gc_wb(m, m->generator);
m->invokes.unknown = jl_deserialize_value(s, (jl_value_t**)&m->invokes);
Expand Down
3 changes: 1 addition & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2045,7 +2045,7 @@ void jl_init_types(void)
jl_simplevector_type,
jl_any_type,
jl_any_type, // jl_method_instance_type
jl_any_type, // jl_method_instance_type
jl_any_type,
jl_array_any_type,
jl_any_type,
jl_int32_type,
Expand Down Expand Up @@ -2158,7 +2158,6 @@ void jl_init_types(void)
#endif
jl_svecset(jl_methtable_type->types, 8, jl_int32_type); // uint32_t
jl_svecset(jl_method_type->types, 10, jl_method_instance_type);
jl_svecset(jl_method_type->types, 11, jl_method_instance_type);
jl_svecset(jl_method_instance_type->types, 12, jl_voidpointer_type);
jl_svecset(jl_method_instance_type->types, 13, jl_voidpointer_type);
jl_svecset(jl_method_instance_type->types, 14, jl_voidpointer_type);
Expand Down
40 changes: 28 additions & 12 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,17 @@
(map (lambda (x) (replace-outer-vars x renames))
(cdr e))))))

(define (make-generator-function name sp-names arg-names body)
(let ((arg-names (append sp-names
(map (lambda (n)
(if (eq? n '|#self#|) (gensy) n))
arg-names))))
(let ((body (insert-after-meta body ;; don't specialize on generator arguments
`((meta nospecialize ,@arg-names)))))
`(block
(global ,name)
(function (call ,name ,@arg-names) ,body)))))

;; construct the (method ...) expression for one primitive method definition,
;; assuming optional and keyword args are already handled
(define (method-def-expr- name sparams argl body (rett '(core Any)))
Expand Down Expand Up @@ -328,7 +339,14 @@
(error "function argument and static parameter names must be distinct")))
(if (or (and name (not (sym-ref? name))) (eq? name 'true) (eq? name 'false))
(error (string "invalid function name \"" (deparse name) "\"")))
(let* ((types (llist-types argl))
(let* ((generator (let ((found (find generator-meta? body)))
(if found
(let* ((gname (symbol (string (gensy) "#" (current-julia-module-counter))))
(gf (make-generator-function gname names (llist-vars argl) (caddr (car found)))))
(set-car! (cddar found) gname)
(list gf))
'())))
(types (llist-types argl))
(body (method-lambda-expr argl body rett))
;; HACK: the typevars need to be bound to ssavalues, since this code
;; might be moved to a different scope by closure-convert.
Expand Down Expand Up @@ -360,8 +378,10 @@
(call (core svec) ,@temps)))
,body))))
(if (symbol? name)
`(block (method ,name) ,mdef (unnecessary ,name)) ;; return the function
mdef)))))
`(block ,@generator (method ,name) ,mdef (unnecessary ,name)) ;; return the function
(if (not (null? generator))
`(block ,@generator ,mdef)
mdef))))))

;; wrap expr in nested scopes assigning names to vals
(define (scopenest names vals expr)
Expand Down Expand Up @@ -411,11 +431,8 @@
keynames))
;; list of function's initial line number and meta nodes (empty if none)
(prologue (extract-method-prologue body))
(annotations (append (if (any generated-meta? prologue)
'((meta generated))
'())
(map (lambda (a) `(meta nospecialize ,(arg-name (cadr (caddr a)))))
(filter nospecialize-meta? kargl))))
(annotations (map (lambda (a) `(meta nospecialize ,(arg-name (cadr (caddr a)))))
(filter nospecialize-meta? kargl)))
;; body statements
(stmts (cdr body))
(positional-sparams
Expand Down Expand Up @@ -565,10 +582,9 @@
'()))

(define (without-generated stmts)
(map (lambda (x) (if (generated-meta? x)
(filter (lambda (e) (not (eq? e 'generated))) x)
x))
stmts))
(filter (lambda (x) (not (or (generator-meta? x)
(generated_only-meta? x))))
stmts))

;; keep only sparams used by `expr` or other sparams
(define (filter-sparams expr sparams)
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ typedef struct _jl_method_t {
jl_svec_t *sparam_syms; // symbols giving static parameter names
jl_value_t *source; // original code template (jl_code_info_t, but may be compressed), null for builtins
struct _jl_method_instance_t *unspecialized; // unspecialized executable method instance, or null
struct _jl_method_instance_t *generator; // executable code-generating function if available
jl_value_t *generator; // executable code-generating function if available
jl_array_t *roots; // pointers in generated code (shared to reduce memory), or null

// cache of specializations of this method for invoke(), i.e.
Expand Down
3 changes: 2 additions & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,8 @@ extern jl_sym_t *pure_sym; extern jl_sym_t *simdloop_sym;
extern jl_sym_t *meta_sym; extern jl_sym_t *list_sym;
extern jl_sym_t *inert_sym; extern jl_sym_t *static_parameter_sym;
extern jl_sym_t *polly_sym; extern jl_sym_t *inline_sym;
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *generated_sym;
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *generator_sym;
extern jl_sym_t *generated_only_sym;
extern jl_sym_t *isdefined_sym; extern jl_sym_t *nospecialize_sym;

void jl_register_fptrs(uint64_t sysimage_base, const char *base, const int32_t *offsets,
Expand Down
72 changes: 40 additions & 32 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -245,24 +245,23 @@ jl_code_info_t *jl_new_code_info_from_ast(jl_expr_t *ast)
}

// invoke (compiling if necessary) the jlcall function pointer for a method template
STATIC_INLINE jl_value_t *jl_call_staged(jl_svec_t *sparam_vals, jl_method_instance_t *generator,
STATIC_INLINE jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator, jl_svec_t *sparam_vals,
jl_value_t **args, uint32_t nargs)
{
jl_generic_fptr_t fptr;
fptr.fptr = generator->fptr;
fptr.jlcall_api = generator->jlcall_api;
if (__unlikely(fptr.fptr == NULL || fptr.jlcall_api == 0)) {
size_t world = generator->def.method->min_world;
const char *F = jl_compile_linfo(&generator, (jl_code_info_t*)generator->inferred, world, &jl_default_cgparams).functionObject;
fptr = jl_generate_fptr(generator, F, world);
size_t spl = jl_svec_len(sparam_vals);
jl_value_t **gargs;
size_t totargs = 1 + spl + nargs + def->isva;
JL_GC_PUSHARGS(gargs, totargs);
gargs[0] = generator;
memcpy(&gargs[1], jl_svec_data(sparam_vals), spl * sizeof(void*));
memcpy(&gargs[1+spl], args, nargs * sizeof(void*));
if (def->isva) {
gargs[totargs-1] = jl_f_tuple(NULL, &gargs[1+spl+def->nargs-1], nargs - (def->nargs-1));
gargs[1+spl+def->nargs-1] = gargs[totargs-1];
}
assert(jl_svec_len(generator->def.method->sparam_syms) == jl_svec_len(sparam_vals));
if (fptr.jlcall_api == 1)
return fptr.fptr1(args[0], &args[1], nargs-1);
else if (fptr.jlcall_api == 3)
return fptr.fptr3(sparam_vals, args[0], &args[1], nargs-1);
else
abort(); // shouldn't have inferred any other calling convention
jl_value_t *code = jl_apply(gargs, 1+spl+def->nargs);
JL_GC_POP();
return code;
}

// return a newly allocated CodeInfo for the function signature
Expand All @@ -275,9 +274,11 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
jl_expr_t *ex = NULL;
jl_value_t *linenum = NULL;
jl_svec_t *sparam_vals = env;
jl_method_instance_t *generator = linfo->def.method->generator;
jl_value_t *generator = linfo->def.method->generator;
jl_method_t *gen_meth = jl_gf_mtable(generator)->defs.leaf->func.method;
assert(generator != NULL);
assert(linfo != generator);
assert(jl_is_method(gen_meth));
jl_code_info_t *func = NULL;
JL_GC_PUSH4(&ex, &linenum, &sparam_vals, &func);
jl_ptls_t ptls = jl_get_ptls_states();
Expand All @@ -292,13 +293,14 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
// need to eval macros in the right module
ptls->current_task->current_module = ptls->current_module = linfo->def.method->module;
// and the right world
ptls->world_age = generator->def.method->min_world;
ptls->world_age = gen_meth->min_world;

ex = jl_exprn(lambda_sym, 2);

jl_array_t *argnames = jl_alloc_vec_any(linfo->def.method->nargs);
jl_array_t *argnames = jl_alloc_vec_any(linfo->def.method->nargs + jl_svec_len(sparam_vals) + 1);
jl_array_ptr_set(ex->args, 0, argnames);
jl_fill_argnames((jl_array_t*)generator->inferred, argnames);
jl_fill_argnames((jl_array_t*)gen_meth->source, argnames);
jl_array_del_beg(argnames, jl_svec_len(sparam_vals) + 1);

// build the rest of the body to pass to expand
jl_expr_t *scopeblock = jl_exprn(jl_symbol("scope-block"), 1);
Expand All @@ -319,7 +321,7 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
// invoke code generator
assert(jl_nparams(tt) == jl_array_len(argnames) ||
(linfo->def.method->isva && (jl_nparams(tt) >= jl_array_len(argnames) - 1)));
jl_value_t *generated_body = jl_call_staged(sparam_vals, generator, jl_svec_data(tt->parameters), jl_nparams(tt));
jl_value_t *generated_body = jl_call_staged(linfo->def.method, generator, sparam_vals, jl_svec_data(tt->parameters), jl_nparams(tt));
jl_array_ptr_set(body->args, 2, generated_body);

if (jl_is_code_info(generated_body)) {
Expand Down Expand Up @@ -398,7 +400,7 @@ jl_method_instance_t *jl_get_specialized(jl_method_t *m, jl_value_t *types, jl_s
return new_linfo;
}

static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src, int *isstaged)
static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src, jl_value_t **generator, int *gen_only)
{
uint8_t j;
uint8_t called = 0;
Expand Down Expand Up @@ -470,12 +472,17 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src, int *issta
}
st = jl_nothing;
}
else {
for (size_t j=0; j < jl_expr_nargs(st); j++) {
if (jl_exprarg(st, j) == (jl_value_t*)generated_sym) {
*isstaged = 1; break;
}
else if (jl_expr_nargs(st) == 2 && jl_exprarg(st, 0) == (jl_value_t*)generator_sym) {
jl_value_t *gname = jl_exprarg(st, 1);
*generator = jl_get_global(m->module, (jl_sym_t*)gname);
if (*generator == NULL) {
jl_error("invalid @generated function; try placing it in global scope");
}
st = jl_nothing;
}
else if (jl_expr_nargs(st) == 1 && jl_exprarg(st, 0) == (jl_value_t*)generated_only_sym) {
*gen_only = 1;
st = jl_nothing;
}
}
else {
Expand Down Expand Up @@ -530,7 +537,6 @@ static jl_method_t *jl_new_method(
jl_svec_t *tvars)
{
size_t i, l = jl_svec_len(tvars);
int isstaged = 0;
jl_svec_t *sparam_syms = jl_alloc_svec_uninit(l);
for (i = 0; i < l; i++) {
jl_svecset(sparam_syms, i, ((jl_tvar_t*)jl_svecref(tvars, i))->name);
Expand All @@ -547,12 +553,14 @@ static jl_method_t *jl_new_method(
m->sig = (jl_value_t*)sig;
m->isva = isva;
m->nargs = nargs;
jl_method_set_source(m, definition, &isstaged);
if (isstaged) {
// create and store generator for generated functions
m->generator = jl_get_specialized(m, (jl_value_t*)jl_anytuple_type, jl_emptysvec);
jl_value_t *gen = NULL; int gen_only = 0;
jl_method_set_source(m, definition, &gen, &gen_only);
if (gen) {
m->generator = gen;
jl_gc_wb(m, m->generator);
m->generator->inferred = (jl_value_t*)m->source;
}
if (gen_only) {
assert(gen);
m->source = NULL;
}

Expand Down
5 changes: 5 additions & 0 deletions src/utils.scm
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,8 @@
(without (cdr alst) remove)))))

(define (caddddr x) (car (cdr (cdr (cdr (cdr x))))))

(define (find p lst)
(cond ((atom? lst) #f)
((p (car lst)) lst)
(else (find p (cdr lst)))))

0 comments on commit bbb9284

Please sign in to comment.