Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Use cuDNN for conv bias and bias grad (#20771)
Browse files Browse the repository at this point in the history
* Use cuDNN for conv bias and bias grad

* Environment variables to use native add-bias and bias-grad

* Handle 3D tensors in cuDNN legacy API

* Fix AMP for ndarray.numpy

* Remove env vars, used for benchmarking

Co-authored-by: Vladimir Cherepanov <[email protected]>
  • Loading branch information
mk-61 and Vladimir Cherepanov committed Dec 17, 2021
1 parent b555b54 commit f0ef9d8
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 67 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/amp/loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def has_overflow(self, params):
"""Check gradients for overflow."""
if is_np_array():
all_finite_f = ndarray.numpy._internal.multi_all_finite
ones_f = ndarray.numpy.ones
ones_f = lambda ctx: ndarray.numpy.ones((1,), device=ctx)
else:
all_finite_f = ndarray.multi_all_finite
ones_f = ndarray.ones
ones_f = lambda ctx: ndarray.ones((1,), ctx=ctx)
with ag.pause():
chunk_size = 200
valid_params = [p._grad[0] for p in params if p._grad is not None]
gpu_output = ones_f((1,), ctx=valid_params[0].context)
gpu_output = ones_f(valid_params[0].context)
nb_params = len(valid_params)
for idx in range(0, nb_params, chunk_size):
all_finite_f(*valid_params[idx:idx+chunk_size],
Expand Down
9 changes: 0 additions & 9 deletions src/common/cuda/cudnn_cxx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,6 @@ std::vector<Descriptor> GetSomeAttrs(size_t max_n,
return ret;
}

std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
const std::vector<int64_t>& dims) {
CHECK_EQ(order.size(), dims.size());
std::vector<int64_t> ret(dims.size(), 1);
for (size_t i = dims.size() - 1; i--;)
ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
return ret;
}

std::vector<Descriptor> GetPlans(cudnnBackendHeurMode_t h_mode,
cudnnHandle_t handle,
const Descriptor& op_graph,
Expand Down
10 changes: 8 additions & 2 deletions src/common/cuda/cudnn_cxx.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,14 @@ std::vector<Descriptor> GetSomeAttrs(size_t max_n,
cudnnBackendDescriptorType_t type);

// Order sets layout, as a permutation of dims, with N,C,<spacial dims> being identity.
std::vector<int64_t> PackedStrides(const std::vector<size_t>& order,
const std::vector<int64_t>& dims);
template <typename T>
std::vector<T> PackedStrides(const std::vector<size_t>& order, const std::vector<T>& dims) {
CHECK_EQ(order.size(), dims.size());
std::vector<T> ret(dims.size(), 1);
for (size_t i = dims.size() - 1; i--;)
ret[order[i]] = dims[order[i + 1]] * ret[order[i + 1]];
return ret;
}

// Given an engine config's `notes`, return whether that config is compatible, i.e. does
// the config have all of the required notes and none of the notes that are being excluded.
Expand Down
119 changes: 105 additions & 14 deletions src/operator/cudnn_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@

#include <dmlc/parameter.h>

#include <algorithm>
#include <cstdlib>
#include <iomanip>
#include <iterator>
#include <limits>
#include <numeric>
#include <sstream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -79,10 +77,6 @@ size_t LayoutInfo::ChannelIdx() const {
return channel_last ? 1 + n_space_dims : 1;
}

std::vector<int64_t> LayoutInfo::Strides(const std::vector<int64_t>& dims) const {
return PackedStrides(Order(), dims);
}

LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout) {
static std::unordered_map<mshadow::LayoutFlag, LayoutInfo> layout_map{
{mshadow::kNCW, {1, false}},
Expand Down Expand Up @@ -165,14 +159,8 @@ Descriptor MakeTensorDesc(int64_t uid,
for (size_t i = 0; i < dims.size(); ++i)
dims[i] = blob.shape_[rev_order[i]];
auto strides = li.Strides(dims);
if (li.n_space_dims == 1 && expand_1d) {
dims.insert(dims.begin() + 2, 1);
std::vector<size_t> order(dims.size());
std::iota(order.begin(), order.end(), 0);
if (li.channel_last)
std::rotate(order.begin() + 1, order.begin() + 2, order.end());
strides = PackedStrides(order, dims);
}
if (expand_1d)
li.ExpandIf1d(&dims, &strides);
return MakeTensorDesc(
uid, CudnnType(static_cast<mshadow::TypeFlag>(blob.type_flag_)), dims, strides, is_virtual);
}
Expand Down Expand Up @@ -758,6 +746,109 @@ void ConvWgrad::Exec(const cudnn_cxx::Descriptor& plan,
CUDNN_CALL(cudnnBackendExecute(s->dnn_handle_, plan.get(), var_pack.get()));
}

struct LegacyTensorDestroyer {
using pointer = cudnnTensorDescriptor_t;

void operator()(cudnnTensorDescriptor_t desc) {
CUDNN_CALL_NONFATAL(cudnnDestroyTensorDescriptor(desc));
}
};

using LegacyTensor = std::unique_ptr<cudnnTensorDescriptor_t, LegacyTensorDestroyer>;

LegacyTensor MakeLegacyTensor() {
cudnnTensorDescriptor_t desc{};
CUDNN_CALL(cudnnCreateTensorDescriptor(&desc));
return LegacyTensor(desc);
}

union ScalingParam {
double d;
float f;
};

std::pair<ScalingParam, ScalingParam> AlphaBeta(int type_flag, double init_a, double init_b) {
ScalingParam a, b;
switch (type_flag) {
case kFloat64:
a.d = init_a;
b.d = init_b;
break;
case kFloat32: // fallthrough
case kFloat16:
a.f = init_a;
b.f = init_b;
break;
default:
LOG(FATAL) << "Unexpected type: " << type_flag;
}
return {a, b};
}

void SetLegacyTensor(cudnnTensorDescriptor_t desc, const TBlob& blob, const LayoutInfo& li) {
std::vector<int> dims(blob.shape_.ndim());
CHECK_EQ(dims.size(), li.n_space_dims + 2);
auto rev_order = ReverseOrder(li.Order());
for (size_t i = 0; i < dims.size(); ++i)
dims[i] = blob.shape_[rev_order[i]];
auto strides = li.Strides(dims);
li.ExpandIf1d(&dims, &strides);
auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
}

void SetLegacyCTensorExpandDims(cudnnTensorDescriptor_t desc,
const TBlob& blob,
const LayoutInfo& li) {
std::vector<int> dims(li.n_space_dims + 2, 1);
dims[1] = blob.shape_[0];
std::vector<int> strides(dims.size(), 1);
strides[0] = blob.shape_[0];
li.ExpandIf1d(&dims, &strides);
auto type = static_cast<mshadow::TypeFlag>(blob.type_flag_);
CUDNN_CALL(cudnnSetTensorNdDescriptor(desc, CudnnType(type), dims.size(), &dims[0], &strides[0]));
}

bool LegacyAddBias(const OpContext& ctx, const LayoutInfo& li, const TBlob& y, const TBlob& b) {
thread_local auto y_desc = MakeLegacyTensor();
thread_local auto b_desc = MakeLegacyTensor();

auto s = ctx.get_stream<gpu>();
auto [alpha, beta] = AlphaBeta(y.type_flag_, 1.0, 1.0); // NOLINT(whitespace/braces)

SetLegacyTensor(y_desc.get(), y, li);
SetLegacyCTensorExpandDims(b_desc.get(), b, li);

auto err =
cudnnAddTensor(s->dnn_handle_, &alpha, b_desc.get(), b.dptr_, &beta, y_desc.get(), y.dptr_);
if (err == CUDNN_STATUS_NOT_SUPPORTED)
return false;
CHECK_EQ(err, CUDNN_STATUS_SUCCESS);
return true;
}

bool LegacyBiasGrad(const OpContext& ctx,
const LayoutInfo& li,
bool add_to,
const TBlob& db,
const TBlob& dy) {
thread_local auto db_desc = MakeLegacyTensor();
thread_local auto dy_desc = MakeLegacyTensor();

auto s = ctx.get_stream<gpu>();
auto [alpha, beta] = AlphaBeta(dy.type_flag_, 1.0, add_to ? 1.0 : 0.0); // NOLINT(*)

SetLegacyCTensorExpandDims(db_desc.get(), db, li);
SetLegacyTensor(dy_desc.get(), dy, li);

auto err = cudnnConvolutionBackwardBias(
s->dnn_handle_, &alpha, dy_desc.get(), dy.dptr_, &beta, db_desc.get(), db.dptr_);
if (err == CUDNN_STATUS_NOT_SUPPORTED)
return false;
CHECK_EQ(err, CUDNN_STATUS_SUCCESS);
return true;
}

} // namespace cudnn
} // namespace op
} // namespace mxnet
Expand Down
28 changes: 27 additions & 1 deletion src/operator/cudnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

#include <mxnet/op_attr_types.h>

#include <algorithm>
#include <mutex>
#include <numeric>
#include <tuple>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -89,7 +91,23 @@ struct LayoutInfo {

std::vector<size_t> Order() const;
size_t ChannelIdx() const;
std::vector<int64_t> Strides(const std::vector<int64_t>& dims) const;

template <typename T>
std::vector<T> Strides(const std::vector<T>& dims) const {
return cudnn_cxx::PackedStrides(Order(), dims);
}

template <typename T>
void ExpandIf1d(std::vector<T>* dims, std::vector<T>* strides) const {
if (n_space_dims != 1)
return;
dims->insert(dims->begin() + 2, 1);
std::vector<size_t> order(dims->size());
std::iota(order.begin(), order.end(), 0);
if (channel_last)
std::rotate(order.begin() + 1, order.begin() + 2, order.end());
*strides = cudnn_cxx::PackedStrides(order, *dims);
}
};

LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout);
Expand Down Expand Up @@ -246,6 +264,14 @@ struct ConvWgrad {
const TBlob& dw);
};

bool LegacyAddBias(const OpContext& ctx, const LayoutInfo& li, const TBlob& y, const TBlob& b);

bool LegacyBiasGrad(const OpContext& ctx,
const LayoutInfo& li,
bool add_to,
const TBlob& db,
const TBlob& dy);

} // namespace cudnn
} // namespace op
} // namespace mxnet
Expand Down
46 changes: 27 additions & 19 deletions src/operator/nn/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
if (ok && !param.no_bias) {
CHECK_EQ(inputs[conv::kBias].shape_.ndim(), 1);
auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
int k = inputs[conv::kBias].shape_.Size();
auto b = inputs[conv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
attrs,
ctx,
{outputs[conv::kOut], b},
{kWriteInplace},
{outputs[conv::kOut]});
auto li = cudnn::GetLayoutInfo(layout);
if (li.channel_last ||
!cudnn::LegacyAddBias(ctx, li, outputs[conv::kOut], inputs[conv::kBias])) {
int k = inputs[conv::kBias].shape_.Size();
auto b = inputs[conv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
attrs,
ctx,
{outputs[conv::kOut], b},
{kWriteInplace},
{outputs[conv::kOut]});
}
}
if (!ok) {
if (!param.cudnn_off)
Expand Down Expand Up @@ -137,17 +141,21 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
cudnn::Exec<cudnn::ConvWgrad>(
ctx, conv_param, inputs[1 + conv::kData], inputs[0], outputs[conv::kWeight]));
if (ok && !param.no_bias && req[conv::kBias] != kNullOp) {
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
if (li.channel_last) {
// This kernel should be faster.
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
AddBiasGrad(outputs[conv::kBias], y_grad, req[conv::kBias], param.num_filter, ctx);
} else {
TShape axes{static_cast<int>(li.ChannelIdx())};
TShape small =
ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
ReduceAxesRTCComputeImpl(
ctx, {inputs[0]}, {req[conv::kBias]}, {outputs[conv::kBias]}, small, "red::sum{}");
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
auto add_to = req[conv::kBias] == kAddTo;
if (li.channel_last ||
!cudnn::LegacyBiasGrad(ctx, li, add_to, outputs[conv::kBias], inputs[0])) {
if (li.channel_last) {
// This kernel should be faster.
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
AddBiasGrad(outputs[conv::kBias], y_grad, req[conv::kBias], param.num_filter, ctx);
} else {
TShape axes{static_cast<int>(li.ChannelIdx())};
TShape small = ReduceAxesShapeImpl(
inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
ReduceAxesRTCComputeImpl(
ctx, {inputs[0]}, {req[conv::kBias]}, {outputs[conv::kBias]}, small, "red::sum{}");
}
}
}
if (!ok) {
Expand Down
50 changes: 31 additions & 19 deletions src/operator/nn/deconvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,18 @@ void DeconvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
if (ok && !param.no_bias) {
CHECK_EQ(inputs[deconv::kBias].shape_.ndim(), 1);
auto layout = static_cast<mshadow::LayoutFlag>(param.layout.value());
int k = inputs[deconv::kBias].shape_.Size();
auto b = inputs[deconv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
attrs,
ctx,
{outputs[deconv::kOut], b},
{kWriteInplace},
{outputs[deconv::kOut]});
auto li = cudnn::GetLayoutInfo(layout);
if (li.channel_last ||
!cudnn::LegacyAddBias(ctx, li, outputs[deconv::kOut], inputs[deconv::kBias])) {
int k = inputs[deconv::kBias].shape_.Size();
auto b = inputs[deconv::kBias].reshape(cudnn::ExpandChannelDims(layout, k));
BinaryBroadcastRTCCompute{"add"}( // NOLINT(whitespace/braces)
attrs,
ctx,
{outputs[deconv::kOut], b},
{kWriteInplace},
{outputs[deconv::kOut]});
}
}
if (!ok) {
if (!param.cudnn_off)
Expand Down Expand Up @@ -115,17 +119,25 @@ void DeconvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
cudnn::Exec<cudnn::ConvWgrad>(
ctx, conv_param, inputs[0], inputs[1 + deconv::kData], outputs[deconv::kWeight]));
if (ok && !param.no_bias && req[deconv::kBias] != kNullOp) {
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
if (li.channel_last) {
// This kernel should be faster.
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
AddBiasGrad(outputs[deconv::kBias], y_grad, req[deconv::kBias], param.num_filter, ctx);
} else {
TShape axes{static_cast<int>(li.ChannelIdx())};
TShape small =
ReduceAxesShapeImpl(inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
ReduceAxesRTCComputeImpl(
ctx, {inputs[0]}, {req[deconv::kBias]}, {outputs[deconv::kBias]}, small, "red::sum{}");
auto li = cudnn::GetLayoutInfo(static_cast<mshadow::LayoutFlag>(param.layout.value()));
auto add_to = req[conv::kBias] == kAddTo;
if (li.channel_last ||
!cudnn::LegacyBiasGrad(ctx, li, add_to, outputs[deconv::kBias], inputs[0])) {
if (li.channel_last) {
// This kernel should be faster.
auto y_grad = FlattenAs2DHead<gpu, DType>(inputs[0], ctx);
AddBiasGrad(outputs[deconv::kBias], y_grad, req[deconv::kBias], param.num_filter, ctx);
} else {
TShape axes{static_cast<int>(li.ChannelIdx())};
TShape small = ReduceAxesShapeImpl(
inputs[0].shape_, dmlc::optional<mxnet::TShape>(axes), true, true);
ReduceAxesRTCComputeImpl(ctx,
{inputs[0]},
{req[deconv::kBias]},
{outputs[deconv::kBias]},
small,
"red::sum{}");
}
}
}
if (!ok) {
Expand Down

0 comments on commit f0ef9d8

Please sign in to comment.