kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

HF model does not work on Torch/XLA #184

Closed TiesdeKok closed 2 years ago

TiesdeKok commented 2 years ago

Disclaimer: This isn't a high-priority issue, it is a bit of a weird and unusual use case. Mostly a PSA in case anyone else runs into the same issue.

The HF version of the GPT-J model does not load using Torch/XLA.

Error:

RuntimeError: torch_xla/csrc/tensor_methods.cpp:896 : Check failed: xla::ShapeUtil::Compatible(shapes.back(), tensor_shape) 
<stack trace>
f32[1,5,16]{2,1,0} vs. f16[1,5,16]{2,1,0}

The issue appears to be triggered by this check: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_methods.cpp#L939

A reproducible example:


import torch
import torch_xla.core.xla_model as xm
from transformers import AutoTokenizer, AutoModelForCausalLM

dev = xm.xla_device()

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", eos_token='<|endoftext|>', pad_token='<|pad|>')

model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16)
model.resize_token_embeddings(len(tokenizer))
model = model.to(dev)

prompt = "This is a prompt "
encodings_input = tokenizer(prompt, return_tensors="pt")['input_ids']
encodings_input = encodings_input.to(dev)

res = model.generate(encodings_input, do_sample=False, temperature=0.25, max_length=50, eos_token_id=50256, pad_token_id=50257)

I ran this on a v2-8 TPU with torch-xla-1.10.

kingoflolz commented 2 years ago

I would suggest filing an issue on the huggingface or pytorch repo, as this does not involve any code from this repository