Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FEATURE] Dnnl sum primitive path (#21132)
Browse files Browse the repository at this point in the history
* Added dnnl_sum primitive path to mxnet binary_add when shapes are the same

* added test coverage

* added operation check

* Random number for tests

* delete unnecessary variables

* review changes
  • Loading branch information
Kacper-Pietkun committed Aug 31, 2022
1 parent 8d933fd commit 3a19f0e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 27 deletions.
63 changes: 36 additions & 27 deletions src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
* \file elemwise_binary_broadcast_op_basic.cc
* \brief CPU Implementation of basic functions for elementwise binary broadcast operator.
*/
#include "./elemwise_unary_op.h"
#include "./elemwise_binary_op-inl.h"
#include "./elemwise_binary_broadcast_op.h"
#include "operator/tensor/elemwise_unary_op.h"
#include "operator/tensor/elemwise_binary_op-inl.h"
#include "operator/tensor/elemwise_binary_broadcast_op.h"
#if MXNET_USE_ONEDNN == 1
#include "../nn/dnnl/dnnl_binary-inl.h"
#include "operator/nn/dnnl/dnnl_binary-inl.h"
#include "operator/nn/dnnl/dnnl_sum-inl.h"
#endif // MXNET_USE_ONEDNN == 1

namespace mxnet {
Expand All @@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
inputs[1].shape(),
outputs[0].shape(),
&new_lshape,
&new_rshape,
&new_oshape);
std::vector<NDArray> new_inputs;
std::vector<NDArray> new_outputs;
if (ndim_diff) {
new_inputs = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)};
new_outputs = {outputs[0].Reshape(new_oshape)};
} else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
// BinaryBroadcastShapeCompact function doesn't reshape tensors of size (1,1,...,1)
// into shape (1). It is mandatory for oneDNN primitive to have this reshape done.
mxnet::TShape one_shape = mxnet::TShape(1, 1);
new_inputs = {inputs[0].Reshape(one_shape), inputs[1].Reshape(one_shape)};
new_outputs = {outputs[0].Reshape(one_shape)};
const mxnet::TShape& input_0_shape = inputs[0].shape();
const mxnet::TShape& input_1_shape = inputs[1].shape();
const mxnet::TShape& output_0_shape = outputs[0].shape();
// We can use more efficient sum kernel, when there is no broadcast - when shapes are the
// same.
const bool same_shape = (input_0_shape == input_1_shape);

if (same_shape && alg == dnnl::algorithm::binary_add) {
DNNLSumFwd& fwd = DNNLSumFwd::GetCached(inputs, outputs);
fwd.Execute(ctx, inputs, req, outputs);
} else {
new_inputs = {inputs[0], inputs[1]};
new_outputs = {outputs[0]};
}
mxnet::TShape new_lshape, new_rshape, new_oshape;
int ndim_diff = BinaryBroadcastShapeCompact(
input_0_shape, input_1_shape, output_0_shape, &new_lshape, &new_rshape, &new_oshape);
std::vector<NDArray> new_inputs;
std::vector<NDArray> new_outputs;
if (ndim_diff) {
new_inputs = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)};
new_outputs = {outputs[0].Reshape(new_oshape)};
} else if (input_0_shape.Size() == 1 && input_1_shape.Size() == 1) {
// BinaryBroadcastShapeCompact function doesn't reshape tensors of size (1,1,...,1)
// into shape (1). It is mandatory for oneDNN primitive to have this reshape done.
mxnet::TShape one_shape = mxnet::TShape(1, 1);
new_inputs = {inputs[0].Reshape(one_shape), inputs[1].Reshape(one_shape)};
new_outputs = {outputs[0].Reshape(one_shape)};
} else {
new_inputs = {inputs[0], inputs[1]};
new_outputs = {outputs[0]};
}

DNNLBinaryOpFwd& fwd = DNNLBinaryOpFwd::GetBinaryOpForward<alg>(new_inputs, new_outputs);
fwd.Execute(new_inputs, req, new_outputs);
DNNLBinaryOpFwd& fwd = DNNLBinaryOpFwd::GetBinaryOpForward<alg>(new_inputs, new_outputs);
fwd.Execute(new_inputs, req, new_outputs);
}
}
#endif

Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9414,6 +9414,21 @@ def test_elementwise_ops_on_misaligned_input():
mx.nd.waitall()
assert a[3].asscalar() == 4.0


@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
@pytest.mark.parametrize('ndim', [1, 2, 3, 4, 5])
@pytest.mark.parametrize('max_dim_size', [1, 2, 3, 4, 5])
def test_broadcast_ops_on_input_with_the_same_shape(dtype, ndim, max_dim_size):
shape = list(rand_shape_nd(ndim, dim=max_dim_size))
a = np.random.uniform(low=-100, high=100, size=shape)
b = np.random.uniform(low=-100, high=100, size=shape)
expected = a + b
am = mx.nd.array(a)
bm = mx.nd.array(b)
cm = am + bm
mx.nd.waitall()
assert_almost_equal(cm, expected)

@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
@pytest.mark.parametrize('lead_dim', [2, 3, 4, 6, 10])
@pytest.mark.parametrize('both_ways', [False, True])
Expand Down

0 comments on commit 3a19f0e

Please sign in to comment.