diff --git a/src/rewrite.jl b/src/rewrite.jl index 398ac472..f5de2f50 100644 --- a/src/rewrite.jl +++ b/src/rewrite.jl @@ -291,114 +291,285 @@ function _write_add_mul(vectorized, minus, current_sum, left_factors, inner_fact end """ - _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Union{Nothing, Symbol}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym()) + _rewrite( + vectorized::Bool, + minus::Bool, + inner_factor, + current_sum::Union{Symbol, Nothing}, + left_factors::Vector, + right_factors::Vector, + new_var::Symbol = gensym(), + ) Return `new_var, code` such that `code` is equivalent to ```julia new_var = prod(left_factors) * inner_factor * prod(reverse(right_factors)) ``` -if `current_sum` is `nothing`, and is + +If `current_sum` is `nothing`, and is ```julia new_var = current_sum op prod(left_factors) * inner_factor * prod(reverse(right_factors)) ``` -otherwise where `op` is `+` if `!vectorized` and `!minus`, `.+` if `vectorized` and `!minus`, -`-` if `!vectorized` and `minus` and `.-` if `vectorized` and `minus`. +otherwise where `op` is `+` if `!vectorized & !minus`, `.+` if +`vectorized & !minus`, `-` if `!vectorized & minus` and `.-` if +`vectorized & minus`. """ -function _rewrite(vectorized::Bool, minus::Bool, inner_factor, current_sum::Union{Symbol, Nothing}, left_factors::Vector, right_factors::Vector, new_var::Symbol=gensym()) +function _rewrite( + vectorized::Bool, + minus::Bool, + inner_factor, + current_sum::Union{Symbol, Nothing}, + left_factors::Vector, + right_factors::Vector, + new_var::Symbol = gensym(), +) if isexpr(inner_factor, :call) - # We need to verfify that `left_factors` and `right_factors` are empty for broadcast, see `_is_decomposable_with_factors`. - # We also need to verify that `current_sum` is `nothing` otherwise we are unsure that the elements in the containers have been copied, e.g., in - # `I + (x .+ 1)`, the offdiagonal entries of `I + x` are the same as `x` so we cannot do `broadcast!(add_mul, I + x, 1)`. - if inner_factor.args[1] == :+ || inner_factor.args[1] == :- || - (current_sum === nothing && isempty(left_factors) && isempty(right_factors) && (inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-))) - block = Expr(:block) - if length(inner_factor.args) > 2 # not unary addition or subtraction - next_sum, code = _rewrite(vectorized, minus, inner_factor.args[2], current_sum, left_factors, right_factors) - push!(block.args, code) - start = 3 - else + if ( + inner_factor.args[1] == :+ || + inner_factor.args[1] == :- || + ( + current_sum === nothing && + isempty(left_factors) && + isempty(right_factors) && + (inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-)) + ) + ) + # There are three cases here: + # 1. scalar addition : +(args...) + # 2. scalar subtraction : -(args...) + # 3. broadcast addition or subtraction. + # For case (3), we need to verify that current_sum, left_factors, + # and right_factors are empty, otherwise we are unsure that the + # elements in the containers have been copied, e.g., in + # `I + (x .+ 1)`, the offdiagonal entries of `I + x` are the same as + # `x` so we cannot do `broadcast!(add_mul, I + x, 1)`. + code = Expr(:block) + if length(inner_factor.args) == 2 + # Unary addition or subtraction. next_sum = current_sum start = 2 + else + next_sum, new_code = _rewrite( + vectorized, + minus, + inner_factor.args[2], + current_sum, + left_factors, + right_factors, + ) + push!(code.args, new_code) + start = 3 end - vectorized = vectorized || inner_factor.args[1] == :(.+) || inner_factor.args[1] == :(.-) if inner_factor.args[1] == :- || inner_factor.args[1] == :(.-) minus = !minus end - return rewrite_sum(vectorized, minus, inner_factor.args[start:end], next_sum, left_factors, right_factors, new_var, block) + vectorized = ( + vectorized || + inner_factor.args[1] == :(.+) || + inner_factor.args[1] == :(.-) + ) + return rewrite_sum( + vectorized, + minus, + inner_factor.args[start:end], + next_sum, + left_factors, + right_factors, + new_var, + code, + ) elseif inner_factor.args[1] == :* && !vectorized - # We need `&& !vectorized` otherwise `x .+ A * b` would be rewritten `broadcast!(add_mul, x, A, b)`. - - # we might need to recurse on multiple arguments, e.g., - # (x+y)*(x+y) - # special case, only recurse on one argument and don't create temporary objects - if isone(mapreduce(_is_complex_expr, +, inner_factor.args)) && - isone(mapreduce(_is_decomposable_with_factors, +, inner_factor.args)) + # A multiplication expression *(args...). We need `!vectorized` + # otherwise `x .+ A * b` would be rewritten + # `broadcast!(add_mul, x, A, b)`. + # We might need to recurse on multiple arguments, e.g., (x+y)*(x+y). + # As a special case, only recurse on one argument and don't create + # temporary objects + if ( + isone(mapreduce(_is_complex_expr, +, inner_factor.args)) && + isone(mapreduce(_is_decomposable_with_factors, +, inner_factor.args)) + ) # `findfirst` return the index in `2:...` so we need to add `1`. which_idx = 1 + findfirst(2:length(inner_factor.args)) do i _is_decomposable_with_factors(inner_factor.args[i]) end return _rewrite( - vectorized, minus, inner_factor.args[which_idx], current_sum, - vcat(left_factors, [esc(inner_factor.args[i]) for i in 2:(which_idx - 1)]), - vcat(right_factors, [esc(inner_factor.args[i]) for i in length(inner_factor.args):-1:(which_idx + 1)]), - new_var) + vectorized, + minus, + inner_factor.args[which_idx], + current_sum, + vcat( + left_factors, + [esc(inner_factor.args[i]) for i in 2:(which_idx - 1)] + ), + vcat( + right_factors, + [ + esc(inner_factor.args[i]) + for i in length(inner_factor.args):-1:(which_idx + 1) + ], + ), + new_var, + ) else - blk = Expr(:block) + code = Expr(:block) for i in 2:length(inner_factor.args) - if _is_complex_expr(inner_factor.args[i]) - new_var_, parsed = rewrite(inner_factor.args[i]) - push!(blk.args, parsed) - inner_factor.args[i] = new_var_ + arg = inner_factor.args[i] + if _is_complex_expr(arg) # `arg` needs rewriting. + new_arg, new_arg_code = rewrite(arg) + push!(code.args, new_arg_code) + inner_factor.args[i] = new_arg else - inner_factor.args[i] = esc(inner_factor.args[i]) + inner_factor.args[i] = esc(arg) end end - push!(blk.args, _write_add_mul( - vectorized, minus, current_sum, left_factors, - inner_factor.args[2:end], right_factors, new_var - )) - return new_var, blk + push!( + code.args, + _write_add_mul( + vectorized, + minus, + current_sum, + left_factors, + inner_factor.args[2:end], + right_factors, + new_var, + ), + ) + return new_var, code end - elseif inner_factor.args[1] == :^ && _is_complex_expr(inner_factor.args[2]) && !vectorized - # We need `&& !vectorized` otherwise `A .+ (A + A)^2` would be rewritten `broadcast!(add_mul, x, AA, AA)` where `AA` is `A + A`. - MulType = :(MA.promote_operation(*, typeof($(inner_factor.args[2])), typeof($(inner_factor.args[2])))) - if inner_factor.args[3] == 2 + elseif ( + inner_factor.args[1] == :^ && + _is_complex_expr(inner_factor.args[2]) && + !vectorized + ) + # An expression like `base ^ exponent`, where the `base` is a + # non-trivial expression that also needs to be re-written. We need + # `!vectorized` otherwise `A .+ (A + A)^2` would be rewritten as + # `broadcast!(add_mul, x, AA, AA)` where `AA` is `A + A`. + MulType = :( + MA.promote_operation( + *, + typeof($(inner_factor.args[2])), + typeof($(inner_factor.args[2])) + ) + ) + if inner_factor.args[3] == 0 + # If the exponent is 0, rewrite + # new_var = base^0 + # as + # new_var = 1 + return _rewrite( + vectorized, + minus, + :(one($MulType)), + current_sum, + left_factors, + right_factors, + new_var, + ) + elseif inner_factor.args[3] == 1 + # If the exponent is 1, rewrite + # new_var = base^1 + # as + # new_var = base + return _rewrite( + vectorized, + minus, + :(convert($MulType, $(inner_factor.args[2]))), + current_sum, + left_factors, + right_factors, + new_var, + ) + elseif inner_factor.args[3] == 2 + # If the exponent is 2, rewrite + # new_var = base^2 + # as + # new_base = base_rewrite + # new_var = base_rewrite * base_rewrite new_var_, parsed = rewrite(inner_factor.args[2]) square_expr = _write_add_mul( - vectorized, minus, current_sum, left_factors, - (new_var_, new_var_), right_factors, new_var + vectorized, + minus, + current_sum, + left_factors, + (new_var_, new_var_), + right_factors, + new_var, ) return new_var, Expr(:block, parsed, square_expr) - elseif inner_factor.args[3] == 1 - return _rewrite(vectorized, minus, :(convert($MulType, $(inner_factor.args[2]))), current_sum, left_factors, right_factors, new_var) - elseif inner_factor.args[3] == 0 - return _rewrite(vectorized, minus, :(one($MulType)), current_sum, left_factors, right_factors, new_var) else - new_var_, parsed = rewrite(inner_factor.args[2]) - power_expr = _write_add_mul( - vectorized, minus, current_sum, left_factors, - (Expr(:call, :^, new_var_, esc(inner_factor.args[3])),), - right_factors, new_var + # In the general case, rewrite + # new_var = base^exponent + # as + # new_base = base_rewrite + # new_var = base_rewrite^(exponent) + new_base, base_rewrite = rewrite(inner_factor.args[2]) + new_expr = _write_add_mul( + vectorized, + minus, + current_sum, + left_factors, + (Expr(:call, :^, new_base, esc(inner_factor.args[3])),), + right_factors, + new_var, ) - return new_var, Expr(:block, parsed, power_expr) + return new_var, Expr(:block, base_rewrite, new_expr) end elseif inner_factor.args[1] == :/ && !vectorized + # Rewrite + # new_var = numerator / denominator + # as + # new_var = numerator * (1 / denominator) @assert length(inner_factor.args) == 3 - numerator = inner_factor.args[2] - denom = inner_factor.args[3] - return _rewrite(vectorized, minus, numerator, current_sum, left_factors, vcat(esc(:(1 / $denom)), right_factors), new_var) - elseif length(inner_factor.args) >= 2 && (isexpr(inner_factor.args[2], :generator) || isexpr(inner_factor.args[2], :flatten)) - return new_var, _parse_generator(vectorized, minus, inner_factor, current_sum, left_factors, right_factors, new_var) + return _rewrite( + vectorized, + minus, + inner_factor.args[2], + current_sum, + left_factors, + vcat(esc(:(1 / $(inner_factor.args[3]))), right_factors), + new_var, + ) + elseif ( + length(inner_factor.args) >= 2 && + ( + isexpr(inner_factor.args[2], :generator) || + isexpr(inner_factor.args[2], :flatten) + ) + ) + # A generator statement. + code = _parse_generator( + vectorized, + minus, + inner_factor, + current_sum, + left_factors, + right_factors, + new_var, + ) + return new_var, code end - elseif isexpr(inner_factor, :curly) - Base.error("The curly syntax (sum{},prod{},norm2{}) is no longer supported. Expression: `$inner_factor`.") end - if isa(inner_factor, Expr) && _is_comparison(inner_factor) + if isexpr(inner_factor, :curly) + error( + "The curly syntax (sum{},prod{},norm2{}) is no longer supported. " * + "Expression: `$inner_factor`." + ) + elseif isa(inner_factor, Expr) && _is_comparison(inner_factor) error("Unexpected comparison in expression `$inner_factor`.") - end - if isa(inner_factor, Expr) && _has_assignment_in_ref(inner_factor) + elseif isa(inner_factor, Expr) && _has_assignment_in_ref(inner_factor) error("Unexpected assignment in expression `$inner_factor`.") end - # at the lowest level - return new_var, _write_add_mul(vectorized, minus, current_sum, left_factors, (esc(inner_factor),), right_factors, new_var) + # None of the special cases were hit! This probably means we are vectorized. + code = _write_add_mul( + vectorized, + minus, + current_sum, + left_factors, + (esc(inner_factor),), + right_factors, + new_var, + ) + return new_var, code end