pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Model support for `hf_Reformer` with Torch_XLA2 #8138

Open ManfeiBai opened 1 month ago

ManfeiBai commented 1 month ago

Fix the model test for hf_Reformer.py

  1. setup env according to Run a model under torch_xla2
  2. Run model test under run_torchbench/ with python models/your_target_model_name.py
  3. Fix the failure.

Please refer to this guide as guide to fix:

Also refer to these PRs:

barney-s commented 1 week ago

With script changes the test is still failing.

% JAX_ENABLE_X64=true JAX_PLATFORMS=cpu python models/hf_Reformer.py
/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:337: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
Traceback (most recent call last):
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/hf_Reformer.py", line 61, in <module>
    sys.exit(main())
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/hf_Reformer.py", line 39, in main
    xla2_ans = model(**example)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/transformers/models/reformer/modeling_reformer.py", line 2404, in forward
    reformer_outputs = self.reformer(
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/transformers/models/reformer/modeling_reformer.py", line 2099, in forward
    encoder_outputs = self.encoder(
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/transformers/models/reformer/modeling_reformer.py", line 1721, in forward
    hidden_states = _ReversibleFunction.apply(
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/transformers/models/reformer/modeling_reformer.py", line 1609, in forward
    layer_outputs = layer(
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/transformers/models/reformer/modeling_reformer.py", line 1474, in forward
    attn_outputs = self.attention(
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/transformers/models/reformer/modeling_reformer.py", line 1307, in forward
    self_attention_outputs = self.self_attention(
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/transformers/models/reformer/modeling_reformer.py", line 1202, in forward
    assert out_vectors.shape == (
AssertionError
barni@barni ~/workspace/pytorch-tpu/run_torchbench
 %