Skip to content

Commit

Permalink
- build:
Browse files Browse the repository at this point in the history
	- unify schema building in core and converter;
	- add more build script for android;
	- add linux build script for python;

- ops impl:
	- add floor mod support in binary;
	- use eltwise impl in add/max/sub/mul binary for optimization;
	- remove fake double support in cast;
	- fix 5d support for concat;
	- add adjX and adjY support for batch matmul;
	- optimize conv2d back prop filter;
	- add pad mode support for conv3d;
	- fix bug in conv2d & conv depthwise with very small feature map;
	- optimize binary without broacast;
	- add data types support for gather;
	- add gather ND support;
	- use uint8 data type in gather v2;
	- add transpose support for matmul;
	- add matrix band part;
	- add dim != 4 support for padding, reshape & tensor convert;
	- add pad type support for pool3d;
	- make ops based on TensorFlow Lite quantization optional;
	- add all & any support for reduction;
	- use type in parameter as output type in reduction;
	- add int support for unary;
	- add variable weight support for conv2d;
	- fix conv2d depthwise weights initialization;
	- fix type support for transpose;
	- fix grad outputs count for  reduce grad and reshape grad;
	- fix priorbox & detection output;
	- fix metal softmax error;

- python:
	- add runSessionWithCallBackInfo interface;
	- add max nodes limit (1400) for visualization tool;
	- fix save error in python3;
	- align default dim;

- convert:
	- add extra design for optimization;
	- add more post converting optimizers;
	- add caffe v1 weights blob support;
	- add cast, unary, conv transpose support for onnx model;
	- optimize batchnorm, conv with variable weights, prelu, reshape, slice, upsample for onnx model;
	- add cos/sin/atan/tan support for unary for tensorflow model;
	- add any/all support for reduction for tensorflow model;
	- add elu, conv3d, pool3d support for tensorflow model;
	- optimize argmax, batchnorm, concat, batch to space, conv with variable weights, prelu, slice for tensorflow model;

- others:
	- fix size computer lock;
	- fix thread pool deadlock;
	- add express & parameters in express;
	- rewrite blitter chooser without static map;
	- add tests for expr;
  • Loading branch information
liqing committed Oct 29, 2019
1 parent 2808f97 commit d6b00d0
Show file tree
Hide file tree
Showing 333 changed files with 7,566 additions and 4,739 deletions.
17 changes: 13 additions & 4 deletions 3rd_party/flatbuffers/src/idl_gen_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class CppGenerator : public BaseGenerator {
// Iterate through all definitions we haven't generate code for (enums,
// structs, and tables) and output them to a single file.
bool generate() {
bool need_flatbuffer_include = true;
code_.Clear();
code_ += "// " + std::string(FlatBuffersGeneratedWarning()) + "\n\n";

Expand All @@ -245,11 +246,19 @@ class CppGenerator : public BaseGenerator {
if (parser_.opts.gen_nullable) {
code_ += "#pragma clang system_header\n\n";
}

code_ += "#include \"flatbuffers/flatbuffers.h\"";
if (parser_.uses_flexbuffers_) {
code_ += "#include \"flatbuffers/flexbuffers.h\"";
for (auto& iter : parser_.included_files_) {
if (!iter.second.empty()) {
need_flatbuffer_include = false;
}
}
need_flatbuffer_include = need_flatbuffer_include && parser_.native_included_files_.empty();
if (need_flatbuffer_include) {
code_ += "#include \"flatbuffers/flatbuffers.h\"";
if (parser_.uses_flexbuffers_) {
code_ += "#include \"flatbuffers/flexbuffers.h\"";
}
}

code_ += "";

if (parser_.opts.include_dependence_headers) { GenIncludeDependencies(); }
Expand Down
32 changes: 20 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,9 @@ enable_language(ASM)
# set(CMAKE_C_COMPILER gcc)
# set(CMAKE_CXX_COMPILER g++)

option(MNN_USE_CPP11 "Enable MNN use c++11" ON)

if (NOT MSVC)
if(MNN_USE_CPP11)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x")
endif()
endif()

# build options
option(MNN_USE_CPP11 "Enable MNN use c++11" ON)
option(MNN_BUILD_HARD "Build -mfloat-abi=hard or not" OFF)
option(MNN_BUILD_SHARED_LIBS "MNN build shared or static lib" ON)
option(MNN_FORBID_MULTI_THREAD "Disable Multi Thread" OFF)
Expand All @@ -31,6 +21,12 @@ option(MNN_BUILD_TRAIN "Build Train Tools" OFF)
option(MNN_BUILD_DEMO "Build demo/exec or not" OFF)
option(MNN_BUILD_QUANTOOLS "Build Quantized Tools or not" OFF)
option(MNN_EVALUATION "Build Evaluation Tools or not" OFF)
option(MNN_BUILD_CONVERTER "Build Converter" OFF)
option(MNN_SUPPORT_TFLITE_QUAN "Enable MNN's tflite quantized op" ON)
include(cmake/macros.cmake)
if (MNN_BUILD_CONVERTER)
add_subdirectory(tools/converter)
endif()

if (MNN_USE_THREAD_POOL)
set(MNN_OPENMP OFF)
Expand All @@ -40,6 +36,19 @@ endif()
if(MNN_FORBID_MULTI_THREAD)
add_definitions(-DMNN_FORBIT_MULTI_THREADS)
endif()
if(MNN_SUPPORT_TFLITE_QUAN)
add_definitions(-DMNN_SUPPORT_TFLITE_QUAN)
endif()

if (NOT MSVC)
if(MNN_USE_CPP11)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x")
endif()
endif()

# debug options
option(MNN_DEBUG "Enable MNN DEBUG" OFF)
Expand Down Expand Up @@ -84,7 +93,6 @@ if (NOT MNN_BUILD_TEST)
endif()
endif()

include(cmake/macros.cmake)

message(STATUS ">>>>>>>>>>>>>")
message(STATUS "MNN BUILD INFO:")
Expand Down
2 changes: 1 addition & 1 deletion benchmark/Readme_CN.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
基于表达式构建模型并进行benchmark:
cd /path/to/MNN
mkdir build && cd build
cmake -DMNN_SUPPORT_TRAIN=true -DMNN_BUILD_BENCHMARK=true ..
cmake -DMNN_BUILD_BENCHMARK=true ..
make -j8

运行以下命令查看help:
Expand Down
4 changes: 2 additions & 2 deletions benchmark/benchmarkExprModels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <sys/types.h>
#endif

#include "tools/converter/source/IR/MNN_generated.h"
#include "MNN_generated.h"
#include "MNNForwardType.h"
#include "Interpreter.hpp"
#include "Expr.hpp"
Expand Down Expand Up @@ -83,7 +83,7 @@ static void displayStats(const std::string& name, const std::vector<float>& cost

static std::vector<float> runNet(VARP netOutput, const ScheduleConfig& config, int loop) {
std::unique_ptr<NetT> netTable(new NetT);
netOutput->render(netTable.get());
Variable::save({netOutput}, netTable.get());
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = CreateNet(builder, netTable.get());
builder.Finish(offset);
Expand Down
35 changes: 25 additions & 10 deletions demo/exec/expressDemo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ static std::pair<VARP, VARP> _makeConvolution(int k, int ic, int oc, int size) {
auto input = _Input({1, ic, size, size}, NC4HW4);
return std::make_pair(input, _Conv(0.0f, 0.0f, input, {ic, oc}, {k, k}, SAME));
}
static std::pair<VARP, VARP> _makeGEMMByMatMul(int e, int l, int h) {
auto a = _Input({e, l});
std::vector<float> weight(l*h);
auto b = _Const(weight.data(), {l, h});
auto c = _MatMul(a, b);
return std::make_pair(a, c);
}

static std::pair<VARP, VARP> _makeGEMMByConvolution(int e, int l, int h) {
auto icC4 = UP_DIV(l);
Expand Down Expand Up @@ -105,6 +112,16 @@ static void _testGEMM() {
conv.second->unMap();
}
}
for (int i=0; i<size.size(); ++i) {
conv = _makeGEMMByMatMul(size[i][0], size[i][1], size[i][2]);
AUTOTIME;
for (int v=0; v<10; ++v) {
conv.first->writeMap<float>();
conv.first->unMap();
conv.second->readMap<float>();
conv.second->unMap();
}
}
}

int main(int argc, const char* argv[]) {
Expand All @@ -118,8 +135,11 @@ int main(int argc, const char* argv[]) {
if (argc >= 3) {
device = (Optimizer::Device)atoi(argv[2]);
}
auto model = Model::load(modelFileName);
auto model = Variable::loadMap(modelFileName);
auto inputOutput = Variable::getInputAndOutput(model);
auto optimizer = Optimizer::create(device);
auto inputs = inputOutput.first;
auto outputs = inputOutput.second;
if (nullptr == optimizer) {
MNN_ERROR("Can't find optimizer for %d\n", device);
return 0;
Expand All @@ -128,10 +148,10 @@ int main(int argc, const char* argv[]) {
if (argc >= 4) {
testTime = atoi(argv[3]);
}
optimizer->onExecute(model);
model.save("temp.mnn");
auto input = model.inputs[0];
auto output = model.outputs[0];
optimizer->onExecute(Variable::mapToSequence(outputs));
Variable::save(Variable::mapToSequence(outputs), "temp.mnn");
auto input = inputs.begin()->second;
auto output = outputs.begin()->second;
//input->resize({1, 224, 224, 3});
auto inputInfo = input->getInfo();
if (nullptr == inputInfo) {
Expand All @@ -144,11 +164,6 @@ int main(int argc, const char* argv[]) {
if (output->getInfo()->order == NC4HW4) {
output = _Convert(output, NCHW);
}
//Init
bool success = output->expr().first->requireAlloc();
if (!success) {
return 0;
}
}
auto outputInfo = output->getInfo();
if (nullptr == outputInfo) {
Expand Down
13 changes: 8 additions & 5 deletions demo/exec/segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "Expr.hpp"
#include "ExprCreator.hpp"
#include "AutoTime.hpp"
#include "Optimizer.hpp"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
Expand All @@ -32,13 +33,15 @@ int main(int argc, const char* argv[]) {
MNN_PRINT("Usage: ./segment.out model.mnn input.jpg output.jpg\n");
return 0;
}
auto net = Model::load(argv[1]);
if (net.inputs.empty() || net.outputs.empty()) {
auto net = Variable::getInputAndOutput(Variable::loadMap(argv[1]));
if (net.first.empty()) {
MNN_ERROR("Invalid Model\n");
return 0;
}
auto optimizer = Optimizer::create(Optimizer::CPU);
optimizer->onExecute(Variable::mapToSequence(net.second));

auto input = net.inputs[0];
auto input = net.first.begin()->second;
auto info = input->getInfo();
if (nullptr == info) {
MNN_ERROR("The model don't have init dim\n");
Expand All @@ -47,7 +50,7 @@ int main(int argc, const char* argv[]) {
auto shape = input->getInfo()->dim;
shape[0] = 1;
input->resize(shape);
auto output = net.outputs[0];
auto output = net.second.begin()->second;
if (nullptr == output->getInfo()) {
MNN_ERROR("Alloc memory or compute size error\n");
return 0;
Expand Down Expand Up @@ -103,7 +106,7 @@ int main(int argc, const char* argv[]) {
input->unMap();
}
{
auto originOrder = output->getInfo()->order;
//auto originOrder = output->getInfo()->order;
output = _Convert(output, NHWC);
//output = _Softmax(output, -1);
auto outputInfo = output->getInfo();
Expand Down
56 changes: 23 additions & 33 deletions express/include/Expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class MNN_EXPRESS_PUBLIC Variable {
int size;
void* ptr = nullptr;
};
void render(NetT* dest);
const std::string& name() const {
return mName;
}
Expand All @@ -77,58 +76,59 @@ class MNN_EXPRESS_PUBLIC Variable {
T* writeMap() {
return (T*)writeInternal();
}

//Depecerate
void unMap();
static void clone(VARP dst, VARP src);

bool input(VARP src);
static void replace(VARP dst, VARP src);

static VARP create(EXPRP expr, int index = 0);

void visitOutputs(const std::function<bool(VARP)>& visit);
void visitOutputs(const std::function<bool(VARP, int)>& visit);

static void visit(VARP var, const std::function<bool(VARP)>& before, const std::function<bool(VARP)>& after);

static std::vector<VARP> load(const char* fileName);
static std::map<std::string, VARP> loadMap(const char* fileName);
static std::pair<std::map<std::string, VARP>, std::map<std::string, VARP>> getInputAndOutput(const std::map<std::string, VARP>& allVariable);
static std::vector<VARP> mapToSequence(const std::map<std::string, VARP>& source);
static std::vector<VARP> getExecuteOrder(const std::vector<VARP>& output);
static void save(const std::vector<VARP>& vars, const char* fileName);
static void save(const std::vector<VARP>& vars, NetT* dest);

size_t linkNumber() const {
return mTo.size();
}
bool visited() const {
return mVisited;
}
void setVisited(bool visited) {
mVisited = visited;
}

private:
Variable(EXPRP expr, int index) {
mFrom = expr;
mFromIndex = index;
}

void* readInternal();
void* writeInternal();
void* writeInternal(bool inform=true);
void informDirty();

friend class Expr;
int mOutputIndex = -1;
EXPRP mFrom;
int mFromIndex;
std::string mName;
std::list<WeakEXPRP> mTo;
bool mVisited = false;
std::list<std::pair<int, WeakEXPRP>> mTo;
};

class MNN_EXPRESS_PUBLIC Expr {
public:
struct Inside;
static EXPRP create(std::unique_ptr<OpT>&& op, std::vector<VARP> inputs, int outputSize = 1,
static EXPRP create(const OpT* op, std::vector<VARP> inputs, int outputSize = 1,
std::shared_ptr<Executor> executor = nullptr);
static EXPRP create(std::unique_ptr<OpT>&& op, std::vector<VARP> inputs, int outputSize = 1,
std::shared_ptr<Executor> executor = nullptr) {
return create(op.get(), inputs, outputSize, executor);
}
void setName(const std::string& name);
void setExecutor(std::shared_ptr<Executor> exe);

// After render, the expr's op is removed
void render(NetT* dest);

const Op* get() const {
return mOp;
}
Expand Down Expand Up @@ -158,9 +158,12 @@ class MNN_EXPRESS_PUBLIC Expr {
void setVisited(bool visited) {
mVisited = visited;
}
const std::string& name() const {
return mName;
}

private:
bool setContentDirty();
bool setContentDirty(int inputIndex);
bool setInfoDirty();

Expr(int outputSize);
Expand All @@ -170,8 +173,8 @@ class MNN_EXPRESS_PUBLIC Expr {
std::vector<VARP> mInputs;
std::list<WeakVARP> mOutputs;
const int mOutputSize;
std::vector<int> mOutputIndexes;

bool mValid = true;
bool mInfoDirty = true;
bool mAllocated = false;
bool mContentDirty = true;
Expand All @@ -181,19 +184,6 @@ class MNN_EXPRESS_PUBLIC Expr {
bool mVisited = false;
std::shared_ptr<Executor> mExecutor;
};
class MNN_EXPRESS_PUBLIC Model {
public:
std::vector<VARP> inputs;
std::vector<VARP> outputs;

std::vector<VARP> sequence;

static Model load(const char* fileName);

// Re compute the sequence by outputs's execute order
void reorder();
void save(const char* fileName) const;
};
} // namespace Express
} // namespace MNN

Expand Down
2 changes: 2 additions & 0 deletions express/include/MathOp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ MNN_EXPRESS_PUBLIC VARP _ReduceMax(VARP x, INTS dim, bool keepDim = false);
MNN_EXPRESS_PUBLIC VARP _Sum(VARP x, INTS dim, bool keepDim = false);
MNN_EXPRESS_PUBLIC VARP _Mean(VARP x, INTS dim, bool keepDim = false);
MNN_EXPRESS_PUBLIC VARP _Prod(VARP x, INTS dim, bool keepDim = false);
MNN_EXPRESS_PUBLIC VARP _Any(VARP x, INTS dim, bool keepDim = false);
MNN_EXPRESS_PUBLIC VARP _All(VARP x, INTS dim, bool keepDim = false);
MNN_EXPRESS_PUBLIC VARP _MatMul(VARP a, VARP b, bool tranposeA = false, bool tranposeB = false);
MNN_EXPRESS_PUBLIC VARP _Normalize(VARP x, int32_t acrossSpatial, int32_t channelShared, float eps, std::vector<float> scale);

Expand Down
8 changes: 6 additions & 2 deletions express/include/NeuralNetWorkOp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace Express {
enum PaddingMode {CAFFE, VALID, SAME};
enum PoolingMode {MAXPOOL, AVEPOOL};
MNN_EXPRESS_PUBLIC VARP _Input(INTS dims = {}, Dimensionformat format = NC4HW4, halide_type_t type = halide_type_of<float>());
MNN_EXPRESS_PUBLIC VARP _Clone(VARP source, bool deepCopy=false);

MNN_EXPRESS_PUBLIC VARP _Const(float value, INTS dims = {}, Dimensionformat format = NHWC);
MNN_EXPRESS_PUBLIC VARP _Const(const void* ptr, INTS dims = {}, Dimensionformat format = NHWC,
halide_type_t type = halide_type_of<float>());
Expand All @@ -20,10 +22,12 @@ MNN_EXPRESS_PUBLIC VARP _Conv(VARP weight, VARP bias, VARP x, PaddingMode pad =
MNN_EXPRESS_PUBLIC VARP _Conv(float weight, float bias, VARP x, INTS channel, INTS kernelSize, PaddingMode pad = VALID,
INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1);
MNN_EXPRESS_PUBLIC VARP _Conv(std::vector<float>&& weight, std::vector<float>&& bias, VARP x, INTS channel, INTS kernelSize,
PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1);
PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0});
MNN_EXPRESS_PUBLIC VARP _Deconv(VARP weight, VARP bias, VARP x, PaddingMode pad = VALID, INTS stride = {1, 1},
INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0});
MNN_EXPRESS_PUBLIC VARP _MaxPool(VARP x, INTS kernel, INTS stride, PaddingMode pad = VALID, INTS pads= {0, 0});
MNN_EXPRESS_PUBLIC VARP _AvePool(VARP x, INTS kernel, INTS stride, PaddingMode pad = VALID, INTS pads= {0, 0});
MNN_EXPRESS_PUBLIC VARP _Reshape(VARP x, INTS dim);
MNN_EXPRESS_PUBLIC VARP _Reshape(VARP x, INTS dim, Dimensionformat format);
MNN_EXPRESS_PUBLIC VARP _Reshape(VARP x, VARP shape);
MNN_EXPRESS_PUBLIC VARP _Scale(VARP x, int channels, std::vector<float>&& scales, std::vector<float>&& bias);

Expand Down
Loading

0 comments on commit d6b00d0

Please sign in to comment.