Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Roialign, roundings, trigonometry #33

Merged
merged 6 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from onnx2torch.node_converters.conv import *
from onnx2torch.node_converters.expand import *
from onnx2torch.node_converters.flatten import *
from onnx2torch.node_converters.functions import *
from onnx2torch.node_converters.gather import *
from onnx2torch.node_converters.gemm import *
from onnx2torch.node_converters.global_average_pool import *
Expand All @@ -24,6 +25,8 @@
from onnx2torch.node_converters.reduce import *
from onnx2torch.node_converters.reshape import *
from onnx2torch.node_converters.resize import *
from onnx2torch.node_converters.roialign import *
from onnx2torch.node_converters.roundings import *
from onnx2torch.node_converters.scatter_nd import *
from onnx2torch.node_converters.shape import *
from onnx2torch.node_converters.split import *
Expand Down
17 changes: 1 addition & 16 deletions onnx2torch/node_converters/activations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ['OnnxExp', 'OnnxErf', 'OnnxHardSigmoid', 'OnnxSoftmaxV1V11']
__all__ = ['OnnxErf', 'OnnxHardSigmoid', 'OnnxSoftmaxV1V11']

import torch
from torch import nn
Expand All @@ -10,12 +10,6 @@
from onnx2torch.utils.common import onnx_mapping_from_node


class OnnxExp(nn.Module):

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return torch.exp(input_tensor)


class OnnxErf(nn.Module):

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -54,15 +48,6 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint:
)


@add_converter(operation_type='Exp', version=6)
@add_converter(operation_type='Exp', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=OnnxExp(),
onnx_mapping=onnx_mapping_from_node(node=node),
)


@add_converter(operation_type='HardSigmoid', version=1)
@add_converter(operation_type='HardSigmoid', version=6)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
Expand Down
58 changes: 58 additions & 0 deletions onnx2torch/node_converters/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
__all__ = ['OnnxFunction']

import torch
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node

# Exporting from pytorch to onnx operators atanh, asinh, acosh, cosh, sinh are not supported
_TORCH_FUNCTION_FROM_ONNX_TYPE = {
'Abs': torch.abs,
'Acos': torch.acos,
'Asin': torch.asin,
'Atan': torch.atan,
'Cos': torch.cos,
'Exp': torch.exp,
'Log': torch.log,
'Sign': torch.sign,
'Sin': torch.sin,
'Tan': torch.tan,
'Tanh': torch.tanh,
}


class OnnxFunction(nn.Module):

def __init__(self, function_type: str):
super().__init__()
self.function = _TORCH_FUNCTION_FROM_ONNX_TYPE[function_type]

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
return self.function(input_tensor)


@add_converter(operation_type='Abs', version=13)
@add_converter(operation_type='Abs', version=6)
@add_converter(operation_type='Acos', version=7)
@add_converter(operation_type='Asin', version=7)
@add_converter(operation_type='Atan', version=7)
@add_converter(operation_type='Cos', version=7)
@add_converter(operation_type='Exp', version=6)
@add_converter(operation_type='Exp', version=13)
@add_converter(operation_type='Log', version=13)
@add_converter(operation_type='Log', version=6)
@add_converter(operation_type='Sign', version=13)
@add_converter(operation_type='Sign', version=9)
@add_converter(operation_type='Sin', version=7)
@add_converter(operation_type='Tan', version=7)
@add_converter(operation_type='Tanh', version=13)
@add_converter(operation_type='Tanh', version=6)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=OnnxFunction(node.operation_type),
onnx_mapping=onnx_mapping_from_node(node=node),
)
69 changes: 69 additions & 0 deletions onnx2torch/node_converters/roialign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
__all__ = ['OnnxRoiAlign']

import torch
from torch import nn
from torchvision.ops import roi_align

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node


class OnnxRoiAlign(nn.Module):

def __init__(
self,
mode: str = 'avg',
output_height: int = 1,
output_width: int = 1,
sampling_ratio: int = 0,
spatial_scale: float = 1.0,
):
super().__init__()

if mode != 'avg':
raise NotImplementedError(f'"{mode}" roi align mode is not implemented.')

self._output_size = (output_height, output_width)
self._sampling_ratio = sampling_ratio
self._spatial_scale = spatial_scale

def forward(
self,
input_tensor: torch.Tensor,
rois: torch.Tensor,
batch_indices: torch.Tensor,
) -> torch.Tensor:
batched_rois = torch.concat([batch_indices.unsqueeze(1).to(rois.dtype), rois], dim=1)

return roi_align(
input=input_tensor,
boxes=batched_rois,
output_size=self._output_size,
spatial_scale=self._spatial_scale,
sampling_ratio=self._sampling_ratio,
aligned=False,
)


@add_converter(operation_type='RoiAlign', version=10)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
node_attributes = node.attributes
mode = node_attributes.get('mode', 'avg')
output_height = node_attributes.get('output_height', 1)
output_width = node_attributes.get('output_width', 1)
sampling_ratio = node_attributes.get('sampling_ratio', 0)
spatial_scale = node_attributes.get('spatial_scale', 1.0)

return OperationConverterResult(
torch_module=OnnxRoiAlign(
mode=mode,
ivkalgin marked this conversation as resolved.
Show resolved Hide resolved
output_height=output_height,
output_width=output_width,
sampling_ratio=sampling_ratio,
spatial_scale=spatial_scale,
),
onnx_mapping=onnx_mapping_from_node(node),
)
39 changes: 39 additions & 0 deletions onnx2torch/node_converters/roundings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
__all__ = ['OnnxRound']

import torch
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node


_TORCH_ROUND_FROM_ONNX_TYPE = {
'Ceil': torch.ceil,
'Floor': torch.floor,
'Round': torch.round,
}


class OnnxRound(nn.Module):

def __init__(self, round_type: str):
super().__init__()
self.round_function = _TORCH_ROUND_FROM_ONNX_TYPE[round_type]

def forward(self, input_tensor: torch.Tensor):
return self.round_function(input_tensor)


@add_converter(operation_type='Ceil', version=13)
@add_converter(operation_type='Ceil', version=6)
@add_converter(operation_type='Floor', version=13)
@add_converter(operation_type='Floor', version=6)
@add_converter(operation_type='Round', version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
return OperationConverterResult(
torch_module=OnnxRound(node.operation_type),
onnx_mapping=onnx_mapping_from_node(node=node),
)
5 changes: 2 additions & 3 deletions tests/node_converters/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ def _test_activation(activation: str, data: np.ndarray, opset_version, **kwargs)
@pytest.mark.parametrize(
'activation,input_shape',
(
('Relu', [8, 3, 32, 32]),
('Erf', [8, 3, 32, 32]),
('Exp', [8, 3, 32, 32]),
('Sigmoid', [8, 3, 32, 32]),
('HardSigmoid', [8, 3, 32, 32]),
('LeakyRelu', [8, 3, 32, 32]),
('Relu', [8, 3, 32, 32]),
('Sigmoid', [8, 3, 32, 32]),
),
)
def test_common_activations(activation: str, input_shape: List[int]) -> None:
Expand Down
110 changes: 110 additions & 0 deletions tests/node_converters/roialign_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import List

import numpy as np
import onnx
import pytest
from onnx.helper import make_tensor_value_info
from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes


def get_roi_align_input_values(): # type: ignore
x = np.array(
[
[
[
[
0.2764, 0.7150, 0.1958, 0.3416, 0.4638, 0.0259, 0.2963, 0.6518, 0.4856, 0.7250,
],
[
0.9637, 0.0895, 0.2919, 0.6753, 0.0234, 0.6132, 0.8085, 0.5324, 0.8992, 0.4467,
],
[
0.3265, 0.8479, 0.9698, 0.2471, 0.9336, 0.1878, 0.4766, 0.4308, 0.3400, 0.2162,
],
[
0.0206, 0.1720, 0.2155, 0.4394, 0.0653, 0.3406, 0.7724, 0.3921, 0.2541, 0.5799,
],
[
0.4062, 0.2194, 0.4473, 0.4687, 0.7109, 0.9327, 0.9815, 0.6320, 0.1728, 0.6119,
],
[
0.3097, 0.1283, 0.4984, 0.5068, 0.4279, 0.0173, 0.4388, 0.0430, 0.4671, 0.7119,
],
[
0.1011, 0.8477, 0.4726, 0.1777, 0.9923, 0.4042, 0.1869, 0.7795, 0.9946, 0.9689,
],
[
0.1366, 0.3671, 0.7011, 0.6234, 0.9867, 0.5585, 0.6985, 0.5609, 0.8788, 0.9928,
],
[
0.5697, 0.8511, 0.6711, 0.9406, 0.8751, 0.7496, 0.1650, 0.1049, 0.1559, 0.2514,
],
[
0.7012, 0.4056, 0.7879, 0.3461, 0.0415, 0.2998, 0.5094, 0.3727, 0.5482, 0.0502,
],
]
]
],
dtype=np.float32,
)
batch_indices = np.array([0, 0, 0], dtype=np.int64)
rois = np.array([[0, 0, 9, 9], [0, 5, 4, 9], [5, 5, 9, 9]], dtype=np.float32)
return x, batch_indices, rois


def _test_roi(
input_tensor: np.ndarray,
rois: np.ndarray,
batch_indices: np.ndarray,
**kwargs,
) -> None:
test_inputs = {'X': input_tensor, 'rois': rois, 'batch_indices': batch_indices}

node = onnx.helper.make_node(
op_type='RoiAlign',
inputs=list(test_inputs),
outputs=['y'],
**kwargs,
)
onnx_type = NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')]
outputs_info = [make_tensor_value_info(name='y', elem_type=onnx_type, shape=None)]
model = make_model_from_nodes(
nodes=node,
initializers={},
inputs_example=test_inputs,
outputs_info=outputs_info,
)
check_onnx_model(model, test_inputs)


@pytest.mark.parametrize(
'spatial_scale,sampling_ratio,output_height,output_width',
(
(1.0, 2, 5, 5),
(0.25, 0, 7, 7),
(0.125, 0, 7, 7),
(0.6, 0, 1, 1),
(None, None, None, None),
)
)
@pytest.mark.filterwarnings('ignore::torch.jit._trace.TracerWarning')
def test_roi(spatial_scale: float, sampling_ratio: int, output_height: int, output_width:int) -> None:
x, batch_indices, rois = get_roi_align_input_values()
kwargs = {}
if spatial_scale is not None:
kwargs['spatial_scale'] = spatial_scale
if sampling_ratio is not None:
kwargs['sampling_ratio'] = sampling_ratio
if output_height is not None:
kwargs['output_height'] = output_height
if output_width is not None:
kwargs['output_width'] = output_width
_test_roi(
input_tensor=x,
rois=rois,
batch_indices=batch_indices,
**kwargs,
)
Loading