From faa283228d8e4aa391dd0877b7388996e9a0e223 Mon Sep 17 00:00:00 2001 From: Wang Jiajun Date: Wed, 18 Dec 2019 10:57:35 +0800 Subject: [PATCH] Add im2col and col2im operator (#16502) * add im2col * add col2im * fix typo * add docs * add unittest * more tests * fix lint * fix doc * fix request * trigger CI --- src/operator/nn/im2col-inl.h | 259 +++++++++++++++++++++++ src/operator/nn/im2col.cc | 272 +++++++++++++++++++++++++ src/operator/nn/im2col.cu | 45 ++++ tests/python/unittest/test_operator.py | 147 +++++++++++++ 4 files changed, 723 insertions(+) create mode 100644 src/operator/nn/im2col-inl.h create mode 100644 src/operator/nn/im2col.cc create mode 100644 src/operator/nn/im2col.cu diff --git a/src/operator/nn/im2col-inl.h b/src/operator/nn/im2col-inl.h new file mode 100644 index 000000000000..b5caa035f911 --- /dev/null +++ b/src/operator/nn/im2col-inl.h @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file im2col-inl.h + * \brief + * \author Jiajun Wang +*/ + +#ifndef MXNET_OPERATOR_NN_IM2COL_INL_H_ +#define MXNET_OPERATOR_NN_IM2COL_INL_H_ +#include +#include "../mxnet_op.h" +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" +#include "./im2col.h" + +namespace mxnet { +namespace op { + +struct Im2colParam : public dmlc::Parameter { + mxnet::TShape kernel; + mxnet::TShape stride; + mxnet::TShape dilate; + mxnet::TShape pad; + DMLC_DECLARE_PARAMETER(Im2colParam) { + DMLC_DECLARE_FIELD(kernel).describe("Sliding kernel size: (w,), (h, w) or (d, h, w)."); + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, 0)) + .describe("The stride between adjacent sliding blocks in spatial dimension: " + "(w,), (h, w) or (d, h, w). Defaults to 1 for each dimension."); + DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0, 0)) + .describe("The spacing between adjacent kernel points: (w,), (h, w) or (d, h, w). " + "Defaults to 1 for each dimension."); + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, 0)) + .describe("The zero-value padding size on both sides of spatial dimension: " + "(w,), (h, w) or (d, h, w). Defaults to no padding."); + } + + index_t DilatedKernelSize(int dim) const { + return 1 + (kernel[dim] - 1) * dilate[dim]; + } +}; // struct Im2colParam + + +template +void Im2colCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + const Im2colParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + const mxnet::TShape im_shape = inputs[0].shape_; + const mxnet::TShape col_shape = outputs[0].shape_; + const index_t num = im_shape[0]; + + const int spatial_size = param.kernel.ndim(); + mxnet::TShape col_buffer_shape(1 + spatial_size, 1); + col_buffer_shape[0] = col_shape[1]; + for (int i = 0; i < spatial_size; ++i) { + const index_t pad_size = im_shape[i + 2] + 2 * param.pad[i]; + const index_t output_size = (pad_size - param.DilatedKernelSize(i)) / param.stride[i] + 1; + col_buffer_shape[i + 1] = output_size; + } + + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor im = inputs[0].get_with_shape( + Shape4(im_shape[0], im_shape[1], im_shape[2], im_shape[3]), s); + Tensor col = outputs[0].get_with_shape( + Shape3(col_shape[0], col_shape[1], col_shape[2]), s); + + if (req[0] == kNullOp) return; + if (req[0] != kAddTo) { + for (index_t n = 0; n < num; ++n) { + im2col(s, im[n].dptr_, im_shape, col_buffer_shape, + param.kernel, param.pad, param.stride, param.dilate, col[n].dptr_); + } + } else { + Tensor tcol = ctx.requested[0] + .get_space_typed(Shape2(col_shape[1], col_shape[2]), s); + for (index_t n = 0; n < num; ++n) { + im2col(s, im[n].dptr_, im_shape, col_buffer_shape, + param.kernel, param.pad, param.stride, param.dilate, tcol.dptr_); + Tensor ocol = col[n]; + ocol += tcol; + } + } + }); +} + +template +void Im2colGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + const Im2colParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + + const mxnet::TShape im_shape = outputs[0].shape_; + const mxnet::TShape col_shape = inputs[0].shape_; + const index_t num = im_shape[0]; + + const int spatial_size = param.kernel.ndim(); + mxnet::TShape col_buffer_shape(1 + spatial_size, 1); + col_buffer_shape[0] = col_shape[1]; + for (int i = 0; i < spatial_size; ++i) { + const index_t pad_size = im_shape[i + 2] + 2 * param.pad[i]; + const index_t output_size = (pad_size - param.DilatedKernelSize(i)) / param.stride[i] + 1; + col_buffer_shape[i + 1] = output_size; + } + + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor im_grad = outputs[0].get_with_shape( + Shape4(im_shape[0], im_shape[1], im_shape[2], im_shape[3]), s); + Tensor col_grad = inputs[0].get_with_shape( + Shape3(col_shape[0], col_shape[1], col_shape[2]), s); + + for (index_t n = 0; n < num; ++n) { + col2im(s, col_grad[n].dptr_, im_shape, col_buffer_shape, + param.kernel, param.pad, param.stride, param.dilate, + im_grad[n].dptr_, req[0]); + } + }); +} + +struct Col2imParam : public dmlc::Parameter { + mxnet::TShape output_size; + mxnet::TShape kernel; + mxnet::TShape stride; + mxnet::TShape dilate; + mxnet::TShape pad; + DMLC_DECLARE_PARAMETER(Col2imParam) { + DMLC_DECLARE_FIELD(output_size) + .describe("The spatial dimension of image array: (w,), (h, w) or (d, h, w)."); + DMLC_DECLARE_FIELD(kernel).describe("Sliding kernel size: (w,), (h, w) or (d, h, w)."); + DMLC_DECLARE_FIELD(stride).set_default(mxnet::TShape(0, 0)) + .describe("The stride between adjacent sliding blocks in spatial dimension: " + "(w,), (h, w) or (d, h, w). Defaults to 1 for each dimension."); + DMLC_DECLARE_FIELD(dilate).set_default(mxnet::TShape(0, 0)) + .describe("The spacing between adjacent kernel points: (w,), (h, w) or (d, h, w). " + "Defaults to 1 for each dimension."); + DMLC_DECLARE_FIELD(pad).set_default(mxnet::TShape(0, 0)) + .describe("The zero-value padding size on both sides of spatial dimension: " + "(w,), (h, w) or (d, h, w). Defaults to no padding."); + } + + index_t DilatedKernelSize(int dim) const { + return 1 + (kernel[dim] - 1) * dilate[dim]; + } +}; // struct Col2imParam + +template +void Col2imCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + const Col2imParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + const mxnet::TShape im_shape = outputs[0].shape_; + const mxnet::TShape col_shape = inputs[0].shape_; + const index_t num = im_shape[0]; + + const int spatial_size = param.kernel.ndim(); + mxnet::TShape col_buffer_shape(1 + spatial_size, 1); + col_buffer_shape[0] = col_shape[1]; + for (int i = 0; i < spatial_size; ++i) { + const index_t pad_size = im_shape[i + 2] + 2 * param.pad[i]; + const index_t output_size = (pad_size - param.DilatedKernelSize(i)) / param.stride[i] + 1; + col_buffer_shape[i + 1] = output_size; + } + + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor im = outputs[0].get_with_shape( + Shape4(im_shape[0], im_shape[1], im_shape[2], im_shape[3]), s); + Tensor col = inputs[0].get_with_shape( + Shape3(col_shape[0], col_shape[1], col_shape[2]), s); + + for (index_t n = 0; n < num; ++n) { + col2im(s, col[n].dptr_, im_shape, col_buffer_shape, + param.kernel, param.pad, param.stride, param.dilate, + im[n].dptr_, req[0]); + } + }); +} + +template +void Col2imGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + const Col2imParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + + const mxnet::TShape im_shape = inputs[0].shape_; + const mxnet::TShape col_shape = outputs[0].shape_; + const index_t batch_size = im_shape[0]; + + const int spatial_size = param.kernel.ndim(); + mxnet::TShape col_buffer_shape(1 + spatial_size, 1); + col_buffer_shape[0] = im_shape[1]; + for (int i = 0; i < spatial_size; ++i) { + const index_t pad_size = im_shape[i + 2] + 2 * param.pad[i]; + const index_t output_size = (pad_size - param.DilatedKernelSize(i)) / param.stride[i] + 1; + col_buffer_shape[i + 1] = output_size; + } + + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor im_grad = inputs[0].get_with_shape( + Shape4(im_shape[0], im_shape[1], im_shape[2], im_shape[3]), s); + Tensor col_grad = outputs[0].get_with_shape( + Shape3(col_shape[0], col_shape[1], col_shape[2]), s); + + if (req[0] == kNullOp) return; + if (req[0] != kAddTo) { + for (index_t n = 0; n < batch_size; ++n) { + im2col(s, im_grad[n].dptr_, im_shape, col_buffer_shape, + param.kernel, param.pad, param.stride, param.dilate, col_grad[n].dptr_); + } + } else { + Tensor tgrad = ctx.requested[0] + .get_space_typed(Shape2(col_shape[1], col_shape[2]), s); + for (index_t n = 0; n < batch_size; ++n) { + im2col(s, im_grad[n].dptr_, im_shape, col_buffer_shape, + param.kernel, param.pad, param.stride, param.dilate, tgrad.dptr_); + Tensor cgrad = col_grad[n]; + cgrad += tgrad; + } + } + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NN_IM2COL_INL_H_ diff --git a/src/operator/nn/im2col.cc b/src/operator/nn/im2col.cc new file mode 100644 index 000000000000..ae493f1bc594 --- /dev/null +++ b/src/operator/nn/im2col.cc @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file im2col.cc + * \brief + * \author Jiajun Wang +*/ + +#include "./im2col-inl.h" +#include "../operator_common.h" +#include "mxnet/op_attr_types.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(Im2colParam); +DMLC_REGISTER_PARAMETER(Col2imParam); + +template +void SlidingParser(nnvm::NodeAttrs* attrs) { + using namespace mshadow; + PType param_; + try { + param_.Init(attrs->dict); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + + if (param_.kernel.ndim() == 1) { + if (param_.stride.ndim() == 0) param_.stride = Shape1(1); + if (param_.dilate.ndim() == 0) param_.dilate = Shape1(1); + if (param_.pad.ndim() == 0) param_.pad = Shape1(0); + } else if (param_.kernel.ndim() == 2) { + if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1); + if (param_.dilate.ndim() == 0) param_.dilate = Shape2(1, 1); + if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0); + } else { + CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim() << "D convolution not supported"; + if (param_.stride.ndim() == 0) param_.stride = Shape3(1, 1, 1); + if (param_.dilate.ndim() == 0) param_.dilate = Shape3(1, 1, 1); + if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0); + } + CHECK_EQ(param_.kernel.ndim(), param_.stride.ndim()) + << "Stride must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while stride is " + << param_.stride; + CHECK_EQ(param_.kernel.ndim(), param_.dilate.ndim()) + << "Dilate must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while dilate is " + << param_.dilate; + CHECK_EQ(param_.kernel.ndim(), param_.pad.ndim()) + << "Padding must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param_.kernel << " while padding is " + << param_.pad; + attrs->parsed = std::move(param_); +} + +NNVM_REGISTER_OP(im2col) +.describe(R"(Extract sliding blocks from input array. + +This operator is used in vanilla convolution implementation to transform the sliding +blocks on image to column matrix, then the convolution operation can be computed +by matrix multiplication between column and convolution weight. Due to the close +relation between im2col and convolution, the concept of **kernel**, **stride**, +**dilate** and **pad** in this operator are inherited from convolution operation. + +Given the input data of shape :math:`(N, C, *)`, where :math:`N` is the batch size, +:math:`C` is the channel size, and :math:`*` is the arbitrary spatial dimension, +the output column array is always with shape :math:`(N, C \times \prod(\text{kernel}), W)`, +where :math:`C \times \prod(\text{kernel})` is the block size, and :math:`W` is the +block number which is the spatial size of the convolution output with same input parameters. +Only 1-D, 2-D and 3-D of spatial dimension is supported in this operator. + +)" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(SlidingParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", [](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape){ + using namespace mshadow; + CHECK_EQ(in_shape->size(), 1U); + const Im2colParam& param = nnvm::get(attrs.parsed); + if (mxnet::op::shape_is_none(in_shape->at(0))) { + return false; + } + + CHECK_GT(param.kernel.Size(), 0U) \ + << "incorrect kernel size: " << param.kernel; + CHECK_GT(param.stride.Size(), 0U) \ + << "incorrect stride size: " << param.stride; + CHECK_GT(param.dilate.Size(), 0U) \ + << "incorrect dilate size: " << param.dilate; + + index_t out_dim = 1; + mxnet::TShape dshape(in_shape->at(0)); + for (int i = 0; i < param.kernel.ndim(); ++i) { + const index_t pad_size = dshape[i + 2] + 2 * param.pad[i]; + const index_t dilated_kernel_size = param.DilatedKernelSize(i); + CHECK_LE(dilated_kernel_size, pad_size) + << "kernel size exceed input"; + const index_t output_size = (pad_size - dilated_kernel_size) / param.stride[i] + 1; + out_dim *= output_size; + } + SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape3(dshape[0], dshape[1] * param.kernel.Size(), out_dim)); + return true; +}) +.set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + CHECK_EQ(in_type->size(), 1U); + if (mxnet::op::type_is_none(in_type->at(0))) { + return false; + } + + int dtype = in_type->at(0); + TYPE_ASSIGN_CHECK(*out_type, 0, dtype); + return true; +}) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FCompute", Im2colCompute) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_im2col"}) +.add_argument("data", "NDArray-or-Symbol", "Input array to extract sliding blocks.") +.add_arguments(Im2colParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_im2col) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(SlidingParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", Im2colGradCompute); + +NNVM_REGISTER_OP(col2im) +.describe(R"(Combining the output column matrix of im2col back to image array. + +Like :class:`~mxnet.ndarray.im2col`, this operator is also used in the vanilla convolution +implementation. Despite the name, col2im is not the reverse operation of im2col. Since there +may be overlaps between neighbouring sliding blocks, the column elements cannot be directly +put back into image. Instead, they are accumulated (i.e., summed) in the input image +just like the gradient computation, so col2im is the gradient of im2col and vice versa. + +Using the notation in im2col, given an input column array of shape +:math:`(N, C \times \prod(\text{kernel}), W)`, this operator accumulates the column elements +into output array of shape :math:`(N, C, \text{output_size}[0], \text{output_size}[1], \dots)`. +Only 1-D, 2-D and 3-D of spatial dimension is supported in this operator. + +)" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(SlidingParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", [](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape){ + using namespace mshadow; + CHECK_EQ(in_shape->size(), 1U); + const Col2imParam& param = nnvm::get(attrs.parsed); + if (mxnet::op::shape_is_none(in_shape->at(0))) { + return false; + } + + CHECK_EQ(param.kernel.ndim(), param.output_size.ndim()) + << "Output size must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while output size is " + << param.output_size; + + CHECK_GT(param.output_size.Size(), 0U) \ + << "incorrect output size: " << param.output_size; + CHECK_GT(param.kernel.Size(), 0U) \ + << "incorrect kernel size: " << param.kernel; + CHECK_GT(param.stride.Size(), 0U) \ + << "incorrect stride size: " << param.stride; + CHECK_GT(param.dilate.Size(), 0U) \ + << "incorrect dilate size: " << param.dilate; + + const int spatial_size = param.kernel.ndim(); + mxnet::TShape dshape(in_shape->at(0)); + + index_t out_dim = 1; + for (int i = 0; i < spatial_size; ++i) { + const index_t pad_size = param.output_size[i] + 2 * param.pad[i]; + const index_t dilated_kernel_size = param.DilatedKernelSize(i); + CHECK_LE(dilated_kernel_size, pad_size) + << "kernel size exceed output size"; + const index_t output_size = (pad_size - dilated_kernel_size) / param.stride[i] + 1; + out_dim *= output_size; + } + + CHECK_EQ(dshape[2], out_dim) + << "output size does not match convolution parameters"; + CHECK_EQ(dshape[1] % param.kernel.Size(), 0) + << "the second dim of input shape should be multiples of kernel size"; + + mxnet::TShape oshape(param.kernel.ndim() + 2, 1); + oshape[0] = dshape[0]; + oshape[1] = dshape[1] / param.kernel.Size(); + for (int i = 0; i < spatial_size; ++i) { + oshape[i + 2] = param.output_size[i]; + } + SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape); + return true; +}) +.set_attr("FInferType", [](const nnvm::NodeAttrs& attrs, + std::vector *in_type, std::vector *out_type) { + CHECK_EQ(in_type->size(), 1U); + if (mxnet::op::type_is_none(in_type->at(0))) { + return false; + } + + int dtype = in_type->at(0); + TYPE_ASSIGN_CHECK(*out_type, 0, dtype); + return true; +}) +.set_attr("FCompute", Col2imCompute) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_col2im"}) +.add_argument("data", "NDArray-or-Symbol", "Input array to combine sliding blocks.") +.add_arguments(Col2imParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_col2im) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(SlidingParser) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FCompute", Col2imGradCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/im2col.cu b/src/operator/nn/im2col.cu new file mode 100644 index 000000000000..94d5b504611f --- /dev/null +++ b/src/operator/nn/im2col.cu @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file im2col.cu + * \brief + * \author Jiajun Wang +*/ + +#include "./im2col-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(im2col) +.set_attr("FCompute", Im2colCompute); + +NNVM_REGISTER_OP(_backward_im2col) +.set_attr("FCompute", Im2colGradCompute); + +NNVM_REGISTER_OP(col2im) +.set_attr("FCompute", Col2imCompute); + +NNVM_REGISTER_OP(_backward_col2im) +.set_attr("FCompute", Col2imGradCompute); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 66031d20d65b..d59c3063f95a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9374,6 +9374,153 @@ def check_random_uniform(): assertRaises(MXNetError, mx.nd.random_uniform, alpha, beta, shape) +@with_seed() +def test_im2col_col2im(): + def compute_output_size(spatial, kernel, stride=1, dilate=1, pad=0): + pad_size = spatial + 2 * pad + dilated_kernel = dilate * (kernel - 1) + 1 + return (pad_size - dilated_kernel) // stride + 1 + + def build_kwargs(kernel, stride=1, dilate=1, pad=0): + return {'kernel': (kernel, kernel), + 'stride': (stride, stride), + 'dilate': (dilate, dilate), + 'pad': (pad, pad)} + + # use im2col to compute convolution + def test_conv_compute(input_shape, num_filter, kernel, stride=1, dilate=1, pad=0): + batch_size = input_shape[0] + channel = input_shape[1] + kwargs = build_kwargs(kernel, stride, dilate, pad) + data = mx.nd.uniform(shape=input_shape) + col = mx.nd.im2col(data, **kwargs) + w = mx.nd.uniform(shape=(num_filter, channel, kernel, kernel)) + c1 = mx.nd.dot(col.transpose((0, 2, 1)), w.reshape(num_filter, -1).T).transpose((0, 2, 1)) + hos = compute_output_size(input_shape[2], kernel, stride, dilate, pad) + wos = compute_output_size(input_shape[3], kernel, stride, dilate, pad) + c1 = c1.reshape((batch_size, num_filter, hos, wos)) + + c2 = mx.nd.Convolution(data, num_filter=num_filter, weight=w, no_bias=True, **kwargs) + assert_almost_equal(c1.asnumpy(), c2.asnumpy(), rtol=1e-5, atol=1e-5) + + test_conv_compute( + input_shape = (5, 3, 30, 20), + num_filter = 10, + kernel = 3 + ) + + test_conv_compute( + input_shape = (5, 3, 30, 20), + num_filter = 10, + kernel = 3, + stride = 2 + ) + + test_conv_compute( + input_shape = (5, 3, 30, 20), + num_filter = 10, + kernel = 3, + stride = 2, + dilate = 2 + ) + + test_conv_compute( + input_shape = (5, 3, 30, 20), + num_filter = 10, + kernel = 3, + stride = 2, + dilate = 2, + pad = 1 + ) + + # use composite of im2col and col2im to reconstruct image + def test_reconstruct(input_shape, kernel, stride=1, dilate=1, pad=0): + batch_size = input_shape[0] + channel = input_shape[1] + kwargs = build_kwargs(kernel, stride, dilate, pad) + data = mx.nd.uniform(shape=input_shape) + col = mx.nd.im2col(data, **kwargs) + im1 = mx.nd.col2im(col, input_shape[2:], **kwargs) + + im2 = mx.nd.col2im(mx.nd.ones_like(col), input_shape[2:], **kwargs) * data + assert_almost_equal(im1.asnumpy(), im2.asnumpy(), rtol=1e-5, atol=1e-5) + + test_reconstruct( + input_shape = (5, 3, 30, 20), + kernel = 3 + ) + + test_reconstruct( + input_shape = (5, 3, 30, 20), + kernel = 3, + stride = 2 + ) + + test_reconstruct( + input_shape = (5, 3, 30, 20), + kernel = 3, + stride = 2, + dilate = 2 + ) + + test_reconstruct( + input_shape = (5, 3, 30, 20), + kernel = 3, + stride = 2, + dilate = 2, + pad = 1 + ) + + # test gradient + # the grad of im2col is col2im, and vice versa + def test_grad(input_shape, kernel, stride=1, dilate=1, pad=0): + # im2col + data = mx.sym.Variable('data') + kwargs = build_kwargs(kernel, stride, dilate, pad) + sym = mx.sym.im2col(data, **kwargs) + + im = mx.nd.uniform(shape=input_shape) + col = mx.nd.im2col(im, **kwargs) + col_shape = col.shape + expected = mx.nd.col2im(col, input_shape[2:], **kwargs) + check_symbolic_backward(sym, [im.asnumpy()], [col.asnumpy()], [expected.asnumpy()]) + + # col2im + data = mx.sym.Variable('data') + sym = mx.sym.col2im(data, input_shape[2:], **kwargs) + + col = mx.nd.uniform(shape=col_shape) + im = mx.nd.col2im(col, input_shape[2:], **kwargs) + expected = mx.nd.im2col(im, **kwargs) + check_symbolic_backward(sym, [col.asnumpy()], [im.asnumpy()], [expected.asnumpy()]) + + test_grad( + input_shape = (5, 3, 30, 20), + kernel = 3 + ) + + test_grad( + input_shape = (5, 3, 30, 20), + kernel = 3, + stride = 2 + ) + + test_grad( + input_shape = (5, 3, 30, 20), + kernel = 3, + stride = 2, + dilate = 2 + ) + + test_grad( + input_shape = (5, 3, 30, 20), + kernel = 3, + stride = 2, + dilate = 2, + pad = 1 + ) + + if __name__ == '__main__': import nose nose.runmodule()