google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Error when wrapping env with TorchWrapper #415

Closed SumeetBatra closed 8 months ago

SumeetBatra commented 8 months ago

Hello all,

I am creating a brax Env and wrapping it with TorchWrapper like so:

env = envs.create(**spec, backend='spring')
env = torch_wrapper.TorchWrapper(env, device=spec['device'])

but when I call env.reset(), I get the following error: TypeError: AutoResetWrapper.reset() missing 1 required positional argument: 'rng'

Looking at TorchWrapper's reset() method, it seems that rng is not propagated from this wrapper to the lower level brax wrappers that require it?

btaba commented 8 months ago

Take a look at the torch notebook

Looks like you need to do something like:

  env = envs.create(env_name, batch_size=num_envs, ...)
  env = gym_wrapper.VectorGymWrapper(env)
  # automatically convert between jax ndarrays and torch tensors:
  env = torch_wrapper.TorchWrapper(env, device=device)