Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ConvTranspose3D #2075

Merged
merged 1 commit into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions schema/current/MNN_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ enum OpType {
OpType_Conv2DBackPropFilter = 265,
OpType_TrainableParam = 266,
OpType_BatchNorm = 267,
OpType_ZeroGrad = 268,
OpType_ConvTranspose3D = 268,
OpType_ZeroGrad = 269,
OpType_Extra = 512,
OpType_ConvInt8 = 513,
OpType_Int8ToFloat = 514,
Expand All @@ -263,7 +264,7 @@ enum OpType {
OpType_MAX = OpType_GridSample
};

inline const OpType (&EnumValuesOpType())[173] {
inline const OpType (&EnumValuesOpType())[174] {
static const OpType values[] = {
OpType_AbsVal,
OpType_QuantizedAdd,
Expand Down Expand Up @@ -426,6 +427,7 @@ inline const OpType (&EnumValuesOpType())[173] {
OpType_Conv2DBackPropFilter,
OpType_TrainableParam,
OpType_BatchNorm,
OpType_ConvTranspose3D,
OpType_ZeroGrad,
OpType_Extra,
OpType_ConvInt8,
Expand Down Expand Up @@ -712,6 +714,7 @@ inline const char * const *EnumNamesOpType() {
"Conv2DBackPropFilter",
"TrainableParam",
"BatchNorm",
"ConvTranspose3D",
"ZeroGrad",
"",
"",
Expand Down Expand Up @@ -955,7 +958,6 @@ inline const char * const *EnumNamesOpType() {
"",
"",
"",
"",
"Extra",
"ConvInt8",
"Int8ToFloat",
Expand Down Expand Up @@ -7424,12 +7426,13 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 }
};
static const flatbuffers::TypeFunction type_refs[] = {
OpTypeTypeTable
};
static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 };
static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 512, 513, 514, 515, 516, 517, 518, 600, 601, 603, 604 };
static const char * const names[] = {
"AbsVal",
"QuantizedAdd",
Expand Down Expand Up @@ -7592,6 +7595,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
"Conv2DBackPropFilter",
"TrainableParam",
"BatchNorm",
"ConvTranspose3D",
"ZeroGrad",
"Extra",
"ConvInt8",
Expand All @@ -7606,7 +7610,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
"GridSample"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_ENUM, 173, type_codes, type_refs, values, names
flatbuffers::ST_ENUM, 174, type_codes, type_refs, values, names
};
return &tt;
}
Expand Down
2 changes: 1 addition & 1 deletion schema/default/MNN.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ enum OpType : int {
Conv2DBackPropFilter,
TrainableParam,
BatchNorm,

ConvTranspose3D,
// Use for self defined grad
ZeroGrad,

Expand Down
162 changes: 162 additions & 0 deletions source/geometry/GeometryConv3D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,171 @@ class GeometryConv3D : public GeometryComputer {
}
};

class GeometryConvTranspose3D : public GeometryConv3D {
public:
virtual bool
onCompute(const Op *op, const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs, Context &context,
CommandBuffer &res) const override {
auto input = inputs[0];
auto output = outputs[0];
MNN_ASSERT(TensorUtils::getDescribe(input)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
MNN_ASSERT(TensorUtils::getDescribe(output)->dimensionFormat != MNN_DATA_FORMAT_NHWC);
auto biasData = op->main_as_Convolution3D()->bias();
auto weightData = op->main_as_Convolution3D()->weight();
auto common = op->main_as_Convolution3D()->common();
auto kernels = common->kernels();
auto strides = common->strides();
auto pads = common->pads();
auto dialtes = common->dilates();
const int kernelDepth = kernels->Get(0), kernelHeight = kernels->Get(1), kernelWidth = kernels->Get(2);
const int strideDepth = strides->Get(0), strideHeight = strides->Get(1), strideWidth = strides->Get(2);
const int dialteDepth = dialtes->Get(0), dialteHeight = dialtes->Get(1), dialteWidth = dialtes->Get(2);
const int padDepth = pads->Get(0), padHeight = pads->Get(1), padWidth = pads->Get(2);
const int outputDepth = output->length(2), outputHeight = output->length(3), outputWidth = output->length(4);
const int inputDepth = input->length(2), inputHeight = input->length(3), inputWidth = input->length(4);
const int inputChannel = input->length(1), batch = input->length(0), outputChannel = output->length(1);

auto weightTensor = context.allocConst(op, {static_cast<int>(weightData->size())}, halide_type_of<float>());
::memcpy(weightTensor.get()->host<float>(), weightData->data(), weightData->size() * sizeof(float));
auto weight = weightTensor.get();
auto biasTensor = context.allocConst(op, {outputChannel}, halide_type_of<float>());
::memcpy(biasTensor.get()->host<float>(), biasData->data(), biasData->size() * sizeof(float));
auto bias = biasTensor.get();

Tensor *A = nullptr;
Tensor *B = nullptr;
{
// B: Input n, ic, id, ih, iw -> ic, n * id * ih * iw
std::shared_ptr<Tensor> dest(Tensor::createDevice<float>({inputChannel, batch * inputDepth * inputHeight * inputWidth}));
res.extras.emplace_back(dest);
B = dest.get();
auto des = TensorUtils::getDescribe(dest.get());
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
des->regions.resize(1);
auto& reg = des->regions[0];
reg.origin = input;
reg.size[0] = inputChannel;
reg.size[1] = batch;
reg.size[2] = inputDepth * inputHeight * inputWidth;
reg.src.offset = 0;
reg.src.stride[0] = inputDepth * inputHeight * inputWidth;
reg.src.stride[1] = inputChannel * inputDepth * inputHeight * inputWidth;
reg.src.stride[2] = 1;
reg.dst.offset = 0;
reg.dst.stride[0] = inputDepth * inputHeight * inputWidth * batch;
reg.dst.stride[1] = inputDepth * inputHeight * inputWidth;
reg.dst.stride[2] = 1;
}
{
// A: Weight oc, ic, kd, kh, kw -> oc, ic*kd*kh*kw
std::shared_ptr<Tensor> kernel(Tensor::createDevice<float>({inputChannel, outputChannel * kernelDepth * kernelHeight * kernelWidth}));
A = kernel.get();
GeometryComputerUtils::makeRawAddressRef(kernel.get(), weight, 0, inputChannel * kernelDepth * kernelHeight * kernelWidth * outputChannel);
res.extras.emplace_back(std::move(kernel));
}
{
// C = MatMul(B, A)
std::shared_ptr<Tensor> C(Tensor::createDevice<float>({outputChannel * kernelDepth * kernelHeight * kernelWidth, batch * inputDepth * inputHeight * inputWidth}));
res.command.emplace_back(GeometryComputerUtils::makeMatMul(A, B, C.get(), nullptr, true, false));
res.extras.emplace_back(C);

// Col2Im:
// 1. C-> C' batch, oc, oh, ow, kw*kh, 2. C' -> C'' batch, oc, oh, ow (reduce_sum)
// 3. C'' -> C'' + bias, 4. posttreat(C'' + bias)
std::shared_ptr<Tensor> C_(Tensor::createDevice<float>({1, batch*outputChannel*kernelDepth*kernelHeight*kernelWidth, batch * outputChannel * outputDepth * outputHeight * outputWidth}));
res.extras.emplace_back(C_);
{
std::shared_ptr<Tensor> im2ColTemp(Tensor::createDevice<float>({outputChannel * kernelDepth * kernelHeight * kernelWidth, batch * inputDepth * inputHeight * inputWidth}));
GeometryConvUtils::im2Col3d(im2ColTemp.get(), output, outputChannel, kernelDepth, kernelHeight, kernelWidth,
batch, inputDepth, inputHeight, inputWidth,
outputDepth, outputHeight,outputWidth,
strideDepth, strideHeight, strideWidth,
dialteDepth, dialteHeight, dialteWidth,
padDepth, padHeight, padWidth);
auto des = TensorUtils::getDescribe(C_.get());
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
auto originDes = TensorUtils::getDescribe(im2ColTemp.get());
des->regions = std::move(originDes->regions);
// Swap src and dst, from im2col3d->col2im3d
int idx = 0;
for (auto& reg : des->regions) {
reg.origin = C.get();
auto temp = reg.src;
reg.src = std::move(reg.dst);
reg.dst = std::move(temp);
reg.dst.offset += outputChannel * outputDepth * outputHeight * outputWidth * batch * idx;
idx++;
}
}
std::shared_ptr<Tensor> C__(Tensor::createDevice<float>({1, 1, batch * outputChannel * outputDepth * outputHeight * outputWidth}));
res.extras.emplace_back(C__);
res.command.emplace_back(GeometryComputerUtils::makeReduce(ReductionType_SUM, C_.get(), C__.get()));
{
std::shared_ptr<Tensor> biasLarge(Tensor::createDevice<float>({1, 1, batch * outputChannel * outputDepth * outputHeight * outputWidth}));
res.extras.emplace_back(biasLarge);
auto des = TensorUtils::getDescribe(biasLarge.get());
des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
des->regions.resize(1);
auto& reg = des->regions[0];
reg.origin = bias;
reg.size[0] = batch;
reg.size[1] = outputChannel;
reg.size[2] = outputDepth * outputHeight * outputWidth;
reg.src.offset = 0;
reg.src.stride[0] = 0;
reg.src.stride[1] = 1;
reg.src.stride[2] = 0;
reg.dst.offset = 0;
reg.dst.stride[0] = outputChannel * outputDepth * outputHeight * outputWidth;
reg.dst.stride[1] = outputDepth * outputHeight * outputWidth;
reg.dst.stride[2] = 1;
std::shared_ptr<Tensor> temp(Tensor::createDevice<float>({1, 1, batch * outputDepth * outputHeight * outputWidth * outputChannel}));
res.extras.emplace_back(temp);
res.command.emplace_back(GeometryComputerUtils::makeBinary(BinaryOpOperation_ADD, C__.get(), biasLarge.get(), temp.get()));
C__ = temp;
}

// Activation
float minValue = 0.0f, maxValue = 0.0f;
bool needPostTreat = false;
if (common->relu()) {
needPostTreat = true;
minValue = 0.0f;
maxValue = std::numeric_limits<float>().max();
}
if (common->relu6()) {
needPostTreat = true;
minValue = 0.0f;
maxValue = 6.0f;
}
if (needPostTreat) {
flatbuffers::FlatBufferBuilder builder;
builder.Finish(GeometryConvUtils::makeRelu6(builder, minValue, maxValue));
std::shared_ptr<Tensor> C2(new Tensor);
C2->buffer().type = halide_type_of<float>();
C2->buffer().dimensions = 3;
C2->setLength(0, 1);
C2->setLength(1, 1);
C2->setLength(2, batch * outputDepth * outputHeight * outputWidth * outputChannel);
TensorUtils::getDescribe(C2.get())->dimensionFormat = MNN_DATA_FORMAT_NCHW;
auto cmd = GeometryComputerUtils::makeCommand(builder, {C__.get()}, {C2.get()});
res.command.emplace_back(cmd);
res.extras.emplace_back(C2);
C__ = C2;
}
GeometryComputerUtils::makeRawAddressRef(outputs[0], C__.get(), 0, batch * outputChannel * outputDepth * outputHeight * outputWidth);
}
return true;
}
};

static void _create() {
std::shared_ptr<GeometryComputer> comp(new GeometryConv3D);
GeometryComputer::registerGeometryComputer(comp, {OpType_Convolution3D});

std::shared_ptr<GeometryComputer> comp2(new GeometryConvTranspose3D);
GeometryComputer::registerGeometryComputer(comp2, {OpType_ConvTranspose3D});

}

REGISTER_GEOMETRY(GeometryConv3D, _create);
Expand Down
58 changes: 58 additions & 0 deletions source/shape/ShapeConvTranspose3D.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//
// ShapeConvTranspose3D.cpp
// MNN
//
// Created by MNN on 2019/01/10.
// Copyright © 2018, Alibaba Group Holding Limited
//

#include <math.h>
#include "shape/SizeComputer.hpp"
#include "core/Macro.h"
#include "core/TensorUtils.hpp"
#include <iostream>
namespace MNN {
class ConvTranspose3DSizeComputer : public SizeComputer {
public:
virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override {
// MNN_ASSERT(1 == inputs.size());
MNN_ASSERT(1 == outputs.size());

auto layer = op->main_as_Convolution3D()->common();
auto input = inputs[0];
if (input->buffer().dimensions != 5) {
return false;
}

auto& outputBuffer = outputs[0]->buffer();
outputBuffer.dimensions = input->buffer().dimensions;
outputBuffer.dim[0].extent = input->buffer().dim[0].extent;
outputBuffer.dim[1].extent = layer->outputCount();

for (int i = 0; i < 3; ++i) {
const int inputLength = input->length(i + 2), stride = (*layer->strides())[i];
if (inputLength <= 0) {
return false;
}
int outputLength;
if (layer->padMode() == PadMode_SAME) {
outputLength = UP_DIV(inputLength, stride);
} else {
const int pad = (*layer->pads())[i], kernel = (*layer->kernels())[i], dialate = (*layer->dilates())[i];
const int dialatedKernel = (kernel - 1) * dialate + 1;
// outputLength = (inputLength + 2 * pad - dialatedKernel) / stride + 1;
outputLength = (inputLength - 1) * stride + dialatedKernel - 2*pad;
}
outputBuffer.dim[i + 2].extent = outputLength;
}

outputBuffer.type = input->getType();

TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat;
return true;
}
};

REGISTER_SHAPE(ConvTranspose3DSizeComputer, OpType_ConvTranspose3D);
} // namespace MNN
2 changes: 2 additions & 0 deletions source/shape/ShapeRegister.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ extern void ___SpaceToBatchNDSizeComputer__OpType_SpaceToBatchND__();
extern void ___PackComputer__OpType_Pack__();
extern void ___DeconvolutionSizeComputer__OpType_Deconvolution__();
extern void ___DeconvolutionSizeComputer__OpType_DeconvolutionDepthwise__();
extern void ___ConvTranspose3DSizeComputer__OpType_ConvTranspose3D__();

void registerShapeOps() {
___ShapeSizeComputer__OpType_Shape__();
Expand Down Expand Up @@ -216,5 +217,6 @@ ___SpaceToBatchNDSizeComputer__OpType_SpaceToBatchND__();
___PackComputer__OpType_Pack__();
___DeconvolutionSizeComputer__OpType_Deconvolution__();
___DeconvolutionSizeComputer__OpType_DeconvolutionDepthwise__();
___ConvTranspose3DSizeComputer__OpType_ConvTranspose3D__();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ static EXPRP _transformConv3D(EXPRP expr) {
auto& weightShape = weightInfo->dim;

auto extraParam = expr->get()->main_as_Extra();

std::string originalOpType(extraParam->type()->c_str());
bool isDeconv = originalOpType == "ConvTranspose";
int co = weightShape[0];
int ci = weightShape[1];
int depth = weightShape[2];
int kh = weightShape[3];
int kw = weightShape[4];

if (isDeconv) {
co = weightShape[1];
ci = weightShape[0];
}
std::unique_ptr<Convolution3DT> conv3d(new MNN::Convolution3DT);
const float* weightDataPtr = weight->readMap<float>();
conv3d->weight.resize(weightInfo->size);
Expand Down Expand Up @@ -81,13 +85,18 @@ static EXPRP _transformConv3D(EXPRP expr) {
}

common->relu = common->relu6 = false;
common->outputCount = co;
common->inputCount = ci * common->group;
if (isDeconv) {
common->outputCount = co * common->group; // deconv set inputCount to be ci, dw to be group
common->inputCount = ci;
} else {
common->outputCount = co;
common->inputCount = ci * common->group; // conv set inputCount to be ci, dw to be group
}
common->kernels = std::vector<int>({depth, kh, kw});

std::unique_ptr<OpT> newOp(new OpT);
newOp->name = expr->name();
newOp->type = OpType_Convolution3D;
newOp->type = isDeconv ? OpType_ConvTranspose3D : OpType_Convolution3D;
newOp->main.type = OpParameter_Convolution3D;
newOp->main.value = conv3d.release();

Expand Down