pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Model support for `tacotron2` with Torch_XLA2 #8185

Open ManfeiBai opened 1 month ago

ManfeiBai commented 1 month ago

Fix the model test for tacotron2.py

  1. setup env according to Run a model under torch_xla2
  2. Run model test under run_torchbench/ with python models/your_target_model_name.py
  3. Fix the failure.

Please refer to this guide as guide to fix:

Also refer to these PRs:

barney-s commented 2 weeks ago

Model requires CUDA ?

 % JAX_ENABLE_X64=true JAX_PLATFORMS=cpu python models/tacotron2.py 
/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:337: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
Traceback (most recent call last):
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/tacotron2.py", line 61, in <module>
    sys.exit(main())
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/tacotron2.py", line 21, in main
    benchmark = benchmark_cls(test="eval", device = "cpu")
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/benchmark/torchbenchmark/util/model.py", line 43, in __call__
    obj = type.__call__(cls, *args, **kwargs)
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/benchmark/torchbenchmark/models/tacotron2/__init__.py", line 32, in __init__
    raise NotImplementedError(
NotImplementedError: Tacotron2 doesn't support CPU because load_model assumes CUDA.
ManfeiBai commented 2 weeks ago

Model requires CUDA ?

 % JAX_ENABLE_X64=true JAX_PLATFORMS=cpu python models/tacotron2.py 
/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:337: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
Traceback (most recent call last):
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/tacotron2.py", line 61, in <module>
    sys.exit(main())
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/tacotron2.py", line 21, in main
    benchmark = benchmark_cls(test="eval", device = "cpu")
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/benchmark/torchbenchmark/util/model.py", line 43, in __call__
    obj = type.__call__(cls, *args, **kwargs)
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/benchmark/torchbenchmark/models/tacotron2/__init__.py", line 32, in __init__
    raise NotImplementedError(
NotImplementedError: Tacotron2 doesn't support CPU because load_model assumes CUDA.

Let's skip this model for now