forked from halide/Halide
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ApplySplit.cpp
163 lines (139 loc) · 8.1 KB
/
ApplySplit.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#include "ApplySplit.h"
#include "IR.h"
#include "IROperator.h"
#include "Simplify.h"
#include "Substitute.h"
namespace Halide {
namespace Internal {
using std::map;
using std::string;
using std::vector;
vector<ApplySplitResult> apply_split(const Split &split, bool is_update, const string &prefix,
map<string, Expr> &dim_extent_alignment) {
vector<ApplySplitResult> result;
Expr outer = Variable::make(Int(32), prefix + split.outer);
Expr outer_max = Variable::make(Int(32), prefix + split.outer + ".loop_max");
if (split.is_split()) {
Expr inner = Variable::make(Int(32), prefix + split.inner);
Expr old_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max");
Expr old_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min");
Expr old_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent");
dim_extent_alignment[split.inner] = split.factor;
Expr base = outer * split.factor + old_min;
string base_name = prefix + split.inner + ".base";
Expr base_var = Variable::make(Int(32), base_name);
string old_var_name = prefix + split.old_var;
Expr old_var = Variable::make(Int(32), old_var_name);
map<string, Expr>::iterator iter = dim_extent_alignment.find(split.old_var);
TailStrategy tail = split.tail;
internal_assert(tail != TailStrategy::Auto)
<< "An explicit tail strategy should exist at this point\n";
if ((iter != dim_extent_alignment.end()) &&
is_zero(simplify(iter->second % split.factor))) {
// We have proved that the split factor divides the
// old extent. No need to adjust the base or add an if
// statement.
dim_extent_alignment[split.outer] = iter->second / split.factor;
} else if (is_negative_const(split.factor) || is_zero(split.factor)) {
user_error << "Can't split " << split.old_var << " by " << split.factor
<< ". Split factors must be strictly positive\n";
} else if (is_one(split.factor)) {
// The split factor trivially divides the old extent,
// but we know nothing new about the outer dimension.
} else if (tail == TailStrategy::GuardWithIf) {
// It's an exact split but we failed to prove that the
// extent divides the factor. Use predication to avoid
// running off the end of the original loop.
// Bounds inference has trouble exploiting an if
// condition. We'll directly tell it that the loop
// variable is bounded above by the original loop max by
// replacing the variable with a promise-clamped version
// of it. We don't also use the original loop min because
// it needlessly complicates the expressions and doesn't
// actually communicate anything new.
Expr guarded = promise_clamped(old_var, old_var, old_max);
string guarded_var_name = prefix + split.old_var + ".guarded";
Expr guarded_var = Variable::make(Int(32), guarded_var_name);
result.emplace_back(prefix + split.old_var, guarded_var, ApplySplitResult::Substitution);
result.emplace_back(guarded_var_name, guarded, ApplySplitResult::LetStmt);
// Inject the if condition *after* doing the substitution
// for the guarded version.
Expr cond = likely(old_var <= old_max);
result.emplace_back(cond);
} else if (tail == TailStrategy::ShiftInwards) {
// Adjust the base downwards to not compute off the
// end of the realization.
// We'll only mark the base as likely (triggering a loop
// partition) if we're at or inside the innermost
// non-trivial loop.
base = likely_if_innermost(base);
base = Min::make(base, old_max + (1 - split.factor));
} else {
internal_assert(tail == TailStrategy::RoundUp);
}
// Define the original variable as the base value computed above plus the inner loop variable.
result.emplace_back(old_var_name, base_var + inner, ApplySplitResult::LetStmt);
result.emplace_back(base_name, base, ApplySplitResult::LetStmt);
} else if (split.is_fuse()) {
// Define the inner and outer in terms of the fused var
Expr fused = Variable::make(Int(32), prefix + split.old_var);
Expr inner_min = Variable::make(Int(32), prefix + split.inner + ".loop_min");
Expr outer_min = Variable::make(Int(32), prefix + split.outer + ".loop_min");
Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent");
const Expr &factor = inner_extent;
Expr inner = fused % factor + inner_min;
Expr outer = fused / factor + outer_min;
result.emplace_back(prefix + split.inner, inner, ApplySplitResult::Substitution);
result.emplace_back(prefix + split.outer, outer, ApplySplitResult::Substitution);
result.emplace_back(prefix + split.inner, inner, ApplySplitResult::LetStmt);
result.emplace_back(prefix + split.outer, outer, ApplySplitResult::LetStmt);
// Maintain the known size of the fused dim if
// possible. This is important for possible later splits.
map<string, Expr>::iterator inner_dim = dim_extent_alignment.find(split.inner);
map<string, Expr>::iterator outer_dim = dim_extent_alignment.find(split.outer);
if (inner_dim != dim_extent_alignment.end() &&
outer_dim != dim_extent_alignment.end()) {
dim_extent_alignment[split.old_var] = inner_dim->second * outer_dim->second;
}
} else {
// rename or purify
result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::Substitution);
result.emplace_back(prefix + split.old_var, outer, ApplySplitResult::LetStmt);
}
return result;
}
vector<std::pair<string, Expr>> compute_loop_bounds_after_split(const Split &split, const string &prefix) {
// Define the bounds on the split dimensions using the bounds
// on the function args. If it is a purify, we should use the bounds
// from the dims instead.
vector<std::pair<string, Expr>> let_stmts;
Expr old_var_extent = Variable::make(Int(32), prefix + split.old_var + ".loop_extent");
Expr old_var_max = Variable::make(Int(32), prefix + split.old_var + ".loop_max");
Expr old_var_min = Variable::make(Int(32), prefix + split.old_var + ".loop_min");
if (split.is_split()) {
Expr inner_extent = split.factor;
Expr outer_extent = (old_var_max - old_var_min + split.factor) / split.factor;
let_stmts.emplace_back(prefix + split.inner + ".loop_min", 0);
let_stmts.emplace_back(prefix + split.inner + ".loop_max", inner_extent - 1);
let_stmts.emplace_back(prefix + split.inner + ".loop_extent", inner_extent);
let_stmts.emplace_back(prefix + split.outer + ".loop_min", 0);
let_stmts.emplace_back(prefix + split.outer + ".loop_max", outer_extent - 1);
let_stmts.emplace_back(prefix + split.outer + ".loop_extent", outer_extent);
} else if (split.is_fuse()) {
// Define bounds on the fused var using the bounds on the inner and outer
Expr inner_extent = Variable::make(Int(32), prefix + split.inner + ".loop_extent");
Expr outer_extent = Variable::make(Int(32), prefix + split.outer + ".loop_extent");
Expr fused_extent = inner_extent * outer_extent;
let_stmts.emplace_back(prefix + split.old_var + ".loop_min", 0);
let_stmts.emplace_back(prefix + split.old_var + ".loop_max", fused_extent - 1);
let_stmts.emplace_back(prefix + split.old_var + ".loop_extent", fused_extent);
} else if (split.is_rename()) {
let_stmts.emplace_back(prefix + split.outer + ".loop_min", old_var_min);
let_stmts.emplace_back(prefix + split.outer + ".loop_max", old_var_max);
let_stmts.emplace_back(prefix + split.outer + ".loop_extent", old_var_extent);
}
// Do nothing for purify
return let_stmts;
}
} // namespace Internal
} // namespace Halide