Skip to content

Commit

Permalink
OpenVINO backend
Browse files Browse the repository at this point in the history
  • Loading branch information
dkurt committed Apr 5, 2022
1 parent 15c8e9b commit 8980f4f
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 3 deletions.
72 changes: 72 additions & 0 deletions cellpose/contrib/openvino_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import io

import numpy as np
import torch
from openvino.inference_engine import IECore

ie = IECore()

def to_openvino(model):
if isinstance(model.net, OpenVINOModel):
return model
model.mkldnn = False
model.net.mkldnn = False
model.net = OpenVINOModel(model.net)
return model


class OpenVINOModel(object):
def __init__(self, model):
self._base_model = model
self._nets = {}
self._exec_nets = {}
self._model_id = "default"


def _init_model(self, inp):
if self._model_id in self._nets:
return self._nets[self._model_id], self._exec_nets[self._model_id]

# Load a new instance of the model with updated weights
if self._model_id != "default":
self._base_model.load_model(self._model_id, cpu=True)

buf = io.BytesIO()
dummy_input = torch.zeros([1] + list(inp.shape[1:])) # To avoid extra network reloading we process batch in the loop
torch.onnx.export(self._base_model, dummy_input, buf, input_names=["input"], output_names=["output", "style"])
net = ie.read_network(buf.getvalue(), b"", init_from_buffer=True)
exec_net = ie.load_network(net, "CPU")

self._nets[self._model_id] = net
self._exec_nets[self._model_id] = exec_net

return net, exec_net


def __call__(self, inp):
net, exec_net = self._init_model(inp)

batch_size = inp.shape[0]
if batch_size > 1:
out_shape = net.outputs["output"].shape
style_shape = net.outputs["style"].shape
output = np.zeros([batch_size] + out_shape[1:], np.float32)
style = np.zeros([batch_size] + style_shape[1:], np.float32)
for i in range(batch_size):
out = exec_net.infer({"input": inp[i : i + 1]})
output[i] = out["output"]
style[i] = out["style"]

return torch.tensor(output), torch.tensor(style)
else:
out = exec_net.infer({"input": inp})
return torch.tensor(out["output"]), torch.tensor(out["style"])


def load_model(self, path, cpu):
self._model_id = path
return self


def eval(self):
pass
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Cellpose 1.0
outputs
models
train

openvino

.. toctree::
:maxdepth: 3
Expand All @@ -61,4 +61,4 @@ Cellpose 1.0
.. toctree::
:caption: API Reference:

api
api
20 changes: 20 additions & 0 deletions docs/openvino.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
OpenVINO
------------------------------

`OpenVINO <https://github.com/openvinotoolkit/openvino>`_ is an optional backend for Cellpose which optimizes deep learning inference for Intel Architectures.

It can be installed with a primary package by adding extra suffix:

::

pip install cellpose[openvino]

Using ``openvino_utils.to_openvino``, convert PyTorch model to OpenVINO one:

::

from cellpose.contrib import openvino_utils

model = models.CellposeModel(...)

model = openvino_utils.to_openvino(model)
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
'scikit-learn',
]

openvino_deps = [
'openvino==2021.3',
]

try:
import torch
a = torch.ones(2, 3)
Expand Down Expand Up @@ -63,7 +67,8 @@
'docs': docs_deps,
'gui': gui_deps,
'distributed': distributed_deps,
'all': gui_deps + distributed_deps,
'openvino': openvino_deps,
'all': gui_deps + distributed_deps + openvino_deps,
},
include_package_data=True,
classifiers=(
Expand Down
39 changes: 39 additions & 0 deletions tests/contrib/test_openvino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import numpy as np
import torch
from cellpose import io, models
from cellpose.contrib import openvino_utils


def create_model():
return models.CellposeModel(gpu=False,
pretrained_model="cyto",
net_avg=True,
device=torch.device("cpu"))

def test_unet(data_dir):
image_name = 'rgb_2D.png'
img = io.imread(str(data_dir.joinpath('2D').joinpath(image_name)))

# Get a reference results
ref_model = create_model()
ref_masks, ref_flows, ref_styles = ref_model.eval(img, net_avg=True)

# Convert model to OpenVINO format
ov_model = create_model()
ov_model = openvino_utils.to_openvino(ov_model)

out_masks, out_flows, out_styles = ov_model.eval(img, net_avg=True)

assert ref_masks.shape == out_masks.shape
assert ref_styles.shape == out_styles.shape

assert np.all(ref_masks == out_masks)
assert np.max(np.abs(ref_styles - out_styles)) < 1e-5

for ref_flow, out_flow in zip(ref_flows, out_flows):
if ref_flow is None or np.prod(ref_flow.shape) == 0:
continue

assert ref_flow.shape == out_flow.shape
assert np.max(np.abs(ref_flow - out_flow)) < 1e-4

0 comments on commit 8980f4f

Please sign in to comment.