jpWang / LiLT

Official PyTorch implementation of LiLT: A Simple yet Effective Language-Independent Layout Transformer for Structured Document Understanding (ACL 2022)
MIT License
342 stars 40 forks source link

LiLT can not make inference with the Half (float16) dtype on CPU #43

Open piegu opened 1 year ago

piegu commented 1 year ago

Hi,

I wanted to make inference with LiLTwith model parameters to Half(float16) dtype on CPU (I did try on GPU and it worked).

As I'm using Transformers from Hugging Face, I ran the following code:

from transformers import AutoTokenizer, AutoModelForTokenClassification

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

param_dtype = torch.float16
model_id = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
model = AutoModelForTokenClassification.from_pretrained(model_id, torch_dtype=param_dtype);
model.to(device);

It worked but when I ran the model for inference with the following code, it failed:

with torch.no_grad():
    output = model(input_ids=input_id.to(device),
                    attention_mask=attention_mask.to(device),
                    bbox=bbox.to(device)
     )

Error message:

[/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in layer_norm(input, normalized_shape, weight, bias, eps)
   2513             layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
   2514         )
-> 2515     return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
   2516 
   2517 

RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'

It looks like that dtype float32 is directly implemented in the LiLT code.

How to solve this issue? Thanks.