From 35c1f87176fecab5a8e8077fb3a62275363a5aa5 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Fri, 19 Jun 2020 19:16:11 -0400 Subject: [PATCH] fix #36230, more efficient lowering of `if` with a chain of `&&` (#36355) --- base/expr.jl | 4 ++-- base/meta.jl | 5 +++-- src/julia-syntax.scm | 25 +++++++++++++++++++++---- test/compiler/inference.jl | 10 ++++++++++ 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/base/expr.jl b/base/expr.jl index db2d6333d01c1..3a1acc1a7d77e 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -347,9 +347,9 @@ function findmeta_block(exargs, argsmatch=args->true) for i = 1:length(exargs) a = exargs[i] if isa(a, Expr) - if (a::Expr).head === :meta && argsmatch((a::Expr).args) + if a.head === :meta && argsmatch(a.args) return i, exargs - elseif (a::Expr).head === :block + elseif a.head === :block idx, exa = findmeta_block(a.args, argsmatch) if idx != 0 return idx, exa diff --git a/base/meta.jl b/base/meta.jl index 4cf2ac3676dcd..560d66ea37311 100644 --- a/base/meta.jl +++ b/base/meta.jl @@ -62,8 +62,9 @@ true ``` """ isexpr(@nospecialize(ex), head::Symbol) = isa(ex, Expr) && ex.head === head -isexpr(@nospecialize(ex), heads::Union{Set,Vector,Tuple}) = isa(ex, Expr) && in(ex.head, heads) -isexpr(@nospecialize(ex), heads, n::Int) = isexpr(ex, heads) && length((ex::Expr).args) == n +isexpr(@nospecialize(ex), heads) = isa(ex, Expr) && in(ex.head, heads) +isexpr(@nospecialize(ex), head::Symbol, n::Int) = isa(ex, Expr) && ex.head === head && length(ex.args) == n +isexpr(@nospecialize(ex), heads, n::Int) = isa(ex, Expr) && in(ex.head, heads) && length(ex.args) == n """ Meta.show_sexpr([io::IO,], ex) diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index b292ee5bae8fb..67fa499cc9045 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1895,6 +1895,13 @@ (else (error (string "invalid " syntax-str " \"" (deparse el) "\"")))))))) +(define (expand-if e) + (if (and (pair? (cadr e)) (eq? (car (cadr e)) '&&)) + (let ((clauses (cdr (flatten-ex '&& (cadr e))))) + `(if (&& ,@(map expand-forms clauses)) + ,@(map expand-forms (cddr e)))) + (cons (car e) (map expand-forms (cdr e))))) + ;; move an assignment into the last statement of a block to keep more statements at top level (define (sink-assignment lhs rhs) (if (and (pair? rhs) (eq? (car rhs) 'block)) @@ -2230,6 +2237,9 @@ ,(expand-forms (cadr e)) ,(expand-forms (caddr e))) (map expand-forms e))) + 'if expand-if + 'elseif expand-if + 'while (lambda (e) `(break-block loop-exit @@ -3709,7 +3719,8 @@ f(x) = yt(x) (handler-level 0) ;; exception handler nesting depth (catch-token-stack '())) ;; tokens identifying handler enter for current catch blocks (define (emit c) - (set! code (cons c code))) + (set! code (cons c code)) + c) (define (make-label) (begin0 label-counter (set! label-counter (+ 1 label-counter)))) @@ -3957,15 +3968,21 @@ f(x) = yt(x) (compile (cadr e) break-labels value tail) #f)) ((if elseif) - (let ((test `(gotoifnot ,(compile-cond (cadr e) break-labels) _)) + (let ((tests (map (lambda (clause) + (emit `(gotoifnot ,(compile-cond clause break-labels) _))) + (if (and (pair? (cadr e)) (eq? (car (cadr e)) '&&)) + (cdadr e) + (list (cadr e))))) (end-jump `(goto _)) (val (if (and value (not tail)) (new-mutable-var) #f))) - (emit test) (let ((v1 (compile (caddr e) break-labels value tail))) (if val (emit-assignment val v1)) (if (and (not tail) (or (length> e 3) val)) (emit end-jump)) - (set-car! (cddr test) (make&mark-label)) + (let ((elselabel (make&mark-label))) + (for-each (lambda (test) + (set-car! (cddr test) elselabel)) + tests)) (let ((v2 (if (length> e 3) (compile (cadddr e) break-labels value tail) '(null)))) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 788469aba82b3..321ae26f02968 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2650,3 +2650,13 @@ end f(n) = depth(n, 1) end @test Base.return_types(TestConstPropRecursion.f, (TestConstPropRecursion.Node,)) == Any[Int] + +# issue #36230, keeping implications of all conditions in a && chain +function symcmp36230(vec) + a, b = vec[1], vec[2] + if isa(a, Symbol) && isa(b, Symbol) + return a == b + end + return false +end +@test Base.return_types(symcmp36230, (Vector{Any},)) == Any[Bool]