pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 483 forks source link

Model support for `soft_actor_critic` with Torch_XLA2 #8180

Open ManfeiBai opened 1 month ago

ManfeiBai commented 1 month ago

Fix the model test for soft_actor_critic.py

  1. setup env according to Run a model under torch_xla2
  2. Run model test under run_torchbench/ with python models/your_target_model_name.py
  3. Fix the failure.

Please refer to this guide as guide to fix:

Also refer to these PRs:

barney-s commented 3 weeks ago
 % JAX_PLATFORMS=cpu python models/soft_actor_critic.py                    
/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:337: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
Traceback (most recent call last):
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/soft_actor_critic.py", line 61, in <module>
    sys.exit(main())
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/soft_actor_critic.py", line 22, in main
    model, example = benchmark.get_module()
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/benchmark/torchbenchmark/models/soft_actor_critic/__init__.py", line 206, in get_module
    next_state, reward, done, info, _unused = self.train_env.step(action)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/gym/core.py", line 460, in step
    return self.env.step(self.action(action))
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/gym/wrappers/time_limit.py", line 50, in step
    observation, reward, terminated, truncated, info = self.env.step(action)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/gym/wrappers/order_enforcing.py", line 37, in step
    return self.env.step(action)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/gym/wrappers/env_checker.py", line 37, in step
    return env_step_passive_checker(self.env, action)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/gym/utils/passive_env_checker.py", line 233, in env_step_passive_checker
    if not isinstance(terminated, (bool, np.bool8)):
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/numpy/__init__.py", line 410, in __getattr__
    raise AttributeError("module {!r} has no attribute "
AttributeError: module 'numpy' has no attribute 'bool8'. Did you mean: 'bool'?