Closed TiesdeKok closed 2 years ago
Disclaimer: This isn't a high-priority issue, it is a bit of a weird and unusual use case. Mostly a PSA in case anyone else runs into the same issue.
The HF version of the GPT-J model does not load using Torch/XLA.
Error:
RuntimeError: torch_xla/csrc/tensor_methods.cpp:896 : Check failed: xla::ShapeUtil::Compatible(shapes.back(), tensor_shape) <stack trace> f32[1,5,16]{2,1,0} vs. f16[1,5,16]{2,1,0}
The issue appears to be triggered by this check: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_methods.cpp#L939
A reproducible example:
import torch import torch_xla.core.xla_model as xm from transformers import AutoTokenizer, AutoModelForCausalLM dev = xm.xla_device() tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", eos_token='<|endoftext|>', pad_token='<|pad|>') model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16) model.resize_token_embeddings(len(tokenizer)) model = model.to(dev) prompt = "This is a prompt " encodings_input = tokenizer(prompt, return_tensors="pt")['input_ids'] encodings_input = encodings_input.to(dev) res = model.generate(encodings_input, do_sample=False, temperature=0.25, max_length=50, eos_token_id=50256, pad_token_id=50257)
I ran this on a v2-8 TPU with torch-xla-1.10.
torch-xla-1.10
I would suggest filing an issue on the huggingface or pytorch repo, as this does not involve any code from this repository
Disclaimer: This isn't a high-priority issue, it is a bit of a weird and unusual use case. Mostly a PSA in case anyone else runs into the same issue.
The HF version of the GPT-J model does not load using Torch/XLA.
Error:
The issue appears to be triggered by this check: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_methods.cpp#L939
A reproducible example:
I ran this on a v2-8 TPU with
torch-xla-1.10
.