Closed SumeetBatra closed 8 months ago
Take a look at the torch notebook
Looks like you need to do something like:
env = envs.create(env_name, batch_size=num_envs, ...)
env = gym_wrapper.VectorGymWrapper(env)
# automatically convert between jax ndarrays and torch tensors:
env = torch_wrapper.TorchWrapper(env, device=device)
Hello all,
I am creating a brax Env and wrapping it with TorchWrapper like so:
but when I call env.reset(), I get the following error:
TypeError: AutoResetWrapper.reset() missing 1 required positional argument: 'rng'
Looking at TorchWrapper's reset() method, it seems that rng is not propagated from this wrapper to the lower level brax wrappers that require it?