Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][AlterLayout] Broadcast with scalar shape #4577

Merged
merged 1 commit into from
Dec 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,28 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
return *channels;
}

/*!
* \brief Is single value tensor (scalar).
* \param expr The expr.
* \return True if single value tensor.
*/
inline bool IsScalar(const Expr& expr) {
if (auto tensor_type = expr->checked_type().as<TensorTypeNode>()) {
for (auto dim_index_expr : tensor_type->shape) {
if (auto dim_index = dim_index_expr.as<IntImm>()) {
if (dim_index->value != 1) {
return false;
}
} else {
return false;
}
}
} else {
return false;
}
return true;
}

/*!
* \brief Create a Constant with a scalar
*
Expand Down
5 changes: 5 additions & 0 deletions src/relay/pass/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class TransformMemorizer : public NodeRef {
Expr input_expr = raw;
Layout new_src_layout = src_layout;
if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
// If scalar, then no need of layout transformation as scalar can be broadcasted easily even
// if the other operand has a transformed layout.
if (IsScalar(input_expr)) {
return raw;
}
int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
new_src_layout = src_layout.ExpandPrimal(dst_layout);
input_expr = MakeExpandDims(input_expr, 0, num_new_axis);
Expand Down
65 changes: 65 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,70 @@ def expected():

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_broadcast_scalar_op():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""
def before():
x = relay.var("x", shape=(1, 500, 500, 64))
kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
bias = relay.var("bias", shape=(64,))
multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32')
multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32')

y = relay.nn.conv2d(x, kernel,
data_layout='NHWC',
kernel_layout="HWIO",
kernel_size=(3, 3))
y = relay.add(bias, y)
y = relay.nn.relu(y)

y = relay.multiply(multiplier1, y)
y = relay.multiply(y, multiplier2)
y = relay.Function(analysis.free_vars(y), y)
return y

def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 500, 500, 64))
kernel = relay.var('kernel', shape=(3, 3, 64, 64), dtype='float32')
bias = relay.var("bias", shape=(64,))
multiplier1 = relay.var('multiplier1', shape=(1, ), dtype='float32')
multiplier2 = relay.var('multiplier2', shape=(1, 1), dtype='float32')

b = relay.expand_dims(bias, axis=0, num_newaxis=3)
b = relay.layout_transform(b, "NHWC", "NCHW16c")

y = relay.layout_transform(x, "NHWC", "NCHW16c")
y = relay.nn.conv2d(y, kernel,
data_layout='NCHW16c',
kernel_layout="HWIO",
kernel_size=(3, 3))

y = relay.add(b, y)
y = relay.nn.relu(y)

y = relay.multiply(multiplier1, y)
y = relay.multiply(y, multiplier2)
y = relay.layout_transform(y, "NCHW16c", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = before()
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_scalar():
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
Expand Down Expand Up @@ -980,6 +1044,7 @@ def expected():
test_alter_layout_dual_path()
test_alter_layout_resnet()
test_alter_layout_broadcast_op()
test_alter_layout_broadcast_scalar_op()
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
Expand Down