Skip to content

Commit

Permalink
implement comprehensions as collect of a Generator
Browse files Browse the repository at this point in the history
this removes `static_typeof` and `type_goto`

fixes #7258
  • Loading branch information
JeffBezanson committed Jun 3, 2016
1 parent 3f59431 commit b99ea48
Show file tree
Hide file tree
Showing 17 changed files with 48 additions and 292 deletions.
4 changes: 2 additions & 2 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -619,8 +619,8 @@ typed_hcat(T::Type) = Array{T}(0)
## cat: special cases
vcat{T}(X::T...) = T[ X[i] for i=1:length(X) ]
vcat{T<:Number}(X::T...) = T[ X[i] for i=1:length(X) ]
hcat{T}(X::T...) = T[ X[j] for i=1, j=1:length(X) ]
hcat{T<:Number}(X::T...) = T[ X[j] for i=1, j=1:length(X) ]
hcat{T}(X::T...) = T[ X[j] for i=1:1, j=1:length(X) ]
hcat{T<:Number}(X::T...) = T[ X[j] for i=1:1, j=1:length(X) ]

vcat(X::Number...) = hvcat_fill(Array{promote_typeof(X...)}(length(X)), X)
hcat(X::Number...) = hvcat_fill(Array{promote_typeof(X...)}(1,length(X)), X)
Expand Down
4 changes: 2 additions & 2 deletions base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ function ctranspose(A::AbstractMatrix)
end
ctranspose{T<:Real}(A::AbstractVecOrMat{T}) = transpose(A)

transpose(x::AbstractVector) = [ transpose(v) for i=1, v in x ]
ctranspose{T}(x::AbstractVector{T}) = T[ ctranspose(v) for i=1, v in x ] #Fixme comprehension
transpose(x::AbstractVector) = [ transpose(v) for i=1:1, v in x ]
ctranspose{T}(x::AbstractVector{T}) = T[ ctranspose(v) for i=1:1, v in x ] #Fixme comprehension

_cumsum_type{T<:Number}(v::AbstractArray{T}) = typeof(+zero(T))
_cumsum_type(v) = typeof(v[1]+v[1])
Expand Down
5 changes: 3 additions & 2 deletions base/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ Generator(f, c1, c...) = Generator(a->f(a...), zip(c1, c...))

Generator{T,I}(::Type{T}, iter::I) = Generator{I,Type{T}}(T, iter)

start(g::Generator) = start(g.iter)
done(g::Generator, s) = done(g.iter, s)
start(g::Generator) = (@_inline_meta; start(g.iter))
done(g::Generator, s) = (@_inline_meta; done(g.iter, s))
function next(g::Generator, s)
@_inline_meta
v, s2 = next(g.iter, s)
g.f(v), s2
end
Expand Down
111 changes: 5 additions & 106 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ type InferenceState
atypes #::Type # type sig
sp::SimpleVector # static parameters
label_counter::Int # index of the current highest label for this function
fedbackvars::Dict{SSAValue, Bool}
mod::Module
currpc::LineNum
static_typeof::Bool

# info on the state of inference and the linfo
linfo::LambdaInfo
Expand All @@ -71,7 +69,6 @@ type InferenceState
backedges::Vector{Tuple{InferenceState, Vector{LineNum}}}
# iteration fixed-point detection
fixedpoint::Bool
typegotoredo::Bool
inworkq::Bool
optimize::Bool
inferred::Bool
Expand Down Expand Up @@ -157,13 +154,13 @@ type InferenceState

inmodule = isdefined(linfo, :def) ? linfo.def.module : current_module() # toplevel thunks are inferred in the current module
frame = new(
atypes, sp, nl, Dict{SSAValue, Bool}(), inmodule, 0, false,
atypes, sp, nl, inmodule, 0,
linfo, linfo, la, s, Union{}, W, n,
cur_hand, handler_at, n_handlers,
ssavalue_uses, ssavalue_init,
ObjectIdDict(), #Dict{InferenceState, Vector{LineNum}}(),
Vector{Tuple{InferenceState, Vector{LineNum}}}(),
false, false, false, optimize, false, nothing)
false, false, optimize, false, nothing)
push!(active, frame)
nactive[] += 1
return frame
Expand Down Expand Up @@ -1070,8 +1067,6 @@ function abstract_eval(e::ANY, vtypes::VarTable, sv::InferenceState)
return abstract_eval_constant(e)
end
e = e::Expr
# handle:
# call null new & static_typeof
if is(e.head,:call)
t = abstract_eval_call(e, vtypes, sv)
elseif is(e.head,:null)
Expand Down Expand Up @@ -1105,42 +1100,6 @@ function abstract_eval(e::ANY, vtypes::VarTable, sv::InferenceState)
t = abstract_eval_constant(val)
end
end
elseif is(e.head,:static_typeof)
var = e.args[1]
t = widenconst(abstract_eval(var, vtypes, sv))
if isa(t,DataType) && typeseq(t,t.name.primary)
# remove unnecessary typevars
t = t.name.primary
end
if is(t,Bottom)
# if we haven't gotten fed-back type info yet, return Bottom. otherwise
# Bottom is the actual type of the variable, so return Type{Bottom}.
if get!(sv.fedbackvars, var, false)
t = Type{Bottom}
else
sv.static_typeof = true
end
elseif isleaftype(t)
t = Type{t}
elseif isleaftype(sv.atypes)
if isa(t,TypeVar)
t = Type{t.ub}
else
t = Type{t}
end
else
# if there is any type uncertainty in the arguments, we are
# effectively predicting what static_typeof will say when
# the function is compiled with actual arguments. in that case
# abstract types yield Type{<:T} instead of Type{T}.
# this doesn't really model the situation perfectly, but
# "isleaftype(inference_stack.types)" should be good enough.
if isa(t,TypeVar) || isvarargtype(t)
t = Type{t}
else
t = Type{TypeVar(:_,t)}
end
end
elseif is(e.head,:method)
t = (length(e.args) == 1) ? Any : Void
elseif is(e.head,:copyast)
Expand Down Expand Up @@ -1631,23 +1590,19 @@ function typeinf_frame(frame)
W = frame.ip
s = frame.stmt_types
n = frame.nstmts
@label restart_typeinf
while !isempty(W)
# make progress on the active ip set
local pc::Int = first(W), pc´::Int
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
#print(pc,": ",s[pc],"\n")
delete!(W, pc)
frame.currpc = pc
frame.static_typeof = false
frame.cur_hand = frame.handler_at[pc]
stmt = frame.linfo.code[pc]
changes = abstract_interpret(stmt, s[pc]::Array{Any,1}, frame)
if changes === ()
# if there was a Expr(:static_typeof) on this line,
# need to continue to the next pc even though the return type was Bottom
# otherwise, this line threw an error and there is no need to continue
frame.static_typeof || break
# this line threw an error and there is no need to continue
break
changes = s[pc]
end
if frame.cur_hand !== ()
Expand Down Expand Up @@ -1697,26 +1652,6 @@ function typeinf_frame(frame)
s[l] = newstate
end
end
elseif is(hd, :type_goto)
for i = 2:length(stmt.args)
var = stmt.args[i]::SSAValue
# Store types that need to be fed back via type_goto
# in ssavalue_init. After finishing inference, if any
# of these types changed, start over with the fed-back
# types known from the beginning.
# See issue #3821 (using !typeseq instead of !subtype),
# and issue #7810.
id = var.id+1
vt = frame.linfo.ssavaluetypes[id]
ot = frame.ssavalue_init[id]
if ot===NF || !(vtot && otvt)
frame.ssavalue_init[id] = vt
if get(frame.fedbackvars, var, false)
frame.typegotoredo = true
end
end
frame.fedbackvars[var] = true
end
elseif is(hd, :return)
pc´ = n + 1
rt = abstract_eval(stmt.args[1], s[pc], frame)
Expand Down Expand Up @@ -1786,39 +1721,6 @@ function typeinf_frame(frame)
end

if finished || frame.fixedpoint
if frame.typegotoredo
# if any type_gotos changed, clear state and restart.
frame.typegotoredo = false
for ll = 2:length(s)
s[ll] = ()
end
empty!(W)
push!(W, 1)
frame.cur_hand = ()
frame.handler_at = Any[ () for i=1:n ]
frame.n_handlers = 0
frame.linfo.ssavaluetypes[:] = frame.ssavalue_init
@goto restart_typeinf
else
# if a static_typeof was never reached,
# use Union{} as its real type and continue
# running type inference from its uses
# (one of which is the static_typeof)
# TODO: this restart should happen just before calling finish()
for (fbvar, seen) in frame.fedbackvars
if !seen
frame.fedbackvars[fbvar] = true
id = (fbvar::SSAValue).id + 1
for r in frame.ssavalue_uses[id]
if !is(s[r], ()) # s[r] === () => unreached statement
push!(W, r)
end
end
@goto restart_typeinf
end
end
end

if finished
finish(frame)
else # fixedpoint propagation
Expand Down Expand Up @@ -2000,7 +1902,7 @@ function eval_annotate(e::ANY, vtypes::ANY, sv::InferenceState, undefs, pass)

e = e::Expr
head = e.head
if is(head,:static_typeof) || is(head,:line) || is(head,:const)
if is(head,:line) || is(head,:const)
return e
elseif is(head,:(=))
e.args[2] = eval_annotate(e.args[2], vtypes, sv, undefs, pass)
Expand Down Expand Up @@ -2222,9 +2124,6 @@ function effect_free(e::ANY, sv, allow_volatile::Bool)
end
if isa(e,Expr)
e = e::Expr
if e.head === :static_typeof
return true
end
if e.head === :static_parameter
return true
end
Expand Down
1 change: 0 additions & 1 deletion base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,6 @@ immutable IteratorND{I,N}
end
new{I,N}(iter, shape)
end
(::Type{IteratorND}){I<:AbstractProdIterator}(p::I) = IteratorND(p, size(p))
end

start(i::IteratorND) = start(i.iter)
Expand Down
2 changes: 1 addition & 1 deletion base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ maximum(r::UnitRange) = isempty(r) ? throw(ArgumentError("range must be non-empt
minimum(r::Range) = isempty(r) ? throw(ArgumentError("range must be non-empty")) : min(first(r), last(r))
maximum(r::Range) = isempty(r) ? throw(ArgumentError("range must be non-empty")) : max(first(r), last(r))

ctranspose(r::Range) = [x for _=1, x=r]
ctranspose(r::Range) = [x for _=1:1, x=r]
transpose(r::Range) = r'

# Ranges are immutable
Expand Down
20 changes: 10 additions & 10 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2802,8 +2802,8 @@ end

function vcat(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
nX = [ size(x, 2) for x in X ]
mX = Int[ size(x, 1) for x in X ]
nX = Int[ size(x, 2) for x in X ]
m = sum(mX)
n = nX[1]

Expand All @@ -2820,7 +2820,7 @@ function vcat(X::SparseMatrixCSC...)
Ti = promote_type(Ti, eltype(X[i].rowval))
end

nnzX = [ nnz(x) for x in X ]
nnzX = Int[ nnz(x) for x in X ]
nnz_res = sum(nnzX)
colptr = Array{Ti}(n + 1)
rowval = Array{Ti}(nnz_res)
Expand Down Expand Up @@ -2862,8 +2862,8 @@ end

function hcat(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
nX = [ size(x, 2) for x in X ]
mX = Int[ size(x, 1) for x in X ]
nX = Int[ size(x, 2) for x in X ]
m = mX[1]
for i = 2 : num
if mX[i] != m; throw(DimensionMismatch("")); end
Expand All @@ -2874,7 +2874,7 @@ function hcat(X::SparseMatrixCSC...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

colptr = Array{Ti}(n + 1)
nnzX = [ nnz(x) for x in X ]
nnzX = Int[ nnz(x) for x in X ]
nnz_res = sum(nnzX)
rowval = Array{Ti}(nnz_res)
nzval = Array{Tv}(nnz_res)
Expand Down Expand Up @@ -2930,16 +2930,16 @@ Concatenate matrices block-diagonally. Currently only implemented for sparse mat
"""
function blkdiag(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
nX = [ size(x, 2) for x in X ]
mX = Int[ size(x, 1) for x in X ]
nX = Int[ size(x, 2) for x in X ]
m = sum(mX)
n = sum(nX)

Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

colptr = Array{Ti}(n + 1)
nnzX = [ nnz(x) for x in X ]
nnzX = Int[ nnz(x) for x in X ]
nnz_res = sum(nnzX)
rowval = Array{Ti}(nnz_res)
nzval = Array{Tv}(nnz_res)
Expand Down Expand Up @@ -3180,7 +3180,7 @@ function trace{Tv}(A::SparseMatrixCSC{Tv})
s
end

diag(A::SparseMatrixCSC) = [d for d in SpDiagIterator(A)]
diag{Tv}(A::SparseMatrixCSC{Tv}) = Tv[d for d in SpDiagIterator(A)]

function diagm{Tv,Ti}(v::SparseMatrixCSC{Tv,Ti})
if (size(v,1) != 1 && size(v,2) != 1)
Expand Down
3 changes: 1 addition & 2 deletions src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,12 @@ jl_sym_t *null_sym; jl_sym_t *body_sym;
jl_sym_t *method_sym; jl_sym_t *core_sym;
jl_sym_t *enter_sym; jl_sym_t *leave_sym;
jl_sym_t *exc_sym; jl_sym_t *error_sym;
jl_sym_t *static_typeof_sym;
jl_sym_t *globalref_sym;
jl_sym_t *new_sym; jl_sym_t *using_sym;
jl_sym_t *const_sym; jl_sym_t *thunk_sym;
jl_sym_t *anonymous_sym; jl_sym_t *underscore_sym;
jl_sym_t *abstracttype_sym; jl_sym_t *bitstype_sym;
jl_sym_t *compositetype_sym; jl_sym_t *type_goto_sym;
jl_sym_t *compositetype_sym;
jl_sym_t *global_sym; jl_sym_t *list_sym;
jl_sym_t *dot_sym; jl_sym_t *newvar_sym;
jl_sym_t *boundscheck_sym; jl_sym_t *inbounds_sym;
Expand Down
14 changes: 1 addition & 13 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3096,7 +3096,7 @@ static void emit_stmtpos(jl_value_t *expr, jl_codectx_t *ctx)
if (jl_is_expr(expr)) {
jl_sym_t *head = ((jl_expr_t*)expr)->head;
// some expression types are metadata and can be ignored in statement position
if (head == line_sym || head == type_goto_sym || head == meta_sym)
if (head == line_sym || head == meta_sym)
return;
// fall-through
}
Expand Down Expand Up @@ -3270,18 +3270,6 @@ static jl_cgval_t emit_expr(jl_value_t *expr, jl_codectx_t *ctx)
else if (head == null_sym) {
return ghostValue(jl_void_type);
}
else if (head == static_typeof_sym) {
jl_value_t *extype = expr_type((jl_value_t*)ex, ctx);
if (jl_is_type_type(extype)) {
extype = jl_tparam0(extype);
if (jl_is_typevar(extype))
extype = ((jl_tvar_t*)extype)->ub;
}
else {
extype = (jl_value_t*)jl_any_type;
}
return mark_julia_const(extype);
}
else if (head == new_sym) {
jl_value_t *ty = expr_type(args[0], ctx);
size_t nargs = jl_array_len(ex->args);
Expand Down
3 changes: 0 additions & 3 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,6 @@ static jl_value_t *eval(jl_value_t *e, jl_value_t **locals, jl_lambda_info_t *la
else if (ex->head == exc_sym) {
return jl_exception_in_transit;
}
else if (ex->head == static_typeof_sym) {
return (jl_value_t*)jl_any_type;
}
else if (ex->head == method_sym) {
jl_sym_t *fname = (jl_sym_t*)args[0];
assert(jl_expr_nargs(ex) != 1 || jl_is_symbol(fname));
Expand Down
2 changes: 0 additions & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3758,7 +3758,6 @@ void jl_init_types(void)
exc_sym = jl_symbol("the_exception");
enter_sym = jl_symbol("enter");
leave_sym = jl_symbol("leave");
static_typeof_sym = jl_symbol("static_typeof");
new_sym = jl_symbol("new");
const_sym = jl_symbol("const");
global_sym = jl_symbol("global");
Expand All @@ -3769,7 +3768,6 @@ void jl_init_types(void)
abstracttype_sym = jl_symbol("abstract_type");
bitstype_sym = jl_symbol("bits_type");
compositetype_sym = jl_symbol("composite_type");
type_goto_sym = jl_symbol("type_goto");
toplevel_sym = jl_symbol("toplevel");
dot_sym = jl_symbol(".");
boundscheck_sym = jl_symbol("boundscheck");
Expand Down
Loading

0 comments on commit b99ea48

Please sign in to comment.