Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LiLT can not make inference with the Half (float16) dtype on CPU #43

Open
piegu opened this issue Apr 30, 2023 · 0 comments
Open

LiLT can not make inference with the Half (float16) dtype on CPU #43

piegu opened this issue Apr 30, 2023 · 0 comments

Comments

@piegu
Copy link

piegu commented Apr 30, 2023

Hi,

I wanted to make inference with LiLTwith model parameters to Half (float16) dtype on CPU (I did try on GPU and it worked).

As I'm using Transformers from Hugging Face, I ran the following code:

from transformers import AutoTokenizer, AutoModelForTokenClassification

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

param_dtype = torch.float16
model_id = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
model = AutoModelForTokenClassification.from_pretrained(model_id, torch_dtype=param_dtype);
model.to(device);

It worked but when I ran the model for inference with the following code, it failed:

with torch.no_grad():
    output = model(input_ids=input_id.to(device),
                    attention_mask=attention_mask.to(device),
                    bbox=bbox.to(device)
     )

Error message:

[/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in layer_norm(input, normalized_shape, weight, bias, eps)
   2513             layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
   2514         )
-> 2515     return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
   2516 
   2517 

RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'

It looks like that dtype float32 is directly implemented in the LiLT code.

How to solve this issue?
Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant