Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement and benchmark ONNX Runtime for Inference #39

Closed
tanaysoni opened this issue Mar 12, 2020 · 16 comments
Closed

Implement and benchmark ONNX Runtime for Inference #39

tanaysoni opened this issue Mar 12, 2020 · 16 comments
Labels
topic:speed type:feature New feature or request

Comments

@tanaysoni
Copy link
Contributor

Went with onnx-ecosystem which is a recent release (couple of weeks). Found nvidia-cuda-docker was not initializing, so I ditched Docker for now and ran this notebook from an environment with PyTorch v1.4.0, Transformers v2.5.1, ONNX runtimes v1.2.1 (CPU & GPU).

With the variables (max_seq_length=128, etc.) as originally specified, here is the result on GPU:

ONNX Runtime inference time:  0.00811

PyTorch Inference time =  0.02096
***** Verifying correctness *****
PyTorch and ORT matching numbers: True
PyTorch and ORT matching numbers: True

With max_seq_length=384, everything else the same, here is the result:

ONNX Runtime inference time:  0.0193

PyTorch Inference time =  0.0273
***** Verifying correctness *****
PyTorch and ORT matching numbers: True
PyTorch and ORT matching numbers: True

Should have more time tomorrow to examine these preliminary results and to further iterate & characterize the differences, including the notebook's variables per_gpu_eval_batch_size and eval_batch_size, both originally set to 1.

At this point I am more familiar with ALBERT_xxlarge inference performance, so eventually I may try to implement it in ONNX for an inference comparison on a larger model.

Here's another max_seq_length=384 run:
Inference-PyTorch-Bert-Model-for-High-Performance-in-ONNX-Runtime_WIP - Jupyter Notebook.pdf

Originally posted by @ahotrod in #23 (comment)

@tanaysoni
Copy link
Contributor Author

Hi @ahotrod, we are testing ONNX Runtime Inference with FARM. We ran a preliminary benchmark to compare it with PyTorch Inference for the forward pass of a model and observed ~2x performance gain.

We plan to implement it in FARM(deepset-ai/FARM#276) and then do an end-to-end benchmark in Haystack.

@tanaysoni
Copy link
Contributor Author

Hi @ahotrod

We used the tutorial notebook to run more benchmarks comparing the performance of ONNX and PyTorch Inference with different batch sizes.

Here's the code for benchmarks

# %env CUDA_LAUNCH_BLOCKING=1

# ONNX Runtime Inference

import onnxruntime as rt  
import time

sess_options = rt.SessionOptions()

# Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

# To enable model serialization and store the optimized graph to desired location.
sess_options.optimized_model_filepath = os.path.join(output_dir, "optimized_model.onnx")
session = rt.InferenceSession(output_model_path, sess_options)

for batch_size in (1, 2, 4, 8, 16, 32, 64):
    runtimes = []
    for _ in range(5):
        dataloader = DataLoader(dataset=dataset, batch_size=batch_size)
        batch = next(iter(dataloader))
        batch = tuple(t.to("cpu") for t in batch)
        inputs = {
            'input_ids':      batch[0],                       
            'attention_mask': batch[1],
            'token_type_ids': batch[2],
        }

        # evaluate the model
        start = time.time()
        res = session.run(None, {
                    'input_ids': inputs['input_ids'].cpu().numpy(),
                    'input_mask': inputs['attention_mask'].cpu().numpy(),
                    'segment_ids': inputs['token_type_ids'].cpu().numpy()
                })
        end = time.time()
        runtimes.append(end-start)
    print(f"ONNX Runtime inference time for batch_size {batch_size}: {round(sum(runtimes)/len(runtimes), 4)}")


# PyTorch Inference
model.to("cuda")
for batch_size in (1, 2, 4, 8, 16, 32, 64):
    runtimes = []
    for _ in range(5):
        dataloader = DataLoader(dataset=dataset, batch_size=batch_size)
        batch = next(iter(dataloader))
        batch = tuple(t.to("cuda") for t in batch)
        inputs = {
            'input_ids':      batch[0],
            'attention_mask': batch[1],
            'token_type_ids': batch[2],
        }

        # evaluate the model
        start = time.time()
        outputs = model(**inputs)
        end = time.time()
        runtimes.append(end-start)
    print(f"PyTorch inference time for batch_size {batch_size}: {round(sum(runtimes)/len(runtimes), 4)}")

The benchmarks were done on an AWS EC2 p3.2xlarge(V100 GPU) instance with pytorch v1.4.0, transformers v2.4.0, onnx v1.6.0, and onnxruntime-gpu v1.2.0

Here's the comparison of Inference times (in seconds)

Batch Size ONNX PyTorch ONNX SpeedUp
1 0.0075 0.0307 4.09
2 0.0089 0.0329 3.70
4 0.0128 0.0364 2.84
8 0.0193 0.0482 2.50
16 0.0348 0.0660 1.90
32 0.0648 0.1068 1.65
64 0.1288 0.1621 1.26

It seems ONNX Inference is faster compared to PyTorch when using lower batch size, but the difference decreases as we increase the batch size. Wondering if there's any further optimization that could be done for ONNX Runtime with respect to batch sizing?

@ahotrod
Copy link

ahotrod commented Mar 14, 2020

@tanaysoni There is on-going further speed optimization of Bert w/ONNX here:
Add Bert Optimization Notebooks #3204

Checks are still running as I type this, with 1 pending review. I have not had an opportunity to evaluate the changes, but after a cursory review, it appears there are significant changes to nine supporting code files, plus a primary change to the forked notebook is:

# Use contiguous array as input could improve performance.
ort_inputs = {'input_ids': numpy.ascontiguousarray(inputs['input_ids'].cpu().numpy()),
              'input_mask': numpy.ascontiguousarray(inputs['attention_mask'].cpu().numpy()),
              'segment_ids': numpy.ascontiguousarray(inputs['token_type_ids'].cpu().numpy())
}

# Warm up with one run.
session.run(None, ort_inputs)

# Measure the latency.
start = time.time()
results = session.run(None, ort_inputs)
end = time.time()

PyTorch cuda Inference time = 30.92 ms
ONNX Runtime cuda inference time: 9.97 ms

Note one of the commits is Allow test multiple batch_size.

FYI ref: Graph Optimizations in ONNX Runtime
rt.GraphOptimizationLevel.ORT_ENABLE_ALL doesn't appear to add anything for GPU, only CPU, and rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED previously adds BERT Embedding Layer Fusion.

@tholor tholor added the type:feature New feature or request label Mar 17, 2020
@tanaysoni
Copy link
Contributor Author

@ahotrod thank you for all the pointers! I could reproduce the results in the newly updated notebook with a batch_size of 1. I'll now try with different sizes and also benchmark the integration will FARM.

@ahotrod
Copy link

ahotrod commented Apr 23, 2020

@tanaysoni Wondering if you have taken the next step of importing the ONNX model into TensorRT for NVIDIA Cuda inferencing performance gains in a production environment?

Seen any ONNX support for RoBERTa?

@tanaysoni
Copy link
Contributor Author

@ahotrod we would want to explore different execution providers for Haystack in the next days. Do you have any experience/benchmark for BERT SQuAD with TensorRT or nGraph?

I haven't yet seen an ONNX for RoBERTa but would be excited to test it out in Haystack!

@ahotrod
Copy link

ahotrod commented Apr 23, 2020

@tanaysoni I have no experience/benchmark for BERT SQuAD with TensorRT or nGraph.

Historical early comments on "productionizing the (HF) models" from Oct2019: https://medium.com/huggingface/benchmarking-transformers-pytorch-and-tensorflow-e2917fb891c2

There's an informative intro into "The Serving Problem" at this 20 April 2020 blog post:
https://blog.einstein.ai/benchmarking-tensorrt-inference-server/

This blog post also benchmarks the newly-named NVIDIA Triton (formerly TensorRT) Inference Server. Early impression includes the positive that Triton hosts models from multiple frameworks (ONNX, PyTorch and TensorFlow) and multiple HF Transformer Language models, e.g. BERT, ALBERT, GPT2 and CTRL mentioned in the blog. Downside is that Triton links to a proprietary NVIDIA hardware solution, no surprise, maybe even requiring NVIDIA's DGX in their GPU Cloud, not sure.

Triton github: https://github.com/NVIDIA/triton-inference-server/tree/d7cc183b7611f7775e1808b0a9d25a36e3d6e055#roadmap

I have just begun looking for a reasonable cloud inferencing solution for large (ALBERT-xxlarge, RoBERTa-large, eventually Elastic-large, ... etc.) HF Transformer QA models, either in Tensorflow or Pytorch, compatible with Haystack. Reasonable in that each inferencing of a single 6K-word or less document can take low seconds, not tens of seconds.

Looks like I will be using an AWS cloud solution. Currently working on a domain vocabulary file for AWS Transcribe.

@tanaysoni
Copy link
Contributor Author

Hi @ahotrod, thank you for all the pointers!

I did a quick test running an ONNX model on TensorRT(V100 GPU) using this Dockerfile, but the benchmarks did not show performance gains. I'll have to investigate further before posting the results.

Meanwhile, we are also working on implementing an inference speed benchmarking pipeline in FARM-#321. This will help reproduce benchmarks for different models, execution providers, batch sizing, and other params.

@ahotrod
Copy link

ahotrod commented May 6, 2020

@tanaysoni FYI ONNX Conversion Script just posted.
Will be following with interest/implications for my models.
Conformity/coordination with your work in Haystack-Farm?

@tholor
Copy link
Member

tholor commented May 13, 2020

@ahotrod yep, I believe onnx-runtime can be a good alternative to PyTorch and becomes increasingly popular. Great to see that Transformers is also implementing it! We will try to support this in Haystack.

On our end, we finished the implementation in FARM and recently added some benchmark scripts. Maybe the results are interesting to you: Google Spreadsheet

The speedup is particularly significant for smaller batches and when the ONNX optimizations for V100 (or similar devices) are applied.

We will work on getting a "FarmOnnxReader" into Haystack.

@ahotrod
Copy link

ahotrod commented May 20, 2020

@tholor @tanaysoni

You may find this interesting: HF ONNX

@tanaysoni
Copy link
Contributor Author

Hi @ahotrod, ONNX support is now added in Haystack with #157!

@raphychek
Copy link

Hi @tanaysoni! Measuring the time of inference on GPU (at least in PyTorch, not that sure about ONNX) doesn't work well like this, as executions on GPU are asynchronous. The results you have might be uncorrect. You should check this link which explains it: https://towardsdatascience.com/the-correct-way-to-measure-inference-time-of-deep-neural-networks-304a54e5187f

@tholor
Copy link
Member

tholor commented Oct 1, 2021

Hey @raphychek , not sure which of our code your are referring to here? You are totally right that GPU computations are asynchronous and that's why we usually use torch.cuda.synchronize() between GPU operations that we measure OR measure on an outer scope where the GPU was forced to sync (e.g. when assigning back to CPU or aggregating results as in some of the above snippets).

@raphychek
Copy link

raphychek commented Oct 1, 2021

Hi @tholor. I might be learning something new here, so thank you for that! What do you mean by "aggregating results" and how does it allow the GPU to be forced to sync?

From my own experiments, measuring time of GPU infered operations with a time.time() substraction gave differents -and sometimes inconsistents- results than when using torch.cuda.synchronise() and measuring time with torch.cuda.Event(). Hence the part that seems suspect to me is this one, especially knowing your model in on cuda (model.to("cuda") in your code):

        start = time.time()
        outputs = model(**inputs)
        end = time.time()
        runtimes.append(end-start)

@tholor
Copy link
Member

tholor commented Oct 4, 2021

Ah, I see you are referring to this code snippet above. I believe the code that you quoted from there could indeed be problematic - depends on the implementation of the model forward pass though. Imagine an operation that sums all logits and prints the sum. Such an operation forces the GPU to sync. Unfortunately, the notebook linked in the comment seems to be deleted by now.

However, this script above was just one of our earlier benchmark runs. For ONNX we had a couple of other scripts later on that actually make explicit use of torch.synchronize, see:
https://github.com/deepset-ai/FARM/blob/7305a17979b0a80dbe2dbebe5815450883f20627/farm/infer.py#L645

Hope this is helpful :)

masci pushed a commit that referenced this issue Nov 27, 2023
Simplify `pygraphviz` optional import
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic:speed type:feature New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants