google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.38k stars 257 forks source link

Training APG with pytorch #221

Closed shenao-zhang closed 2 years ago

shenao-zhang commented 2 years ago

Hi, I'm trying to implement BPTT (or analytical PG) with brax using pytorch. But it seems that simply converting the Jax brax env to torch will only give the observation tensor, not the torch gradient. For instance, I ran a simple one-step truncated BPTT in pytorch, which gives the error "element 0 of tensors does not require grad and does not have a grad_fn"

gym_name = f'brax-hopper-v0'
env = gym.make(gym_name, episode_length=1000)
self.env = to_torch.JaxToTorchWrapper(env)
obs = self.env.reset()
action = NN(obs)
obs, reward, done, _ = self.env.step(action)
actor_loss = -reward
shenao-zhang commented 2 years ago

Besides, I noticed that it is commented in that "TDS has full differentiability, Brax has accelerator support". Can someone clarify if and why Brax is not considered as "full differentiability"?

erikfrey commented 2 years ago

Hi there! Thanks for pointing this out. Yes, Brax is fully differentiable, we actually have our own APG trainer here:

But this is all done natively within Jax. It's quite interesting, I had not considered passing gradients to PyTorch, presumably we could override .grad_fn on the PyTorch side and pass jax.grad(env.step) through the JaxToTorch wrapper. I don't see why that couldn't be done!

We can look into this in a future update, but we're also happy to review a PR for this as it might be a relatively small, targeted change.

shenao-zhang commented 2 years ago

Thanks for the clarification!

amdee commented 1 year ago

@erikfrey I am interested in this feature. Will it be possible to provide implementation guidance? I am willing to implement it.

hdadong commented 2 months ago

@erikfrey I am interested in this feature, too. Will it be possible to provide implementation guidance? I am willing to implement it.