Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add comments and tidy _rewrite function. #59

Merged
merged 3 commits into from
Nov 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 239 additions & 68 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,114 +291,285 @@ function _write_add_mul(vectorized, minus, current_sum, left_factors, inner_fact
end

"""
_rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Union{Nothing, Symbol}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym())
_rewrite(
vectorized::Bool,
minus::Bool,
inner_factor,
current_sum::Union{Symbol, Nothing},
left_factors::Vector,
right_factors::Vector,
new_var::Symbol = gensym(),
)

Return `new_var, code` such that `code` is equivalent to
```julia
new_var = prod(left_factors) * inner_factor * prod(reverse(right_factors))
```
if `current_sum` is `nothing`, and is

If `current_sum` is `nothing`, and is
```julia
new_var = current_sum op prod(left_factors) * inner_factor * prod(reverse(right_factors))
```
otherwise where `op` is `+` if `!vectorized` and `!minus`, `.+` if `vectorized` and `!minus`,
`-` if `!vectorized` and `minus` and `.-` if `vectorized` and `minus`.
otherwise where `op` is `+` if `!vectorized & !minus`, `.+` if
`vectorized & !minus`, `-` if `!vectorized & minus` and `.-` if
`vectorized & minus`.
"""
function _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Union{Symbol, Nothing}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym())
function _rewrite(
vectorized::Bool,
minus::Bool,
inner_factor,
current_sum::Union{Symbol, Nothing},
left_factors::Vector,
right_factors::Vector,
new_var::Symbol = gensym(),
)
if isexpr(inner_factor, :call)
# We need to verfify that `left_factors` and `right_factors` are empty for broadcast, see `_is_decomposable_with_factors`.
# We also need to verify that `current_sum` is `nothing` otherwise we are unsure that the elements in the containers have been copied, e.g., in
# `I + (x .+ 1)`, the offdiagonal entries of `I + x` are the same as `x` so we cannot do `broadcast!(add_mul, I + x, 1)`.
if inner_factor.args[1] == :+ || inner_factor.args[1] == :- ||
(current_sum === nothing && isempty(left_factors) && isempty(right_factors) && (inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-)))
block = Expr(:block)
if length(inner_factor.args) > 2 # not unary addition or subtraction
next_sum, code = _rewrite(vectorized, minus, inner_factor.args[2], current_sum, left_factors, right_factors)
push!(block.args, code)
start = 3
else
if (
inner_factor.args[1] == :+ ||
inner_factor.args[1] == :- ||
(
current_sum === nothing &&
isempty(left_factors) &&
isempty(right_factors) &&
(inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-))
)
)
# There are three cases here:
# 1. scalar addition : +(args...)
# 2. scalar subtraction : -(args...)
# 3. broadcast addition or subtraction.
# For case (3), we need to verify that current_sum, left_factors,
# and right_factors are empty, otherwise we are unsure that the
# elements in the containers have been copied, e.g., in
# `I + (x .+ 1)`, the offdiagonal entries of `I + x` are the same as
# `x` so we cannot do `broadcast!(add_mul, I + x, 1)`.
code = Expr(:block)
if length(inner_factor.args) == 2
# Unary addition or subtraction.
next_sum = current_sum
start = 2
else
next_sum, new_code = _rewrite(
vectorized,
minus,
inner_factor.args[2],
current_sum,
left_factors,
right_factors,
)
push!(code.args, new_code)
start = 3
end
vectorized = vectorized || inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-)
if inner_factor.args[1] == :- || inner_factor.args[1] == :(.-)
minus = !minus
end
return rewrite_sum(vectorized, minus, inner_factor.args[start:end], next_sum, left_factors, right_factors, new_var, block)
vectorized = (
vectorized ||
inner_factor.args[1] == :(.+) ||
inner_factor.args[1] == :(.-)
)
return rewrite_sum(
vectorized,
minus,
inner_factor.args[start:end],
next_sum,
left_factors,
right_factors,
new_var,
code,
)
elseif inner_factor.args[1] == :* && !vectorized
# We need `&& !vectorized` otherwise `x .+ A * b` would be rewritten `broadcast!(add_mul, x, A, b)`.

# we might need to recurse on multiple arguments, e.g.,
# (x+y)*(x+y)
# special case, only recurse on one argument and don't create temporary objects
if isone(mapreduce(_is_complex_expr, +, inner_factor.args)) &&
isone(mapreduce(_is_decomposable_with_factors, +, inner_factor.args))
# A multiplication expression *(args...). We need `!vectorized`
# otherwise `x .+ A * b` would be rewritten
# `broadcast!(add_mul, x, A, b)`.
# We might need to recurse on multiple arguments, e.g., (x+y)*(x+y).
# As a special case, only recurse on one argument and don't create
# temporary objects
if (
isone(mapreduce(_is_complex_expr, +, inner_factor.args)) &&
isone(mapreduce(_is_decomposable_with_factors, +, inner_factor.args))
)
# `findfirst` return the index in `2:...` so we need to add `1`.
which_idx = 1 + findfirst(2:length(inner_factor.args)) do i
_is_decomposable_with_factors(inner_factor.args[i])
end
return _rewrite(
vectorized, minus, inner_factor.args[which_idx], current_sum,
vcat(left_factors, [esc(inner_factor.args[i]) for i in 2:(which_idx - 1)]),
vcat(right_factors, [esc(inner_factor.args[i]) for i in length(inner_factor.args):-1:(which_idx + 1)]),
new_var)
vectorized,
minus,
inner_factor.args[which_idx],
current_sum,
vcat(
left_factors,
[esc(inner_factor.args[i]) for i in 2:(which_idx - 1)]
),
vcat(
right_factors,
[
esc(inner_factor.args[i])
for i in length(inner_factor.args):-1:(which_idx + 1)
],
),
new_var,
)
else
blk = Expr(:block)
code = Expr(:block)
for i in 2:length(inner_factor.args)
if _is_complex_expr(inner_factor.args[i])
new_var_, parsed = rewrite(inner_factor.args[i])
push!(blk.args, parsed)
inner_factor.args[i] = new_var_
arg = inner_factor.args[i]
if _is_complex_expr(arg) # `arg` needs rewriting.
new_arg, new_arg_code = rewrite(arg)
push!(code.args, new_arg_code)
inner_factor.args[i] = new_arg
else
inner_factor.args[i] = esc(inner_factor.args[i])
inner_factor.args[i] = esc(arg)
end
end
push!(blk.args, _write_add_mul(
vectorized, minus, current_sum, left_factors,
inner_factor.args[2:end], right_factors, new_var
))
return new_var, blk
push!(
code.args,
_write_add_mul(
vectorized,
minus,
current_sum,
left_factors,
inner_factor.args[2:end],
right_factors,
new_var,
),
)
return new_var, code
end
elseif inner_factor.args[1] == :^ && _is_complex_expr(inner_factor.args[2]) && !vectorized
# We need `&& !vectorized` otherwise `A .+ (A + A)^2` would be rewritten `broadcast!(add_mul, x, AA, AA)` where `AA` is `A + A`.
MulType = :(MA.promote_operation(*, typeof($(inner_factor.args[2])), typeof($(inner_factor.args[2]))))
if inner_factor.args[3] == 2
elseif (
inner_factor.args[1] == :^ &&
_is_complex_expr(inner_factor.args[2]) &&
!vectorized
)
# An expression like `base ^ exponent`, where the `base` is a
# non-trivial expression that also needs to be re-written. We need
# `!vectorized` otherwise `A .+ (A + A)^2` would be rewritten as
# `broadcast!(add_mul, x, AA, AA)` where `AA` is `A + A`.
MulType = :(
MA.promote_operation(
*,
typeof($(inner_factor.args[2])),
typeof($(inner_factor.args[2]))
)
)
if inner_factor.args[3] == 0
# If the exponent is 0, rewrite
# new_var = base^0
# as
# new_var = 1
return _rewrite(
vectorized,
minus,
:(one($MulType)),
current_sum,
left_factors,
right_factors,
new_var,
)
elseif inner_factor.args[3] == 1
# If the exponent is 1, rewrite
# new_var = base^1
# as
# new_var = base
return _rewrite(
vectorized,
minus,
:(convert($MulType, $(inner_factor.args[2]))),
current_sum,
left_factors,
right_factors,
new_var,
)
elseif inner_factor.args[3] == 2
# If the exponent is 2, rewrite
# new_var = base^2
# as
# new_base = base_rewrite
# new_var = base_rewrite * base_rewrite
new_var_, parsed = rewrite(inner_factor.args[2])
square_expr = _write_add_mul(
vectorized, minus, current_sum, left_factors,
(new_var_, new_var_), right_factors, new_var
vectorized,
minus,
current_sum,
left_factors,
(new_var_, new_var_),
right_factors,
new_var,
)
return new_var, Expr(:block, parsed, square_expr)
elseif inner_factor.args[3] == 1
return _rewrite(vectorized, minus, :(convert($MulType, $(inner_factor.args[2]))), current_sum, left_factors, right_factors, new_var)
elseif inner_factor.args[3] == 0
return _rewrite(vectorized, minus, :(one($MulType)), current_sum, left_factors, right_factors, new_var)
else
new_var_, parsed = rewrite(inner_factor.args[2])
power_expr = _write_add_mul(
vectorized, minus, current_sum, left_factors,
(Expr(:call, :^, new_var_, esc(inner_factor.args[3])),),
right_factors, new_var
# In the general case, rewrite
# new_var = base^exponent
# as
# new_base = base_rewrite
# new_var = base_rewrite^(exponent)
new_base, base_rewrite = rewrite(inner_factor.args[2])
new_expr = _write_add_mul(
vectorized,
minus,
current_sum,
left_factors,
(Expr(:call, :^, new_base, esc(inner_factor.args[3])),),
right_factors,
new_var,
)
return new_var, Expr(:block, parsed, power_expr)
return new_var, Expr(:block, base_rewrite, new_expr)
end
elseif inner_factor.args[1] == :/ && !vectorized
# Rewrite
# new_var = numerator / denominator
# as
# new_var = numerator * (1 / denominator)
@assert length(inner_factor.args) == 3
numerator = inner_factor.args[2]
denom = inner_factor.args[3]
return _rewrite(vectorized, minus, numerator, current_sum, left_factors, vcat(esc(:(1 / $denom)), right_factors), new_var)
elseif length(inner_factor.args) >= 2 && (isexpr(inner_factor.args[2], :generator) || isexpr(inner_factor.args[2], :flatten))
return new_var, _parse_generator(vectorized, minus, inner_factor, current_sum, left_factors, right_factors, new_var)
return _rewrite(
vectorized,
minus,
inner_factor.args[2],
current_sum,
left_factors,
vcat(esc(:(1 / $(inner_factor.args[3]))), right_factors),
new_var,
)
elseif (
length(inner_factor.args) >= 2 &&
(
isexpr(inner_factor.args[2], :generator) ||
isexpr(inner_factor.args[2], :flatten)
)
)
# A generator statement.
code = _parse_generator(
vectorized,
minus,
inner_factor,
current_sum,
left_factors,
right_factors,
new_var,
)
return new_var, code
end
elseif isexpr(inner_factor, :curly)
Base.error("The curly syntax (sum{},prod{},norm2{}) is no longer supported. Expression: `$inner_factor`.")
end
if isa(inner_factor, Expr) && _is_comparison(inner_factor)
if isexpr(inner_factor, :curly)
error(
"The curly syntax (sum{},prod{},norm2{}) is no longer supported. " *
"Expression: `$inner_factor`."
)
elseif isa(inner_factor, Expr) && _is_comparison(inner_factor)
error("Unexpected comparison in expression `$inner_factor`.")
end
if isa(inner_factor, Expr) && _has_assignment_in_ref(inner_factor)
elseif isa(inner_factor, Expr) && _has_assignment_in_ref(inner_factor)
error("Unexpected assignment in expression `$inner_factor`.")
end
# at the lowest level
return new_var, _write_add_mul(vectorized, minus, current_sum, left_factors, (esc(inner_factor),), right_factors, new_var)
# None of the special cases were hit! This probably means we are vectorized.
code = _write_add_mul(
vectorized,
minus,
current_sum,
left_factors,
(esc(inner_factor),),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bug for 1.6 is here. If inner_factor contains a sum(1 for _ = 1:0), then this evaluates it and throws an error. I'll think about what extra logic needs to be added to fix it in #60

right_factors,
new_var,
)
return new_var, code
end