Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.compile support for pytorch 2.4 #1690

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

Fabioomega
Copy link
Contributor

Added support for torch.compile only for version 2.4 or higher of pytorch. Included support for all the detection models and a recognition model (parseq).

Unfortunately, triton support is only available on linux plataforms. WSL seems to work fine tough, so it may be used that way.

Example use:

  1. Enable the feature by setting the enviroment variable: USE_TRITON = YES.
  2. Try the following code:
from doctr.models import ocr_predictor
from doctr.io import Document
from cv2 import imread
from time import time

t1 = time()
reader = ocr_predictor(det_arch='db_resnet50', pretrained=True)
print('Loading time:', time() - t1)

img = imread("<img>")
t2 = time()
d: Document = reader([img])
print('Document of the first try ', d)
print('Processing time of the first try', time() - t2)
t3 = time()
d: Document = reader([img])
print('Processing time of the second try:', time() - t3)

@Fabioomega Fabioomega marked this pull request as draft August 8, 2024 18:13
@@ -266,7 +266,8 @@ def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] =
).int()

pos_logits = []
for i in range(max_length):
i = 0
while i < max_length:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember there was a issue with while loops by exporting to onnx so we have to be careful here (needs to be checked)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it because it was some unecessary complication related to breaks in torch.compile. Changing to a while loop and changing the logic a bit helped. Hopefully it works for the onnx also

Copy link
Contributor

@felixdittrich92 felixdittrich92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Fabioomega 👋

Tests looks already good to me 👍
Docs section missing here: https://github.com/mindee/doctr/blob/main/docs/source/using_doctr/using_model_export.rst

As mentioned a table would be great :)
For the classification models it would be enough to add both orientation models to the table (i don't think we should blow up the table by adding all the backbone models)

Todo's:

  • comments
  • unittests
  • docs

Left some comments to revert unrequired parts :)

As mentioned for follow up PR's we can focus on fixes for the models which does not work yet out of the box 👍

@@ -76,6 +75,18 @@
" is installed and that either USE_TF or USE_TORCH is enabled."
)

if _torch_available:
Copy link
Contributor

@felixdittrich92 felixdittrich92 Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fabioomega We can remove this.

2 options:

We pin the lower boundary to >= 2.0.0 here

"torch>=1.12.0,<3.0.0",

"torch>=1.12.0,<3.0.0",

and torchvision>=0.15.0

or we mention in the docs that this requires >= 2.0.0 for compile and >=2.4.0 for compile + fullgraph

@odulcy-mindee wdyt ?
We are already at 2.4.0 so i would prefer the >=2.0.0 pin (in this case only to mention >=2.4.0 for fullgraph (triton) support)

@@ -104,3 +115,11 @@ def is_torch_available():
def is_tf_available():
"""Whether TensorFlow is installed."""
return _tf_available

def does_torch_have_compile_capability():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be reverted complete

doctr/models/detection/_utils/__init__.py Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hame some questions about that! Wasn't the original ideia to add a new argument to enable compilation? Did I misunderstood?

Copy link
Contributor

@felixT2K felixT2K Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was the first thought as your code looked like changes to the pipeline/models were needed. However, we then saw that these were not needed.
Which is why we only add tests here and a section on how to use it. The compilation therefore remains on the user side, which is at the same time much more flexible. :)
Additional this avoids to add a arg which at the end only does -> model = torch.compile(model, ..) and is backend depending (PyTorch).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A full sample would look like then for example:

import requests
import torch
from doctr.models import ocr_predictor, parseq, fast_base
from doctr.io import DocumentFile

bytes_data = requests.get(
    "https://i1.rgstatic.net/publication/231831562_Another_Boring_Day_in_Paradise_Rock_and_Roll_and_the_Empowerment_of_Everyday_Life/links/57d02a2408ae601b39a05636/largepreview.png"
).content

doc = DocumentFile.from_images([bytes_data])

rec_model = torch.compile(parseq(pretrained=True))
det_model = torch.compile(fast_base(pretrained=True))
predictor = ocr_predictor(det_arch=det_model, reco_arch=rec_model, pretrained=True)

res = predictor(doc)
res.show()

The only required change here would be to allow also:
torch._dynamo.eval_frame.OptimizedModule in

arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
and
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
and
if not isinstance(arch, classification.MobileNetV3):

@@ -186,3 +186,46 @@ def test_models_onnx_export(arch_name, input_shape, output_size):
assert np.allclose(pt_logits, ort_outs[0], atol=1e-4)
except AssertionError:
pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}")

@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0")
@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the first two skipif Ci runs always on latest pytorch - same for the other tests 👍

@felixdittrich92
Copy link
Contributor

Hi @Fabioomega :)

Are you still interested to work on it ? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants