Skip to content

Commit

Permalink
treat .= as syntactic sugar for broadcast! (#17510)
Browse files Browse the repository at this point in the history
* treat .= as syntactic sugar for broadcast!

* tests

* optimized .= assignment of scalars and vector copies

* .= documentation

* fix show of .= ops

* .-= tests

* NEWS for .=
  • Loading branch information
stevengj authored Jul 21, 2016
1 parent bc034fc commit cd2e260
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 88 deletions.
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ New language features
* Generators and comprehensions support filtering using `if` ([#550]) and nested
iteration using multiple `for` keywords ([#4867]).

* Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]),
* Fused broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]),
and nested `f.(g.(args...))` calls are fused into a single `broadcast` loop ([#17300]).
Similarly, the syntax `x .= ...` is equivalent to a `broadcast!(identity, x, ...)`
call and fuses with nested "dot" calls; also, `x .+= y` and similar is now
equivalent to `x .= x .+ y`, rather than `=` ([#17510]).

* Macro expander functions are now generic, so macros can have multiple definitions
(e.g. for different numbers of arguments, or optional arguments) ([#8846], [#9627]).
Expand Down Expand Up @@ -357,3 +360,4 @@ Deprecated or removed
[#17393]: https://github.com/JuliaLang/julia/issues/17393
[#17402]: https://github.com/JuliaLang/julia/issues/17402
[#17404]: https://github.com/JuliaLang/julia/issues/17404
[#17510]: https://github.com/JuliaLang/julia/issues/17510
9 changes: 9 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ export broadcast_getindex, broadcast_setindex!
broadcast(f) = f()
broadcast(f, x::Number...) = f(x...)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
broadcast!(f, X::AbstractArray) = fill!(X, f())
broadcast!(f, X::AbstractArray, x::Number...) = fill!(X, f(x...))
function broadcast!{T,S,N}(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N})
check_broadcast_shape(size(x), size(y))
copy!(x, y)
end

## Calculate the broadcast shape of the arguments, or error if incompatible
# array inputs
broadcast_shape() = ()
Expand Down
6 changes: 4 additions & 2 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,10 @@ show_unquoted(io::IO, ex, ::Int,::Int) = show(io, ex)
const indent_width = 4
const quoted_syms = Set{Symbol}([:(:),:(::),:(:=),:(=),:(==),:(!=),:(===),:(!==),:(=>),:(>=),:(<=)])
const uni_ops = Set{Symbol}([:(+), :(-), :(!), :(¬), :(~), :(<:), :(>:), :(), :(), :()])
const expr_infix_wide = Set{Symbol}([:(=), :(+=), :(-=), :(*=), :(/=), :(\=), :(&=),
:(|=), :($=), :(>>>=), :(>>=), :(<<=), :(&&), :(||), :(<:), :(=>), :(÷=)])
const expr_infix_wide = Set{Symbol}([
:(=), :(+=), :(-=), :(*=), :(/=), :(\=), :(^=), :(&=), :(|=), :(÷=), :(%=), :(>>>=), :(>>=), :(<<=),
:(.=), :(.+=), :(.-=), :(.*=), :(./=), :(.\=), :(.^=), :(.&=), :(.|=), :(.÷=), :(.%=), :(.>>>=), :(.>>=), :(.<<=),
:(&&), :(||), :(<:), :(=>), :($=)])
const expr_infix = Set{Symbol}([:(:), :(->), Symbol("::")])
const expr_infix_any = union(expr_infix, expr_infix_wide)
const all_ops = union(quoted_syms, uni_ops, expr_infix_any)
Expand Down
2 changes: 1 addition & 1 deletion doc/manual/arrays.rst
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ function elementwise:
1.71056 0.847604
1.73659 0.873631

Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a :func:`broadcast!` function to specify an explicit destination, and :func:`broadcast_getindex` and :func:`broadcast_setindex!` that broadcast the indices before indexing. Moreover, ``f.(args...)`` is equivalent to ``broadcast(f, args...)``, providing a convenient syntax to broadcast any function (:ref:`man-dot-vectorizing`:.).
Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a :func:`broadcast!` function to specify an explicit destination, and :func:`broadcast_getindex` and :func:`broadcast_setindex!` that broadcast the indices before indexing. Moreover, ``f.(args...)`` is equivalent to ``broadcast(f, args...)``, providing a convenient syntax to broadcast any function (:ref:`man-dot-vectorizing`:).

Implementation
--------------
Expand Down
13 changes: 12 additions & 1 deletion doc/manual/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -652,9 +652,20 @@ the fusion stops as soon as a "non-dot" function is encountered; for example,
in ``sin.(sort(cos.(X)))`` the ``sin`` and ``cos`` loops cannot be merged
because of the intervening ``sort`` function.

Finally, the maximum efficiency is typically achieved when the output
array of a vectorized operation is *pre-allocated*, so that repeated
calls do not allocate new arrays over and over again for the results
(:ref:`man-preallocation`:). A convenient syntax for this is
``X .= ...``, which is equivalent to ``broadcast!(identity, X, ...)``
except that, as above, the ``broadcast!`` loop is fused with any nested
"dot" calls. For example, ``X .= sin.(Y)`` is equivalent to
``broadcast!(sin, X, Y)``, overwriting ``X`` with ``sin.(Y)`` in-place.

(In future versions of Julia, operators like ``.*`` will also be handled with
the same mechanism: they will be equivalent to ``broadcast`` calls and
will be fused with other nested "dot" calls.)
will be fused with other nested "dot" calls. ``x .+= y`` is equivalent
to ``x .= x .+ y`` and will eventually result in a fused in-place assignment.
Similarly for ``.*=`` etcetera.)

Further Reading
---------------
Expand Down
5 changes: 4 additions & 1 deletion doc/manual/performance-tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,10 @@ above, we could have passed a :class:`SubArray` rather than an :class:`Array`,
had we so desired.

Taken to its extreme, pre-allocation can make your code uglier, so
performance measurements and some judgment may be required.
performance measurements and some judgment may be required. However,
for "vectorized" (element-wise) functions, the convenient syntax
``x .= f.(y)`` can be used for in-place operations with fused loops
and no temporary arrays (:ref:`dot-vectorizing`).


Avoid string interpolation for I/O
Expand Down
177 changes: 95 additions & 82 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1418,12 +1418,12 @@
`(call ,(cadr e) ,(expand-forms a) ,(expand-forms b))))))

;; convert `a+=b` to `a=a+b`
(define (expand-update-operator- op lhs rhs declT)
(define (expand-update-operator- op op= lhs rhs declT)
(let ((e (remove-argument-side-effects lhs)))
`(block ,@(cdr e)
,(if (null? declT)
`(= ,(car e) (call ,op ,(car e) ,rhs))
`(= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs))))))
`(,op= ,(car e) (call ,op ,(car e) ,rhs))
`(,op= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs))))))

(define (partially-expand-ref e)
(let ((a (cadr e))
Expand All @@ -1443,31 +1443,32 @@
,@(append stmts stuff)
(call getindex ,arr ,@new-idxs))))))

(define (expand-update-operator op lhs rhs . declT)
(define (expand-update-operator op op= lhs rhs . declT)
(cond ((and (pair? lhs) (eq? (car lhs) 'ref))
;; expand indexing inside op= first, to remove "end" and ":"
(let* ((ex (partially-expand-ref lhs))
(stmts (butlast (cdr ex)))
(refex (last (cdr ex)))
(nuref `(ref ,(caddr refex) ,@(cdddr refex))))
`(block ,@stmts
,(expand-update-operator- op nuref rhs declT))))
,(expand-update-operator- op op= nuref rhs declT))))
((and (pair? lhs) (eq? (car lhs) '|::|))
;; (+= (:: x T) rhs)
(let ((e (remove-argument-side-effects (cadr lhs)))
(T (caddr lhs)))
`(block ,@(cdr e)
,(expand-update-operator op (car e) rhs T))))
,(expand-update-operator op op= (car e) rhs T))))
(else
(expand-update-operator- op lhs rhs declT))))
(expand-update-operator- op op= lhs rhs declT))))

(define (lower-update-op e)
(expand-forms
(expand-update-operator
(let ((str (string (car e))))
(symbol (string.sub str 0 (- (length str) 1))))
(cadr e)
(caddr e))))
(let ((str (string (car e))))
(expand-update-operator
(symbol (string.sub str 0 (- (length str) 1)))
(if (= (string.char str 0) #\.) '.= '=)
(cadr e)
(caddr e)))))

(define (expand-and e)
(let ((e (cdr (flatten-ex '&& e))))
Expand Down Expand Up @@ -1546,11 +1547,9 @@
(cadr expr) ;; eta reduce `x->f(x)` => `f`
`(-> ,argname (block ,@splat ,expr)))))

(define (getfield-field? x) ; whether x from (|.| f x) is a getfield call
(or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)))

;; fuse nested calls to f.(args...) into a single broadcast call
(define (expand-fuse-broadcast f args)
; fuse nested calls to expr == f.(args...) into a single broadcast call,
; or a broadcast! call if lhs is non-null.
(define (expand-fuse-broadcast lhs rhs)
(define (fuse? e) (and (pair? e) (eq? (car e) 'fuse)))
(define (anyfuse? exprs)
(if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs)))))
Expand Down Expand Up @@ -1594,72 +1593,83 @@
oldarg))
fargs args)))
(let ,fbody ,@(reverse (fuse-lets fargs args '()))))))
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
(define (sk args kwargs pargs)
(if (null? args)
(cons kwargs pargs)
(if (kwarg? (car args))
(sk (cdr args) (cons (car args) kwargs) pargs)
(sk (cdr args) kwargs (cons (car args) pargs)))))
(if (has-parameters? args)
(sk (reverse (cdr args)) (cdar args) '())
(sk (reverse args) '() '())))
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
(if (and (pair? e) (eq? (car e) '|.|) (not (getfield-field? (caddr e))))
(make-fuse (cadr e) (cdaddr e))
e))
(let* ((kws.args (split-kwargs args))
(kws (car kws.args))
(args (cdr kws.args)) ; fusing occurs on positional args only
(args_ (map dot-to-fuse args)))
(if (anyfuse? args_)
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
`(fuse ,(to-lambda f args kws) ,args_))))
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
(define (sk args kwargs pargs)
(if (null? args)
(cons kwargs pargs)
(if (kwarg? (car args))
(sk (cdr args) (cons (car args) kwargs) pargs)
(sk (cdr args) kwargs (cons (car args) pargs)))))
(if (has-parameters? args)
(sk (reverse (cdr args)) (cdar args) '())
(sk (reverse args) '() '())))
(let* ((kws.args (split-kwargs args))
(kws (car kws.args))
(args (cdr kws.args)) ; fusing occurs on positional args only
(args_ (map dot-to-fuse args)))
(if (anyfuse? args_)
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
`(fuse ,(to-lambda f args kws) ,args_))))
(if (and (pair? e) (eq? (car e) '|.|))
(let ((f (cadr e)) (x (caddr e)))
(if (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$))
`(call (core getfield) ,f ,x)
(make-fuse f (cdr x))))
e))
; given e == (fuse lambda args), compress the argument list by removing (pure)
; duplicates in args, inlining literals, and moving any varargs to the end:
(define (compress-fuse e)
(define (findfarg arg args fargs) ; for arg in args, return corresponding farg
(if (eq? arg (car args))
(car fargs)
(findfarg arg (cdr args) (cdr fargs))))
(let ((f (cadr e))
(args (caddr e)))
(define (cf old-fargs old-args new-fargs new-args renames varfarg vararg)
(if (null? old-args)
(let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs)))
(nargs (if (null? vararg) new-args (cons vararg new-args))))
`(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames))
,(reverse nargs)))
(let ((farg (car old-fargs)) (arg (car old-args)))
(cond
((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument
(if (null? varfarg)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args renames farg arg)
(if (eq? (cadr vararg) (cadr arg))
(if (fuse? e)
(let ((f (cadr e))
(args (caddr e)))
(define (cf old-fargs old-args new-fargs new-args renames varfarg vararg)
(if (null? old-args)
(let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs)))
(nargs (if (null? vararg) new-args (cons vararg new-args))))
`(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames))
,(reverse nargs)))
(let ((farg (car old-fargs)) (arg (car old-args)))
(cond
((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument
(if (null? varfarg)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
varfarg vararg)
(error "multiple splatted args cannot be fused into a single broadcast"))))
((number? arg) ; inline numeric literals
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg arg) renames)
varfarg vararg))
((and (symbol? arg) (memq arg new-args)) ; combine duplicate args
; (note: calling memq for every arg is O(length(args)^2) ...
; ... would be better to replace with a hash table if args is long)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg (findfarg arg new-args new-fargs)) renames)
varfarg vararg))
(else
(cf (cdr old-fargs) (cdr old-args)
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
(cf (cdadr f) args '() '() '() '() '())))
(let ((e (compress-fuse (make-fuse f args)))) ; an expression '(fuse func args)
(expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e)))))
new-fargs new-args renames farg arg)
(if (eq? (cadr vararg) (cadr arg))
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
varfarg vararg)
(error "multiple splatted args cannot be fused into a single broadcast"))))
((number? arg) ; inline numeric literals
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg arg) renames)
varfarg vararg))
((and (symbol? arg) (memq arg new-args)) ; combine duplicate args
; (note: calling memq for every arg is O(length(args)^2) ...
; ... would be better to replace with a hash table if args is long)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg (findfarg arg new-args new-fargs)) renames)
varfarg vararg))
(else
(cf (cdr old-fargs) (cdr old-args)
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
(cf (cdadr f) args '() '() '() '() '()))
e)) ; (not (fuse? e))
(let ((e (compress-fuse (dot-to-fuse rhs)))) ; an expression '(fuse func args) if expr is a dot call
(if (fuse? e)
(if (null? lhs)
(expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e)))
(expand-forms `(call broadcast! ,(from-lambda (cadr e)) ,lhs ,@(caddr e))))
(if (null? lhs)
(expand-forms e)
(expand-forms `(call broadcast! identity ,lhs ,e))))))

;; table mapping expression head to a function expanding that form
(define expand-table
Expand Down Expand Up @@ -1697,13 +1707,11 @@

'|.|
(lambda (e) ; e = (|.| f x)
(let ((f (cadr e))
(x (caddr e)))
(if (getfield-field? x)
`(call (core getfield) ,(expand-forms f) ,(expand-forms x))
; otherwise, came from f.(args...) --> broadcast(f, args...),
; where we want to fuse with any nested broadcast calls.
(expand-fuse-broadcast f (cdr x)))))
(expand-fuse-broadcast '() e))

'.=
(lambda (e)
(expand-fuse-broadcast (cadr e) (caddr e)))

'|<:| syntactic-op-to-call
'|>:| syntactic-op-to-call
Expand Down Expand Up @@ -2008,11 +2016,16 @@
'%= lower-update-op
'.%= lower-update-op
'|\|=| lower-update-op
'|.\|=| lower-update-op
'&= lower-update-op
'.&= lower-update-op
'$= lower-update-op
'<<= lower-update-op
'.<<= lower-update-op
'>>= lower-update-op
'.>>= lower-update-op
'>>>= lower-update-op
'.>>>= lower-update-op

':
(lambda (e)
Expand Down
19 changes: 19 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,25 @@ let x = [1:4;]
@test sin.(f17300kw.(x, y=1)) == sin.(f17300kw.(x; y=1)) == sin.(x .+ 1)
end

# PR #17510: Fused in-place assignment
let x = [1:4;], y = x
y .= 2:5
@test y === x == [2:5;]
y .= factorial.(x)
@test y === x == [2,6,24,120]
y .= 7
@test y === x == [7,7,7,7]
y .= factorial.(3)
@test y === x == [6,6,6,6]
f17510() = 9
y .= f17510.()
@test y === x == [9,9,9,9]
y .-= 1
@test y === x == [8,8,8,8]
y .-= 1:4
@test y === x == [7,6,5,4]
end

# PR 16988
@test Base.promote_op(+, Bool) === Int
@test isa(broadcast(+, [true]), Array{Int,1})
Expand Down
4 changes: 4 additions & 0 deletions test/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,7 @@ end
@test repr(:(x for x in y if aa for z in w if bb)) == ":(x for x = y if aa for z = w if bb)"
@test repr(:([x for x = y])) == ":([x for x = y])"
@test repr(:([x for x = y if z])) == ":([x for x = y if z])"

for op in (:(.=), :(.+=), :(.&=))
@test repr(parse("x $op y")) == ":(x $op y)"
end

0 comments on commit cd2e260

Please sign in to comment.