Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf). I don't need a Star, but give me a pull request.
https://github.com/PINTO0309/onnx2tf/wiki/model_status
-
onnx-tensorflow is a very useful tool, but the performance of the generated TensorFlow models is significantly degraded due to the extrapolation of a large number of
Transpose
OPs before and after each OP during the format conversion fromNCHW
toNHWC
. Therefore, I will make this tool myself as a derivative tool of onnx-tensorflow without extrapolatingTranspose
. -
Most of the internal processing of the tool is full-scratch, but some of the more complex OPs have been adapted from onnx-tensorflow. I am very grateful to the engineers at International Business Machines Corporation / LeapMind / Microsoft / IBM for developing onnx-tensorflow.
-
I have incorporated all my knowledge of model optimization to other models such as TFLite, EdgeTPU, TensorFlow.js and Myriad based on my years of experience implementing openvino2tensorflow and tflite2tensorflow. It probably has the best model optimization performance and conversion efficiency of any tool I have created in the past, and the lowest rate of conversion errors.
-
Supported layers list. Supported layers
-
If you are having trouble with conversion errors, searching for resolved or open issues will almost always solve your problems. Issues are knowledge for engineers around the world.
-
Contributors to this repository should first read Contribution Guide.
Kazam_screencast_00065_.mp4
-
All OPs are decomposed into primitive operations as much as possible. This is beneficial for lateral deployment of models to frameworks other than TFLite. Therefore, OPs belonging to
tf.keras.layers
are almost never used, and the tool consists only oftf.xxx
. (except for a very few OPs) -
As I do not want to add more dependent packages, I do not use
tensorflow_addons (tfa)
, but replace it with the standard OP of tensorflow. -
Not only does it handle conversions of 4-dimensional inputs, such as
NCHW
toNHWC
, but also the number of input dimensions in 3, 5, or even more dimensions. For example,NCDHW
toNDHWC
, etc. However, since 1-D, 2-D, 3-D and 6-D input may produce patterns that are mechanically difficult to convert, it should be possible to give parameters to externally modify the tool's behavior. See Parameter replacement -
If there are undefined dimensions in the input OP, the model structure is not fully optimized and conversion errors are very likely to occur.
-
Immediately following a
Reshape
OP with dimensional compression and dimensional decompression, there is a 95% probability that the model transformation operation will be disrupted and errors will occur. For example, patterns such as[1,200,200,5]
->[1,200,-1]
or[10,20,30,40,50]
->[10,2,10,30,10,4,50]
orFlatten
. See #8 Not able to reshape input in replace.json, or #15 Conv layer shape wrong, or #18 Question about channel_transpose in common_functions.py -
TensorFlow's Convolution does not have an equivalent operation to ONNX's Padding operation. Therefore, a
Pad
OP is inserted immediately before a Convolution with Padding of size greater than 1. -
Support conversion to TensorFlow saved model and TFLite (Float32/Float16/INT8).
-
Files exceeding the Protocol Buffers file size limit of 2GB are not supported. Therefore, the external format is not supported at the initial stage of tool creation.
-
If there are ONNX OPs that are not supported by TensorFlow, use simple-onnx-processing-tools to replace them with harmless OPs in advance and then use this tool to convert them. In other words, you can convert any model with your efforts.
-
ONNX splitting, merging, generating OPs, rewriting OP attributes, BGR<->RGB conversion, converting to JSON and editing in the IDE, batch size changes for undefined dimensions, and various other processing can be done with the simple-onnx-processing-tools. Therefore, it is recommended that models with very complex structures be converted to TFLite after modifying the structure beforehand.
-
BatchNormalization
supports only inference mode. -
LayerNormalization
supports only inference mode. -
Only for
opset=11
or higher -
If you do not like the generated TFLite OP name, edit it using tflite2json2tflite.
-
The generated Keras models cannot be used for retraining. If you want to train, you must build your own model.
-
When converting to TensorFlow.js, CoreML, etc., please generate saved_model with the
--output_signaturedefs
option and use the generated saved_model to convert with various converters. tensorflowjs_converter, coremltools, edgetpu_compilier -
There are many OPs on ONNX that do not support EdgeTPU. Therefore, if you need to generate an EdgeTPU model, please specify
--replace_***_to_pseudo_***
to convert your model. onnx2tf will attempt to replace the OP with an EdgeTPU-compatible OP whenever possible. -
The main factors that cause accuracy degradation after model conversion are as follows
- differences in Padding specifications
- difference in Python division specification in the process of model transformation (error due to even rounding)
- Divide epsilon without consideration
- deprecated TrueDivision
- support difference of powers
- differences in interpolation operation specifications during resizing
- Difference in arithmetic precision supported by each operation
- Calculation error due to scaling up or down by specifying a
scale
when resizing images
The above differences often cannot be dealt with by simply converting the model in a straightforward manner. Therefore, you need to replace the model yourself in advance with an operation that is less prone to errors.
- Support for
INT8 Quantization
,Full INT8 Quantization
,INT8 Quantization with INT16 activation
,Full INT8 Quantization with INT16 activation
andDynamic Range Quantization
. - Support for
Per-Channel Quantization
andPer-Tensor Quantization
. - Support for
GroupConvolution
. - TFLite does not support
TrueDiv
(INT), soTrueDiv
is avoided if possible. - Implement the
Resize
process for the 5D tensor. - Add process to replace
Asin
withpseudo-Asin
. - Add process to replace
Acos
withpseudo-Acos
. - Add process to replace
Abs
withpseudo-Abs
. - Add process to replace
GatherND
withpseudo-GatherND
. - Add process to replace
HardSwish
withpseudo-HardSwish
. - Add process to replace
GridSample
withpseudo-GridSample
. - Add process to replace
PRelu
withpseudo-PRelu
. - Add process to replace
LeakyRelu
withpseudo-LeakyRelu
. - Add process to replace
Power
withpseudo-Power
. - Add process to replace
Neg
withpseudo-Neg
. - Add process to replace
ArgMax
withpseudo-ArgMax
. - Add process to replace
Erf
withpseudo-Erf
. - Added option to fix dynamic batch size
N
to a specified number. - Added option to overwrite dynamic shape input OPs with static shape.
--overwrite_input_shape
- Output in Keras H5 format.
- Automatically run onnx-simplifier (onnxsim) backend and optimize onnx files before model transformation.
- Added the ability to automatically generate each OP name and assign OP names to ONNX files in the old format.
- Supports model splitting. Interrupts model transformation at the specified output name and outputs the model partitioned into subgraphs.
Video speed is adjusted approximately 50 times slower than actual speed.
- onnx
- onnx-simplifier
- onnx_graphsurgeon
- simple_onnx_processing_tools
- tensorflow==2.10.0
- HostPC
$ docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ ghcr.io/pinto0309/onnx2tf:1.3.14 or $ pip install -U onnx \ && pip install -U nvidia-pyindex \ && pip install -U onnx-graphsurgeon \ && pip install -U onnxsim \ && pip install -U simple_onnx_processing_tools \ && pip install -U onnx2tf or $ pip install -e .
or
- Google Colaboratory Python3.8+
!sudo add-apt-repository -y ppa:deadsnakes/ppa !sudo apt-get -y update !sudo apt-get -y install python3.9 !sudo apt-get -y install python3.9-dev !sudo apt-get -y install python3-pip !sudo apt-get -y install python3.9-distutils !python3.9 -m pip install -U setuptools \ && python3.9 -m pip install -U pip \ && python3.9 -m pip install -U distlib !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.7 1 !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2 !python3.9 -m pip install tensorflow==2.10.0 \ && python3.9 -m pip install -U onnx \ && python3.9 -m pip install -U nvidia-pyindex \ && python3.9 -m pip install -U onnx-graphsurgeon \ && python3.9 -m pip install -U onnxsim \ && python3.9 -m pip install -U simple_onnx_processing_tools \ && python3.9 -m pip install -U onnx2tf \ && python3.9 -m pip install -U protobuf==3.20.3
Run test.
# Float32, Float16
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/0.0.2/resnet18-v1-7.onnx
$ onnx2tf -i resnet18-v1-7.onnx
# INT8 Quantization
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.1/emotion-ferplus-8.onnx
# INT8 Quantization (per-channel)
$ onnx2tf -i emotion-ferplus-8.onnx -oiqt
# INT8 Quantization (per-tensor)
$ onnx2tf -i emotion-ferplus-8.onnx -oiqt -qt per-tensor
# Parameter replacement (Resize,Transpose,Softmax)
$ rm replace.json
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.27/human_segmentation_pphumanseg_2021oct.onnx
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.27/replace.json
$ onnx2tf -i human_segmentation_pphumanseg_2021oct.onnx -prf replace.json
$ onnx2tf -h
usage: onnx2tf
[-h]
(-i INPUT_ONNX_FILE_PATH | -V)
[-o OUTPUT_FOLDER_PATH]
[-osd]
[-oh5]
[-oiqt]
[-qt {per-channel,per-tensor}]
[-qcind INPUT_NAME NUMPY_FILE_PATH MEAN STD]
[-ioqd {int8,uint8}]
[-nuo]
[-nuonag]
[-b BATCH_SIZE]
[-ois OVERWRITE_INPUT_SHAPE [OVERWRITE_INPUT_SHAPE ...]]
[-k KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES [KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES ...]]
[-kt KEEP_NWC_OR_NHWC_OR_NDHWC_INPUT_NAMES [KEEP_NWC_OR_NHWC_OR_NDHWC_INPUT_NAMES ...]]
[-kat KEEP_SHAPE_ABSOLUTELY_INPUT_NAMES [KEEP_SHAPE_ABSOLUTELY_INPUT_NAMES ...]]
[-onimc OUTPUT_NAMES [OUTPUT_NAMES ...]]
[-dgc]
[-ebu]
[-rari64 | -rarf32 | -rafi64 | -raff32]
[-fasr FUSED_ARGMAX_SCALE_RATIO]
[-rasin]
[-racos]
[-rabs]
[-rpr]
[-rlr]
[-rpw]
[-rgn]
[-rng]
[-rhs]
[-rerf]
[-me]
[-prf PARAM_REPLACEMENT_FILE]
[-n]
optional arguments:
-h, --help
show this help message and exit
-i INPUT_ONNX_FILE_PATH, --input_onnx_file_path INPUT_ONNX_FILE_PATH
Input onnx file path.
-V, --version
Show version and exit.
-o OUTPUT_FOLDER_PATH, --output_folder_path OUTPUT_FOLDER_PATH
Output folder path. Default: "saved_model"
-osd, --output_signaturedefs
Signature is added to the output for serving or for conversion
to other model formats. However, this can significantly reduce the speed
of model conversion and significant increase the size of the model.
-oh5, --output_h5
Output in Keras H5 format.
-oiqt, --output_integer_quantized_tflite
Output of integer quantized tflite.
-qt {per-channel,per-tensor}, --quant_type {per-channel,per-tensor}
Selects whether "per-channel" or "per-tensor" quantization is used.
Default: "per-channel"
-qcind INPUT_NAME NUMPY_FILE_PATH MEAN STD, \
--quant_calib_input_op_name_np_data_path INPUT_NAME NUMPY_FILE_PATH MEAN STD
INPUT Name of OP and path of calibration data file (Numpy) for quantization and mean and std.
The specification can be omitted only when the input OP is a single 4D tensor image data.
If omitted, it is automatically calibrated using 20 normalized MS-COCO images.
The type of the input OP must be Float32.
Data for calibration must be pre-normalized to a range of 0 to 1.
-qcind {input_op_name} {numpy_file_path} {mean} {std}
Numpy file paths must be specified the same number of times as the number of input OPs.
Normalize the value of the input OP based on the tensor specified in mean and std.
(input_value - mean) / std
Tensors in Numpy file format must be in dimension order after conversion to TF.
Note that this is intended for deployment on low-resource devices,
so the batch size is limited to 1 only.
e.g.
The example below shows a case where there are three input OPs.
Assume input0 is 128x128 RGB image data.
In addition, input0 should be a value that has been divided by 255
in the preprocessing and normalized to a range between 0 and 1.
input1 and input2 assume the input of something that is not an image.
Because input1 and input2 assume something that is not an image,
the divisor is not 255 when normalizing from 0 to 1.
"n" is the number of calibration data.
ONNX INPUT shapes:
input0: [n,3,128,128]
mean: [1,3,1,1] -> [[[[0.485]],[[0.456]],[[0.406]]]]
std : [1,3,1,1] -> [[[[0.229]],[[0.224]],[[0.225]]]]
input1: [n,64,64]
mean: [1,64] -> [[0.1, ..., 0.64]]
std : [1,64] -> [[0.05, ..., 0.08]]
input2: [n,5]
mean: [1] -> [0.3]
std : [1] -> [0.07]
TensorFlow INPUT shapes (Numpy file ndarray shapes):
input0: [n,128,128,3]
mean: [1,1,1,3] -> [[[[0.485, 0.456, 0.406]]]]
std : [1,1,1,3] -> [[[[0.229, 0.224, 0.225]]]]
input1: [n,64,64]
mean: [1,64] -> [[0.1, ..., 0.64]]
std : [1,64] -> [[0.05, ..., 0.08]]
input2: [n,5]
mean: [1] -> [0.3]
std : [1] -> [0.07]
-qcind "input0" "../input0.npy" [[[[0.485, 0.456, 0.406]]]] [[[[0.229, 0.224, 0.225]]]]
-qcind "input1" "./input1.npy" [[0.1, ..., 0.64]] [[0.05, ..., 0.08]]
-qcind "input2" "input2.npy" [0.3] [0.07]
-ioqd {int8,uint8}, --input_output_quant_dtype {int8,uint8}
Input and Output dtypes when doing Full INT8 Quantization.
"int8"(default) or "uint8"
-nuo, --not_use_onnxsim
No optimization by onnx-simplifier is performed.
If this option is used, the probability of a conversion error is very high.
-nuonag, --not_use_opname_auto_generate
Automatic generation of each OP name in the old format ONNX file
and assignment of OP name are not performed.
-b BATCH_SIZE, --batch_size BATCH_SIZE
Fixes the dynamic batch size to the specified numeric batch size.
A value of 1 or more must be specified.
-ois OVERWRITE_INPUT_SHAPE [OVERWRITE_INPUT_SHAPE ...], \
--overwrite_input_shape OVERWRITE_INPUT_SHAPE [OVERWRITE_INPUT_SHAPE ...]
Overwrite the input shape.
The format is
"i1:dim0,...,dimN" "i2:dim0,...,dimN" "i3:dim0,...,dimN"
When there is only one input, for example,
"data:1,3,224,224"
When there are multiple inputs, for example,
"data1:1,3,224,224" "data2:1,3,112" "data3:5"
A value of 1 or more must be specified.
Numerical values other than dynamic dimensions are ignored.
Ignores --batch_size if specified at the same time as --batch_size.
-k KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES [KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES ...], \
--keep_ncw_or_nchw_or_ncdhw_input_names KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES \
[KEEP_NCW_OR_NCHW_OR_NCDHW_INPUT_NAMES ...]
Holds the NCW or NCHW or NCDHW of the input shape for the specified INPUT OP names.
If a nonexistent INPUT OP name is specified, it is ignored.
Valid only for 3D, 4D and 5D input tensors.
e.g. --keep_ncw_or_nchw_or_ncdhw_input_names "input0" "input1" "input2"
-kt KEEP_NWC_OR_NHWC_OR_NDHWC_INPUT_NAMES [KEEP_NWC_OR_NHWC_OR_NDHWC_INPUT_NAMES ...], \
--keep_nwc_or_nhwc_or_ndhwc_input_names KEEP_NWC_OR_NHWC_OR_NDHWC_INPUT_NAMES \
[KEEP_NWC_OR_NHWC_OR_NDHWC_INPUT_NAMES ...]
Holds the NWC or NHWC or NDHWC of the input shape for the specified INPUT OP names.
If a nonexistent INPUT OP name is specified, it is ignored.
If the input OP name is the same as the input OP name specified
in the keep_ncw_or_nchw_or_ncdhw_input_names option, it is ignored.
Valid only for 3D, 4D and 5D input tensors.
e.g. --keep_nwc_or_nhwc_or_ndhwc_input_names "input0" "input1" "input2"
-kat KEEP_SHAPE_ABSOLUTELY_INPUT_NAMES [KEEP_SHAPE_ABSOLUTELY_INPUT_NAMES ...], \
--keep_shape_absolutely_input_names KEEP_SHAPE_ABSOLUTELY_INPUT_NAMES \
[KEEP_SHAPE_ABSOLUTELY_INPUT_NAMES ...]
Name of the INPUT that unconditionally maintains its shape.
If a nonexistent INPUT OP name is specified, it is ignored.
e.g. --keep_shape_absolutely_input_names "input0" "input1" "input2"
-onimc OUTPUT_NAMES [OUTPUT_NAMES ...], \
--output_names_to_interrupt_model_conversion OUTPUT_NAMES [OUTPUT_NAMES ...]
Output names that interrupt model conversion.
Interrupts model transformation at the specified output name and outputs the
model partitioned into subgraphs.
e.g. --output_names_to_interrupt_model_conversion "output0" "output1" "output2"
-dgc, --disable_group_convolution
Disable GroupConvolution and replace it with SeparableConvolution for
output to saved_model format.
-ebu, --enaable_batchmatmul_unfold
BatchMatMul is separated batch by batch to generate a primitive MatMul.
-rari64, --replace_argmax_to_reducemax_and_indicies_is_int64
Replace ArgMax with a ReduceMax. The returned indicies are int64.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64
and replace_argmax_to_reducemax_and_indicies_is_float32
and replace_argmax_to_fused_argmax_and_indicies_is_int64
and replace_argmax_to_fused_argmax_and_indicies_is_float32 can be specified.
-rarf32, --replace_argmax_to_reducemax_and_indicies_is_float32
Replace ArgMax with a ReduceMax. The returned indicies are float32.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64
and replace_argmax_to_reducemax_and_indicies_is_float32
and replace_argmax_to_fused_argmax_and_indicies_is_int64
and replace_argmax_to_fused_argmax_and_indicies_is_float32 can be specified.
-rafi64, --replace_argmax_to_fused_argmax_and_indicies_is_int64
Replace ArgMax with a Fused_ArgMax. The returned indicies are int64.
It improves inference speed at the cost of a small sacrifice in accuracy.
See. https://github.com/tensorflow/models/tree/master/official/projects/edgetpu/vision#argmax-fusion-to-improve-segmentation-model-latency
Currently, only 4D tensors are supported.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64
and replace_argmax_to_reducemax_and_indicies_is_float32
and replace_argmax_to_fused_argmax_and_indicies_is_int64
and replace_argmax_to_fused_argmax_and_indicies_is_float32 can be specified.
-raff32, --replace_argmax_to_fused_argmax_and_indicies_is_float32
Replace ArgMax with a Fused_ArgMax. The returned indicies are float32.
It improves inference speed at the cost of a small sacrifice in accuracy.
See. https://github.com/tensorflow/models/tree/master/official/projects/edgetpu/vision#argmax-fusion-to-improve-segmentation-model-latency
Currently, only 4D tensors are supported.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64
and replace_argmax_to_reducemax_and_indicies_is_float32
and replace_argmax_to_fused_argmax_and_indicies_is_int64
and replace_argmax_to_fused_argmax_and_indicies_is_float32 can be specified.
-fasr FUSED_ARGMAX_SCALE_RATIO, --fused_argmax_scale_ratio FUSED_ARGMAX_SCALE_RATIO
For Fused ArgMax.
Scale ratio when generating Fused ArgMax.
0.0 < fused_argmax_scale_ratio <= 1.0
Default: 0.5
-rasin, --replace_asin_to_pseudo_asin
Replace Asin with a pseudo Asin.
-racos, --replace_acos_to_pseudo_acos
Replace Acos with a pseudo Acos.
-rabs, --replace_abs_to_pseudo_abs
Replace Abs with a pseudo Abs.
-rpr, --replace_prelu_to_pseudo_prelu
Replace PReLU with a pseudo PReLU.
-rlr, --replace_leakyrelu_to_pseudo_leakyrelu
Replace LeakyReLU with a pseudo LeakyReLU.
-rpw, --replace_power_to_pseudo_power
Replace Power with a pseudo Power.
-rgn, --replace_gathernd_to_pseudo_gathernd
Replace GatherND with a pseudo GatherND.
-rng, --replace_neg_to_pseudo_neg
Replace Neg with a pseudo Neg.
-rhs, --replace_hardswish_to_pseudo_hardswish
Replace HardSwish with a pseudo HardSwish.
-rerf, --replace_erf_to_pseudo_erf
Replace Erf with a pseudo Erf.
-me, --mvn_epsilon
For MeanVarianceNormalization.
The number to be added to the variance to avoid division by zero
when normalizing the value.
(input_tensor - mean) / tf.sqrt(variance + mvn_epsilon)
Default: 0.0000000001
-prf PARAM_REPLACEMENT_FILE, --param_replacement_file PARAM_REPLACEMENT_FILE
Parameter replacement file path. (.json)
-n, --non_verbose
Do not show all information logs. Only error logs are displayed.
>>> from onnx2tf import convert
>>> help(convert)
Help on function convert in module onnx2tf:
convert(
input_onnx_file_path: Union[str, NoneType] = '',
onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None,
output_folder_path: Union[str, NoneType] = 'saved_model',
output_signaturedefs: Optional[bool] = False,
output_h5: Optional[bool] = False,
output_integer_quantized_tflite: Optional[bool] = False,
quant_type: Optional[str] = 'per-channel',
quant_calib_input_op_name_np_data_path: Optional[List] = None,
input_output_quant_dtype: Optional[str] = 'int8',
not_use_onnxsim: Optional[bool] = False,
not_use_opname_auto_generate: Optional[bool] = False,
batch_size: Union[int, NoneType] = None,
overwrite_input_shape: Union[List[str], NoneType] = None,
keep_ncw_or_nchw_or_ncdhw_input_names: Union[List[str], NoneType] = None,
keep_nwc_or_nhwc_or_ndhwc_input_names: Union[List[str], NoneType] = None,
keep_shape_absolutely_input_names: Optional[List[str]] = None,
output_names_to_interrupt_model_conversion: Union[List[str], NoneType] = None,
disable_group_convolution: Union[bool, NoneType] = False,
enaable_batchmatmul_unfold: Optional[bool] = False,
replace_argmax_to_reducemax_and_indicies_is_int64: Union[bool, NoneType] = False,
replace_argmax_to_reducemax_and_indicies_is_float32: Union[bool, NoneType] = False,
replace_argmax_to_fused_argmax_and_indicies_is_int64: Union[bool, NoneType] = False,
replace_argmax_to_fused_argmax_and_indicies_is_float32: Union[bool, NoneType] = False,
fused_argmax_scale_ratio: Union[float, NoneType] = 0.5,
replace_asin_to_pseudo_asin: Union[bool, NoneType] = False,
replace_acos_to_pseudo_acos: Union[bool, NoneType] = False,
replace_abs_to_pseudo_abs: Union[bool, NoneType] = False,
replace_prelu_to_pseudo_prelu: Union[bool, NoneType] = False,
replace_leakyrelu_to_pseudo_leakyrelu: Union[bool, NoneType] = False,
replace_power_to_pseudo_power: Optional[bool] = False,
replace_gathernd_to_pseudo_gathernd: Optional[bool] = False,
replace_neg_to_pseudo_neg: Optional[bool] = False,
replace_hardswish_to_pseudo_hardswish: Optional[bool] = False,
replace_erf_to_pseudo_erf: Optional[bool] = False,
mvn_epsilon: Union[float, NoneType] = 0.0000000001,
param_replacement_file: Optional[str] = '',
non_verbose: Union[bool, NoneType] = False
) -> keras.engine.training.Model
Convert ONNX to TensorFlow models.
Parameters
----------
input_onnx_file_path: Optional[str]
Input onnx file path.
Either input_onnx_file_path or onnx_graph must be specified.
onnx_graph: Optional[onnx.ModelProto]
onnx.ModelProto.
Either input_onnx_file_path or onnx_graph must be specified.
onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph.
output_folder_path: Optional[str]
Output tensorflow model folder path.
Default: "saved_model"
output_signaturedefs: Optional[bool]
Signature is added to the output for serving or for conversion
to other model formats. However, this can significantly reduce the speed
of model conversion and significant increase the size of the model.
output_h5: Optional[bool]
Output in Keras H5 format.
output_integer_quantized_tflite: Optional[bool]
Output of integer quantized tflite.
quant_type: Optional[str]
Selects whether "per-channel" or "per-tensor" quantization is used.
Default: "per-channel"
quant_calib_input_op_name_np_data_path: Optional[List]
--quant_calib_input_op_name_np_data_path INPUT_NAME NUMPY_FILE_PATH MEAN STD
INPUT Name of OP and path of calibration data file (Numpy) for quantization and mean and std.
The specification can be omitted only when the input OP is a single 4D tensor image data.
If omitted, it is automatically calibrated using 20 normalized MS-COCO images.
The type of the input OP must be Float32.
Data for calibration must be pre-normalized to a range of 0 to 1.
-qcind {input_op_name} {numpy_file_path} {mean} {std}
Numpy file paths must be specified the same number of times as the number of input OPs.
Normalize the value of the input OP based on the tensor specified in mean and std.
(input_value - mean) / std
Tensors in Numpy file format must be in dimension order after conversion to TF.
Note that this is intended for deployment on low-resource devices,
so the batch size is limited to 1 only.
e.g.
The example below shows a case where there are three input OPs.
Assume input0 is 128x128 RGB image data.
In addition, input0 should be a value that has been divided by 255
in the preprocessing and normalized to a range between 0 and 1.
input1 and input2 assume the input of something that is not an image.
Because input1 and input2 assume something that is not an image,
the divisor is not 255 when normalizing from 0 to 1.
"n" is the number of calibration data.
ONNX INPUT shapes:
input0: [n,3,128,128]
mean: [1,3,1,1] -> [[[[0.485]],[[0.456]],[[0.406]]]]
std : [1,3,1,1] -> [[[[0.229]],[[0.224]],[[0.225]]]]
input1: [n,64,64]
mean: [1,64] -> [[0.1, ..., 0.64]]
std : [1,64] -> [[0.05, ..., 0.08]]
input2: [n,5]
mean: [1] -> [0.3]
std : [1] -> [0.07]
TensorFlow INPUT shapes (Numpy file ndarray shapes):
input0: [n,128,128,3]
mean: [1,1,1,3] -> [[[[0.485, 0.456, 0.406]]]]
std : [1,1,1,3] -> [[[[0.229, 0.224, 0.225]]]]
input1: [n,64,64]
mean: [1,64] -> [[0.1, ..., 0.64]]
std : [1,64] -> [[0.05, ..., 0.08]]
input2: [n,5]
mean: [1] -> [0.3]
std : [1] -> [0.07]
qcind=[
["input0","../input0.npy",[[[[0.485, 0.456, 0.406]]]],[[[[0.229, 0.224, 0.225]]]]],
["input1","./input1.npy",[0.1, ..., 0.64],[0.05, ..., 0.08]],
["input2","input2.npy",[0.3],[0.07]],
]
input_output_quant_dtype: Optional[str]
Input and Output dtypes when doing Full INT8 Quantization.
"int8"(default) or "uint8"
not_use_onnxsim: Optional[bool]
No optimization by onnx-simplifier is performed.
If this option is used, the probability of a conversion error is very high.
not_use_opname_auto_generate: Optional[bool]
Automatic generation of each OP name in the old format ONNX file
and assignment of OP name are not performed.
batch_size: Optional[int]
Fixes the dynamic batch size to the specified numeric batch size.
A value of 1 or more must be specified.
overwrite_input_shape: Optional[List[str]]
Overwrite the input shape.
The format is
['i1:dim0,dim1,...,dimN' 'i2:dim0,dim1,...,dimN' 'i3:dim0,dim1,...,dimN']
When there is only one input, for example,
['data:1,3,224,224']
When there are multiple inputs, for example,
['data1:1,3,224,224','data2:1,3,112','data3:5']
A value of 1 or more must be specified.
Numerical values other than dynamic dimensions are ignored.
Ignores batch_size if specified at the same time as batch_size.
keep_ncw_or_nchw_or_ncdhw_input_names: Optional[List[str]]
Holds the NCW or NCHW or NCDHW of the input shape for the specified INPUT OP names.
If a nonexistent INPUT OP name is specified, it is ignored.
Valid only for 3D, 4D and 5D input tensors.
e.g.
keep_ncw_or_nchw_or_ncdhw_input_names=['input0','input1','input2']
keep_nwc_or_nhwc_or_ndhwc_input_names: Optional[List[str]]
Holds the NWC or NHWC or NDHWC of the input shape for the specified INPUT OP names.
If a nonexistent INPUT OP name is specified, it is ignored.
If the input OP name is the same as the input OP name specified
in the keep_ncw_or_nchw_or_ncdhw_input_names option, it is ignored.
Valid only for 3D, 4D and 5D input tensors.
e.g.
keep_nwc_or_nhwc_or_ndhwc_input_names=['input0','input1','input2']
keep_shape_absolutely_input_names: Optional[List[str]]
Name of the INPUT that unconditionally maintains its shape.
If a nonexistent INPUT OP name is specified, it is ignored.
e.g.
keep_shape_absolutely_input_names=['input0','input1','input2']
output_names_to_interrupt_model_conversion: Optional[List[str]]
Output names that interrupt model conversion.
Interrupts model transformation at the specified output name
and outputs the model partitioned into subgraphs.
e.g.
output_names_to_interrupt_model_conversion=['output0','output1','output2']
disable_group_convolution: Optional[bool]
Disable GroupConvolution and replace it with SeparableConvolution for
output to saved_model format.
enaable_batchmatmul_unfold: Optional[bool]
BatchMatMul is separated batch by batch to generate a primitive MatMul.
replace_argmax_to_reducemax_and_indicies_is_int64: Optional[bool]
Replace ArgMax with a ReduceMax. The returned indicies are int64.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64 and
replace_argmax_to_reducemax_and_indicies_is_float32 and
replace_argmax_to_fused_argmax_and_indicies_is_int64 and
replace_argmax_to_fused_argmax_and_indicies_is_float32 can be specified.
Default: False
replace_argmax_to_reducemax_and_indicies_is_float32: Optional[bool]
Replace ArgMax with a ReduceMax. The returned indicies are float32.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64 and
replace_argmax_to_reducemax_and_indicies_is_float32 and
replace_argmax_to_fused_argmax_and_indicies_is_int64 and
replace_argmax_to_fused_argmax_and_indicies_is_float32 can be specified.
Default: False
replace_argmax_to_fused_argmax_and_indicies_is_int64: Optional[bool]
Replace ArgMax with a ReduceMax. The returned indicies are int64.
It improves inference speed at the cost of a small sacrifice in accuracy.
See. https://github.com/tensorflow/models/tree/master/official/projects/edgetpu/vision#argmax-fusion-to-improve-segmentation-model-latency
Currently, only 4D tensors are supported.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64 and
replace_argmax_to_reducemax_and_indicies_is_float32 and
replace_argmax_to_fused_argmax_and_indicies_is_int64 and
replace_argmax_to_fused_argmax_and_indicies_is_float32 can be specified.
Default: False
replace_argmax_to_fused_argmax_and_indicies_is_float32: Optional[bool]
Replace ArgMax with a ReduceMax. The returned indicies are float32.
It improves inference speed at the cost of a small sacrifice in accuracy.
See. https://github.com/tensorflow/models/tree/master/official/projects/edgetpu/vision#argmax-fusion-to-improve-segmentation-model-latency
Currently, only 4D tensors are supported.
Only one of replace_argmax_to_reducemax_and_indicies_is_int64 and
replace_argmax_to_reducemax_and_indicies_is_float32 and
replace_argmax_to_fused_argmax_and_indicies_is_int64 and
replace_argmax_to_fused_argmax_and_indicies_is_float32 can be specified.
Default: False
fused_argmax_scale_ratio: Optional[float]
For Fused ArgMax.
Scale ratio when generating Fused ArgMax.
0.0 < fused_argmax_scale_ratio <= 1.0
Default: 0.5
replace_asin_to_pseudo_asin: Optional[bool]
Replace Asin with a pseudo Asin.
replace_acos_to_pseudo_acos: Optional[bool]
Replace Acos with a pseudo Acos.
replace_acbs_to_pseudo_abs: Optional[bool]
Replace Abs with a pseudo Abs.
replace_prelu_to_pseudo_prelu: Optional[bool]
Replace PReLU with a pseudo PReLU.
replace_leakyrelu_to_pseudo_leakyrelu: Optional[bool]
Replace LeakyReLU with a pseudo LeakyReLU.
replace_power_to_pseudo_power: Optional[bool]
Replace Power with a pseudo Power.
replace_gathernd_to_pseudo_gathernd: Optional[bool]
Replace GatherND with a pseudo GatherND.
replace_neg_to_pseudo_neg: Optional[bool]
Replace Neg with a pseudo Neg.
replace_hardswish_to_pseudo_hardswish: Optional[bool]
Replace HardSwish with a pseudo HardSwish.
replace_erf_to_pseudo_erf: Optional[bool]
Replace Erf with a pseudo Erf.
mvn_epsilon: Optional[float]
For MeanVarianceNormalization.
The number to be added to the variance to avoid division by zero
when normalizing the value.
(input_tensor - mean) / tf.sqrt(variance + mvn_epsilon)
Default: 0.0000000001
param_replacement_file: Optional[str]
Parameter replacement file path. (.json)
non_verbose: Optional[bool]
Do not show all information logs. Only error logs are displayed.
Default: False
Returns
----------
model: tf.keras.Model
Model
This tool is used to convert NCW
to NWC
, NCHW
to NHWC
, NCDHW
to NDHWC
, NCDDHW
to NDDHWC
, NCDDDDDDHW
to NDDDDDDHWC
. Therefore, as stated in the Key Concepts, the conversion will inevitably break down at some point in the model. You need to look at the entire conversion log to see which OP transpositions are failing and correct them yourself. I dare to explain very little because I know that no matter how much detail I put in the README, you guys will not read it at all. attribute
or INPUT constant
or INPUT Initializer
can be replaced with the specified value.
Starting from v1.3.0
, almost all OPs except for some special OPs support pre- and post-transposition by pre_process_transpose
and post_process_transpose
.
- "A conversion error occurs."
- "Output results are wrong."
Please don't post such low level questions as issues.
-
convert option
--param_replacement_file param_replacement.json
-
param_replacement.json
{ "format_version": 1, "operations": [ { "op_name": "StatefulPartitionedCall/Tile_4", "param_target": "inputs", # attributes or inputs "param_name": "const_fold_opt__677", "values": [1,1,17] # Disable parameter transposition or overwrite parameters }, { "op_name": "StatefulPartitionedCall/Cast_3", "param_target": "attributes", # attributes or inputs "param_name": "to", "values": 1 # Disable parameter transposition or overwrite "to" parameters }, { "op_name": "Resize__697", "param_target": "inputs", "param_name": "Concat__696:0", "values": [26,26] # Replacement of unk__x (Resize OP, sizes height/width parameter) }, { "op_name": "Transpose__927", "param_target": "attributes", "param_name": "perm", "values": [0,1,2,3] # Disable parameter transposition or overwrite "perm" parameters }, { "op_name": "StatefulPartitionedCall/functional_1/max_unpooling2d_2/Reshape_1", "param_target": "inputs", "param_name": "const_fold_opt__911", "values": [4,131072] # Overwrite "shape" parameters }, { "op_name": "Reshape_25", "param_target": "outputs", "param_name": "onnx::InstanceNormalization_270", "post_process_transpose_perm": [0,2,1] # Extrapolate 3D Transpose after Reshape }, { "op_name": "Reshape_30", "param_target": "outputs", "param_name": "onnx::Mul_275", "post_process_transpose_perm": [0,2,3,1] # Extrapolate 4D Transpose after Reshape }, { "op_name": "flatten_1127", "param_target": "inputs", "param_name": "dropout0", "pre_process_transpose_perm": [0,3,1,2] }, { "op_name": "/Slice", "param_target": "op", "begin": [0,0,1,0], "end": [0,0,0,0], "end_mask": 15 }, { "op_name": "/Slice_1", "param_target": "op", "begin": [0,0,0,0], "end": [0,0,39,0], "end_mask": 11 } ] }
-
Replacement Supported OPs
No. OP type Remarks 1 Add 1. "param_target": "inputs" pre_process_transpose_perm
: Transpose is applied to the tensor before the Add operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Add operation with the perm specified as post-processing.2 Cast Type Values Type Values float16 10 int8 3 float32 1 int16 5 float64 11 int32 6 bool 9 int64 7 uint8 2 uint16 4 uint32 12 uint64 13 3 Concat 1. "param_target": "attributes" axis
: Value ofaxis
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Concat operation with the perm specified as post-processing.4 ConvTranspose ConvTranspose
implements special replacements separately ignore all automatic conversions and generatetf.nn.conv1d_transpose
ortf.nn.conv2d_transpose
ortf.nn.conv3d_transpose
directly by specifying all parameters.
https://www.tensorflow.org/api_docs/python/tf/nn/conv1d_transpose
https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose
https://www.tensorflow.org/api_docs/python/tf/nn/conv3d_transpose
1. "param_target": "op"output_shape
: Value ofoutput_shape
strides
: Value ofstrides
padding
: Value ofpadding
dilations
: Value ofdilations
5 Div 1. "param_target": "inputs" values
: Value ofinput
pre_process_transpose_perm
: Transpose is applied to the tensor before the Div operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Div operation with the perm specified as post-processing.6 Expand 1. "param_target": "inputs" values
: Value ofshape
pre_process_transpose_perm
: Transpose is applied to the tensor before the Expand operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Expand operation with the perm specified as post-processing.7 Flatten 1. "param_target": "attributes" axis
: Value ofaxis
2. "param_target": "inputs"pre_process_transpose_perm
: Transpose is applied to the tensor before the Flatten operation with the perm specified as pre-processing.
3. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Flatten operation with the perm specified as post-processing.8 Gemm 9 Gather 1. "param_target": "inputs" values
: Value ofindices
pre_process_transpose_perm
: Transpose is applied to the tensor before the Gather operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Gather operation with the perm specified as post-processing.10 MatMul 1. "param_target": "inputs" pre_process_transpose_perm
: Transpose is applied to the tensor before the MatMul operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the MatMul operation with the perm specified as post-processing.11 Mul 1. "param_target": "inputs" values
: Value ofinput
pre_process_transpose_perm
: Transpose is applied to the tensor before the Mul operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Mul operation with the perm specified as post-processing.12 NonMaxSuppression 13 ReduceL1
ReduceL2
ReduceLogSum
ReduceLogSumExp
ReduceMax
ReduceMean
ReduceMin
ReduceProd
ReduceSum
ReduceSumSquare1. "param_target": "attributes" axes
: Value ofaxes
keepdims
: Value ofkeepdims
2. "param_target": "inputs"pre_process_transpose_perm
: Transpose is applied to the tensor before the ReduceXX operation with the perm specified as pre-processing.
3. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the ReduceXX operation with the perm specified as post-processing.14 Unsqueeze 1. "param_target": "inputs" pre_process_transpose_perm
: Transpose is applied to the tensor before the Unsqueeze operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Unsqueeze operation with the perm specified as post-processing.
3. "param_target": "op"new_shape
: Specifies directly the shape after Unsqueeze processing.15 Reshape 1. "param_target": "inputs" values
: Value ofshape
pre_process_transpose_perm
: Transpose is applied to the tensor before the Reshape operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Reshape operation with the perm specified as post-processing.16 Resize 1. "param_target": "attributes" coordinate_transformation_mode
: Value ofcoordinate_transformation_mode
extrapolation_value
: Value ofextrapolation_value
mode
: Value ofmode
2. "param_target": "inputs"values
: Value ofroi
orscales
orsizes
.scales
=[scale_h,scale_w]
,sizes
=[h,w]
pre_process_transpose_perm
: Transpose is applied to the tensor before the Resize operation with the perm specified as pre-processing.
3. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Resize operation with the perm specified as post-processing.17 Slice Slice
implements special replacements separately ignore all automatic conversions and generatetf.strided_slice
directly by specifying all parameters oftf.strided_slice
directly.
https://www.tensorflow.org/api_docs/python/tf/strided_slice
See replace_slice.json for a sample description.
1. "param_target": "op"begin
: Value ofbegin
end
: Value ofend
strides
: Value ofstrides
begin_mask
: Value ofbegin_mask
end_mask
: Value ofend_mask
ellipsis_mask
: Value ofellipsis_mask
new_axis_mask
: Value ofnew_axis_mask
shrink_axis_mask
: Value ofshrink_axis_mask
18 Softmax 1. "param_target": "attributes" axis
: Value ofaxis
. The transpositions corresponding to the specified axis are extrapolated before and afterSoftmax
.
2. "param_target": "inputs"values
: Value oftensor
19 Split 1. "param_target": "inputs" values
: Value ofsplit
2. "param_target": "attributes"axis
: Value ofaxis
.num_outputs
: Value ofnum_outputs
.20 Sub 1. "param_target": "inputs" values
: Value ofinput
pre_process_transpose_perm
: Transpose is applied to the tensor before the Sub operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Sub operation with the perm specified as post-processing.21 Tile 1. "param_target": "inputs" values
: Value ofinput
pre_process_transpose_perm
: Transpose is applied to the tensor before the Tile operation with the perm specified as pre-processing.
2. "param_target": "outputs"post_process_transpose_perm
: Transpose is applied to the tensor after the Tile operation with the perm specified as post-processing.22 Transpose 1. "param_target": "attributes" perm
: Value ofperm
2. "param_target": "inputs"values
: Value oftensor
- https://github.com/onnx/onnx/blob/main/docs/Operators.md
- ✔️: Supported Help wanted: Pull Request are welcome
OP Status Abs ✔️ Acosh ✔️ Acos ✔️ Add ✔️ And ✔️ ArgMax ✔️ ArgMin ✔️ Asinh ✔️ Asin ✔️ Atanh ✔️ Atan ✔️ AveragePool ✔️ BatchNormalization ✔️ Bernoulli ✔️ BitShift ✔️ BitwiseAnd Help wanted BitwiseNot Help wanted BitwiseOr Help wanted BitwiseXor Help wanted Cast ✔️ Ceil ✔️ Celu ✔️ CenterCropPad Help wanted Clip ✔️ Col2Im Help wanted Compress ✔️ ConcatFromSequence ✔️ Concat ✔️ ConstantOfShape ✔️ Constant ✔️ Conv ✔️ ConvTranspose ✔️ Cosh ✔️ Cos ✔️ CumSum ✔️ DepthToSpace ✔️ Det ✔️ DequantizeLinear ✔️ DFT Help wanted Div ✔️ Dropout ✔️ DynamicQuantizeLinear ✔️ Einsum ✔️ Elu ✔️ Equal ✔️ Erf ✔️ Expand ✔️ Exp ✔️ EyeLike ✔️ Flatten ✔️ Floor ✔️ FusedConv ✔️ GatherElements ✔️ GatherND ✔️ Gather ✔️ Gemm ✔️ GlobalAveragePool ✔️ GlobalLpPool ✔️ GlobalMaxPool ✔️ GreaterOrEqual ✔️ Greater ✔️ GridSample ✔️ GroupNormalization Help wanted GRU Help wanted Hardmax ✔️ HardSigmoid ✔️ HardSwish ✔️ Identity ✔️ If ✔️ Input ✔️ InstanceNormalization ✔️ Inverse ✔️ IsInf ✔️ IsNaN ✔️ LayerNormalization ✔️ LeakyRelu ✔️ LessOrEqual ✔️ Less ✔️ Log ✔️ LogSoftmax ✔️ Loop Help wanted LpNormalization ✔️ LRN ✔️ LSTM Help wanted MatMul ✔️ MatMulInteger ✔️ MaxPool ✔️ Max ✔️ MaxRoiPool Help wanted MaxUnpool ✔️ Mean ✔️ MeanVarianceNormalization ✔️ MelWeightMatrix Help wanted Min ✔️ Mish ✔️ Mod ✔️ Mul ✔️ Multinomial ✔️ Neg ✔️ NonMaxSuppression ✔️ NonZero ✔️ Optional Help wanted OptionalGetElement Help wanted OptionalHasElement Help wanted Not ✔️ OneHot ✔️ Or ✔️ Pad ✔️ Pow ✔️ PRelu ✔️ QLinearAdd ✔️ QLinearConcat ✔️ QLinearConv ✔️ QLinearLeakyRelu ✔️ QLinearMatMul ✔️ QLinearMul ✔️ QLinearSigmoid ✔️ QLinearSoftmax ✔️ QuantizeLinear ✔️ RandomNormalLike ✔️ RandomNormal ✔️ RandomUniformLike ✔️ RandomUniform ✔️ Range ✔️ Reciprocal ✔️ ReduceL1 ✔️ ReduceL2 ✔️ ReduceLogSum ✔️ ReduceLogSumExp ✔️ ReduceMax ✔️ ReduceMean ✔️ ReduceMin ✔️ ReduceProd ✔️ ReduceSum ✔️ ReduceSumSquare ✔️ Relu ✔️ Reshape ✔️ Resize ✔️ ReverseSequence ✔️ RNN Help wanted RoiAlign ✔️ Round ✔️ Scatter ✔️ ScatterElements ✔️ ScatterND ✔️ Scan Help wanted Selu ✔️ SequenceAt ✔️ SequenceConstruct ✔️ SequenceEmpty ✔️ SequenceErase ✔️ SequenceInsert ✔️ SequenceLength ✔️ Shape ✔️ Shrink ✔️ Sigmoid ✔️ Sign ✔️ Sinh ✔️ Sin ✔️ Size ✔️ Slice ✔️ Softmax ✔️ Softplus ✔️ Softsign ✔️ SpaceToDepth ✔️ Split ✔️ SplitToSequence ✔️ Sqrt ✔️ Squeeze ✔️ STFT Help wanted StringNormalizer Help wanted Sub ✔️ Sum ✔️ Tanh ✔️ Tan ✔️ TfIdfVectorizer Help wanted ThresholdedRelu ✔️ Tile ✔️ TopK ✔️ Transpose ✔️ Trilu ✔️ Unique ✔️ Unsqueeze ✔️ Upsample ✔️ Where ✔️ Xor ✔️
-
YOLOv7-tiny with Post-Process (NMS) ONNX to TFLite Float32 https://github.com/PINTO0309/onnx2tf/releases/download/0.0.33/yolov7_tiny_head_0.768_post_480x640.onnx
onnx2tf onnx-tensorflow
(Super redundant + Broken) -
YOLACT-Edge MobileNetV2 with Post-Process (MultiClass-NMS) ONNX to TFLite Float32 https://github.com/PINTO0309/onnx2tf/releases/download/1.0.11/yolact_edge_mobilenetv2_550x550.onnx
-
MoveNet MultiPose ONNX to TFLite Float32 (
Cast
andTrueDiv
standard OP support) https://github.com/PINTO0309/onnx2tf/releases/download/1.0.24/movenet_multipose_lightning_192x256_p6.onnx
ONNX file for testing. https://github.com/PINTO0309/onnx2tf/releases/tag/1.1.28
No. | Model | Pass |
---|---|---|
1 | age_googlenet.onnx | ✔️ |
2 | arcfaceresnet100-8.onnx | ✔️ |
3 | baseline_simplified.onnx | ✔️ |
4 | bvlcalexnet-12.onnx | ✔️ |
5 | caffenet-12.onnx | ✔️ |
6 | convtranspose_3_1_5_2.onnx | ✔️ |
7 | convtranspose_4_5_2_2.onnx | ✔️ |
8 | convtranspose_5_5_6_1.onnx | ✔️ |
9 | convtranspose_6_5_5_8.onnx | ✔️ |
10 | convtranspose_7_1_3_4.onnx | ✔️ |
11 | damoyolo_tinynasL20_T_192x192_post.onnx | ✔️ |
12 | densenet-12.onnx | ✔️ |
13 | digits.onnx | ✔️ |
14 | detr_demo.onnx | ✔️ |
15 | efficientformer_l1.onnx | ✔️ |
16 | efficientnet-lite4-11_nchw.onnx | ✔️ |
17 | effnet_opset11_dynamic_axis.onnx | ✔️ |
18 | emotion-ferplus-8_rename.onnx | ✔️ |
19 | face_detection_yunet_2022mar.onnx | ✔️ |
20 | face_recognition_sface_2021dec-act_int8-wt_int8-quantized.onnx | ✔️ |
21 | face_recognition_sface_2021dec.onnx | ✔️ |
22 | faster_rcnn-10.onnx | ✔️ |
23 | fastestdet.onnx | ✔️ |
24 | fused_conv_clip.onnx | ✔️ |
25 | fused_conv_hardsigmoid.onnx | ✔️ |
26 | fused_conv_leakyrelu.onnx | ✔️ |
27 | fused_conv_relu.onnx | ✔️ |
28 | fused_conv_sigmoid.onnx | ✔️ |
29 | fused_conv_tanh.onnx | ✔️ |
30 | gender_googlenet.onnx | ✔️ |
31 | handpose_estimation_mediapipe_2022may.onnx | ✔️ |
32 | iat_llie_180x320.onnx | ✔️ |
33 | if_p1_11.onnx | ✔️ |
34 | if_p2_11.onnx | ✔️ |
35 | if_p3_11.onnx | ✔️ |
36 | imageclassifier.onnx | ✔️ |
37 | inception-v2-9.onnx | ✔️ |
38 | inverse11.onnx | ✔️ |
39 | mnist-12.onnx | ✔️ |
40 | mobilenetv2-12.onnx | ✔️ |
41 | mosaic_11.onnx | ✔️ |
42 | mosaic-9.onnx | ✔️ |
43 | movenet_multipose_lightning_192x256_p6.onnx | ✔️ |
44 | nanodet-plus-m_416.onnx | ✔️ |
45 | object_tracking_dasiamrpn_kernel_cls1_2021nov.onnx | ✔️ |
46 | object_tracking_dasiamrpn_kernel_r1_2021nov.onnx | ✔️ |
47 | object_tracking_dasiamrpn_model_2021nov.onnx | ✔️ |
48 | pidnet_S_cityscapes_192x320.onnx | ✔️ |
49 | qlinear_conv_tensor_test.onnx | ✔️ |
50 | rcnn-ilsvrc13-9.onnx | ✔️ |
51 | regnet_x_400mf.onnx | ✔️ |
52 | ResNet101-DUC-12.onnx | ✔️ |
53 | resnet18-v1-7.onnx | ✔️ |
54 | resnet50-v1-12.onnx | ✔️ |
55 | resnet50-v2-7.onnx | ✔️ |
56 | retinanet-9.onnx | ✔️ |
57 | sinet_320_op.onnx | ✔️ |
58 | squeezenet1.0-12.onnx | ✔️ |
59 | super-resolution-10.onnx | ✔️ |
60 | tinyyolov2-8.onnx | ✔️ |
61 | version-RFB-640.onnx | ✔️ |
62 | vit-b-32_textual.onnx | ✔️ |
63 | vit-b-32_visual.onnx | ✔️ |
64 | yolact_edge_mobilenetv2_550x550.onnx | ✔️ |
65 | yolact_regnetx_600mf_d2s_31classes_512x512.onnx | ✔️ |
66 | yolact_regnetx_800mf_20classes_512x512.onnx | ✔️ |
67 | yolo_free_nano_crowdhuman_192x320_post.onnx | ✔️ |
68 | yolov7_tiny_head_0.768_post_480x640.onnx | ✔️ |
69 | yolox_nano_192x192.onnx | ✔️ |
70 | yolox_nano_416x416.onnx | ✔️ |
71 | yolox_s.onnx | ✔️ |
72 | yolox_x_crowdhuman_mot17_bytetrack.onnx | ✔️ |
73 | zero_dce_640_dele.onnx | ✔️ |
74 | zfnet512-12.onnx | ✔️ |