torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
version:
transformers==4.43.4
torch==2.1.0
Who can help?
No response
Information
[ ] The official example scripts
[ ] My own modified scripts
Tasks
[ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
[ ] My own task or dataset (give details below)
Reproduction
Script to reproduce
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('distilbert/distilbert-base-uncased')
model = AutoModel.from_pretrained("distilbert/distilbert-base-uncased", load_in_8bit=True, torchscript=True,)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
with torch.inference_mode():
model.eval()
traced_model = torch.jit.trace(model, [encoded_input["input_ids"], encoded_input["attention_mask"]], check_trace=True)
System Info
The trace failed with
version: transformers==4.43.4 torch==2.1.0
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Script to reproduce
Expected behavior
success :)