huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.11k stars 26.32k forks source link

[Distilbert] Torch jit trace failed with `load_in_8bit=True`. #33442

Open sihanwang41 opened 6 days ago

sihanwang41 commented 6 days ago

System Info

The trace failed with

torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!

version: transformers==4.43.4 torch==2.1.0

Who can help?

No response

Information

Tasks

Reproduction

Script to reproduce

import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')
model = AutoModel.from_pretrained("distilbert/distilbert-base-uncased", load_in_8bit=True,  torchscript=True,)

text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
with torch.inference_mode():
    model.eval()
    traced_model = torch.jit.trace(model, [encoded_input["input_ids"], encoded_input["attention_mask"]], check_trace=True)

Expected behavior

success :)

LysandreJik commented 5 days ago

I'm not certain the two are compatible unfortunately, cc @SunMarc