Skip to content

Commit

Permalink
fix JuliaLang#36230, more efficient lowering of if with a chain of …
Browse files Browse the repository at this point in the history
…`&&` (JuliaLang#36355)
  • Loading branch information
JeffBezanson committed Jun 19, 2020
1 parent 00c41cc commit 35c1f87
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
4 changes: 2 additions & 2 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,9 @@ function findmeta_block(exargs, argsmatch=args->true)
for i = 1:length(exargs)
a = exargs[i]
if isa(a, Expr)
if (a::Expr).head === :meta && argsmatch((a::Expr).args)
if a.head === :meta && argsmatch(a.args)
return i, exargs
elseif (a::Expr).head === :block
elseif a.head === :block
idx, exa = findmeta_block(a.args, argsmatch)
if idx != 0
return idx, exa
Expand Down
5 changes: 3 additions & 2 deletions base/meta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ true
```
"""
isexpr(@nospecialize(ex), head::Symbol) = isa(ex, Expr) && ex.head === head
isexpr(@nospecialize(ex), heads::Union{Set,Vector,Tuple}) = isa(ex, Expr) && in(ex.head, heads)
isexpr(@nospecialize(ex), heads, n::Int) = isexpr(ex, heads) && length((ex::Expr).args) == n
isexpr(@nospecialize(ex), heads) = isa(ex, Expr) && in(ex.head, heads)
isexpr(@nospecialize(ex), head::Symbol, n::Int) = isa(ex, Expr) && ex.head === head && length(ex.args) == n
isexpr(@nospecialize(ex), heads, n::Int) = isa(ex, Expr) && in(ex.head, heads) && length(ex.args) == n

"""
Meta.show_sexpr([io::IO,], ex)
Expand Down
25 changes: 21 additions & 4 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1895,6 +1895,13 @@
(else
(error (string "invalid " syntax-str " \"" (deparse el) "\""))))))))

(define (expand-if e)
(if (and (pair? (cadr e)) (eq? (car (cadr e)) '&&))
(let ((clauses (cdr (flatten-ex '&& (cadr e)))))
`(if (&& ,@(map expand-forms clauses))
,@(map expand-forms (cddr e))))
(cons (car e) (map expand-forms (cdr e)))))

;; move an assignment into the last statement of a block to keep more statements at top level
(define (sink-assignment lhs rhs)
(if (and (pair? rhs) (eq? (car rhs) 'block))
Expand Down Expand Up @@ -2230,6 +2237,9 @@
,(expand-forms (cadr e)) ,(expand-forms (caddr e)))
(map expand-forms e)))

'if expand-if
'elseif expand-if

'while
(lambda (e)
`(break-block loop-exit
Expand Down Expand Up @@ -3709,7 +3719,8 @@ f(x) = yt(x)
(handler-level 0) ;; exception handler nesting depth
(catch-token-stack '())) ;; tokens identifying handler enter for current catch blocks
(define (emit c)
(set! code (cons c code)))
(set! code (cons c code))
c)
(define (make-label)
(begin0 label-counter
(set! label-counter (+ 1 label-counter))))
Expand Down Expand Up @@ -3957,15 +3968,21 @@ f(x) = yt(x)
(compile (cadr e) break-labels value tail)
#f))
((if elseif)
(let ((test `(gotoifnot ,(compile-cond (cadr e) break-labels) _))
(let ((tests (map (lambda (clause)
(emit `(gotoifnot ,(compile-cond clause break-labels) _)))
(if (and (pair? (cadr e)) (eq? (car (cadr e)) '&&))
(cdadr e)
(list (cadr e)))))
(end-jump `(goto _))
(val (if (and value (not tail)) (new-mutable-var) #f)))
(emit test)
(let ((v1 (compile (caddr e) break-labels value tail)))
(if val (emit-assignment val v1))
(if (and (not tail) (or (length> e 3) val))
(emit end-jump))
(set-car! (cddr test) (make&mark-label))
(let ((elselabel (make&mark-label)))
(for-each (lambda (test)
(set-car! (cddr test) elselabel))
tests))
(let ((v2 (if (length> e 3)
(compile (cadddr e) break-labels value tail)
'(null))))
Expand Down
10 changes: 10 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2650,3 +2650,13 @@ end
f(n) = depth(n, 1)
end
@test Base.return_types(TestConstPropRecursion.f, (TestConstPropRecursion.Node,)) == Any[Int]

# issue #36230, keeping implications of all conditions in a && chain
function symcmp36230(vec)
a, b = vec[1], vec[2]
if isa(a, Symbol) && isa(b, Symbol)
return a == b
end
return false
end
@test Base.return_types(symcmp36230, (Vector{Any},)) == Any[Bool]

0 comments on commit 35c1f87

Please sign in to comment.