Skip to content

Commit

Permalink
Use loop inverted loop lowering
Browse files Browse the repository at this point in the history
The primary idea of the new iteration protocol is that for
a function like:
```
function iterate(itr)
   done(itr) ? nothing : next(itr)
end
```
we can fuse the `done` comparison into the loop condition and
recover the same loop structure we had before (while retaining
the flexibility of not requiring the done function to be separate),
i.e. for
```
y = iterate(itr)
y === nothing && break
```
we want to have after inlining and early optimization:
```
done(itr) && break
y = next(itr)
```
LLVM performs this optimization in jump threading. However, we run
into a problem. At the top of the loop we have:
```
y = iterate
top:
%cond = y === nothing
br i1 %cond, %exit, %loop
....
```
We'd want to thread over the `top` block (this makes sense, since
by the discussion above, we need to merge our condition into the
loop exit condition). However, LLVM (quite sensibly) refuses to
thread over loop headers and since `top` is both a loop header
and a loop exit, we fail to perform the appropriate transformation.

However, there's a simple fix. Instead of emitting a foor loop as
```
y = iterate(itr)
while y !== nothing
    x, state = y
    ...
    y = iterate(itr, state)
end
```
we can emit it as
```
y = iterate(itr)
if y !== nothing
    while true
       x, state = y
       ...
       y = iterate(itr, state)
       y === nothing && break
    end
end
```
This transformation is known as `loop inversion` (or a special
case of `loop rotation`. In our case the primary benefit is
that we can fuse the condition contained in the initial `iterate`
call into the bypass if, which then lets LLVM understand our loop
structure.

Co-authored-by: Jeff Bezanson <[email protected]>
  • Loading branch information
Keno and Jeff Bezanson committed May 18, 2018
1 parent 1a1d6b6 commit 62fbad2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
9 changes: 5 additions & 4 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6507,7 +6507,7 @@ static std::unique_ptr<Module> emit_function(
};

// Codegen Phi nodes
std::map<BasicBlock *, BasicBlock*> BB_rewrite_map;
std::map<std::pair<BasicBlock *, BasicBlock*>, BasicBlock*> BB_rewrite_map;
std::vector<llvm::PHINode*> ToDelete;
for (auto &tup : ctx.PhiNodes) {
jl_cgval_t phi_result;
Expand All @@ -6526,8 +6526,9 @@ static std::unique_ptr<Module> emit_function(
Value *V = NULL;
BasicBlock *IncomingBB = come_from_bb[edge];
BasicBlock *FromBB = IncomingBB;
if (BB_rewrite_map.count(FromBB)) {
FromBB = BB_rewrite_map[IncomingBB];
std::pair<BasicBlock *, BasicBlock*> LookupKey(IncomingBB, PhiBB);
if (BB_rewrite_map.count(LookupKey)) {
FromBB = BB_rewrite_map[LookupKey];
}
#ifndef JL_NDEBUG
bool found_pred = false;
Expand Down Expand Up @@ -6681,7 +6682,7 @@ static std::unique_ptr<Module> emit_function(
// Check any phi nodes in the Phi block to see if by splitting the edges,
// we made things inconsistent
if (FromBB != ctx.builder.GetInsertBlock()) {
BB_rewrite_map[IncomingBB] = ctx.builder.GetInsertBlock();
BB_rewrite_map[LookupKey] = ctx.builder.GetInsertBlock();
for (BasicBlock::iterator I = PhiBB->begin(); isa<PHINode>(I); ++I) {
PHINode *PN = cast<PHINode>(I);
ssize_t BBIdx = PN->getBasicBlockIndex(FromBB);
Expand Down
17 changes: 13 additions & 4 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1606,10 +1606,11 @@
;; TODO avoid `local declared twice` error from this
;;,@(if outer `((local ,lhs)) '())
,@(if outer `((require-existing-local ,lhs)) '())
(_while
(call (|.| (core Intrinsics) 'not_int) (call (core ===) ,next (null)))
(block ,body
(= ,next (call (top iterate) ,coll ,state)))))))))))
(if (call (top not_int) (call (core ===) ,next (null)))
(_do_while
(block ,body
(= ,next (call (top iterate) ,coll ,state)))
(call (|.| (core Intrinsics) 'not_int) (call (core ===) ,next (null))))))))))))

;; wrap `expr` in a function appropriate for consuming values from given ranges
(define (func-for-generator-ranges expr range-exprs flat outervars)
Expand Down Expand Up @@ -3644,6 +3645,14 @@ f(x) = yt(x)
(compile (caddr e) break-labels #f #f)
(emit `(goto ,topl))
(mark-label endl)))
((_do_while)
(let* ((endl (make-label))
(topl (make&mark-label)))
(compile (cadr e) break-labels #f #f)
(let ((test (compile-cond (caddr e) break-labels)))
(emit `(gotoifnot ,test ,endl)))
(emit `(goto ,topl))
(mark-label endl)))
((break-block)
(let ((endl (make-label)))
(begin0 (compile (caddr e)
Expand Down

0 comments on commit 62fbad2

Please sign in to comment.