RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
577 stars 54 forks source link

`Pong-misc`: TypeError: select cases must have the same shapes, got [(30, 40), ()]. #76

Open HelgeS opened 1 month ago

HelgeS commented 1 month ago

When running the Pong-misc environment, the following error is raised from move_paddles.

I tried both the example notebook and gymnax-blines to ensure it's not an usage error.

Below is the stack trace and the gymnax-blines configuration I have used.

$ python train.py -config agents/Pong-misc/ppo.yaml

PPO:   0%|                                                                                                                                                                                                                                                        | 0/18751 [00:00<?, ?it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/helge/Sandbox/pt/gymnax-blines/train.py", line 76, in <module>
    main(
  File "/home/helge/Sandbox/pt/gymnax-blines/train.py", line 24, in main
    log_steps, log_return, network_ckpt = train_fn(
  File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 271, in train_ppo
    train_state, obs, state, batch, rng_step = get_transition(
  File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 252, in get_transition
    next_obs, next_state, reward, done, _ = rollout_manager.batch_step(
  File "/home/helge/Sandbox/pt/gymnax-blines/utils/ppo.py", line 138, in batch_step
    return jax.vmap(self.env.step, in_axes=(0, 0, 0, None))(
  File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/environment.py", line 45, in step
    obs_st, state_st, reward, done, info = self.step_env(key, state, action, params)
  File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/misc/pong.py", line 75, in step_env
    state = move_paddles(
  File "/home/helge/Sandbox/pt/code/.venv/lib/python3.10/site-packages/gymnax/environments/misc/pong.py", line 356, in move_paddles
    new_center_p2 = jax.lax.select(use_ai_policy, new_center_ai, new_center_self)
TypeError: select cases must have the same shapes, got [(30, 40), ()].

Configuration (copied from CartPole-v1):

train_config:
  train_type: "PPO"
  num_train_steps: 150000
  evaluate_every_epochs: 1000

  env_name: "Pong-misc"
  env_kwargs: {}
  env_params: {}
  num_test_rollouts: 164

  num_train_envs: 8  # Number of parallel env workers
  max_grad_norm: 0.5  # Global norm to clip gradients by
  gamma: 0.99  # Discount factor
  n_steps: 32 # "GAE n-steps"
  n_minibatch: 4 # "Number of PPO minibatches"
  lr_begin: 5e-04  # Start PPO learning rate
  lr_end: 5e-04 #  End PPO learning rate
  lr_warmup: 0.05 # Prop epochs until warmup is completed 
  epoch_ppo: 4  # "Number of PPO epochs on a single batch"
  clip_eps: 0.2 # "Clipping range"
  gae_lambda: 0.95 # "GAE lambda"
  entropy_coeff: 0.01 # "Entropy loss coefficient"
  critic_coeff: 0.5  # "Value loss coefficient"

  network_name: "Categorical-MLP"
  network_config:
    num_hidden_units: 64
    num_hidden_layers: 2

log_config:
  time_to_track: ["num_steps"]
  what_to_track: ["return"]
  verbose: false
  print_every_k_updates: 1
  overwrite: 1
  model_type: "jax"

device_config:
  num_devices: 1
  device_type: "gpu"