Closed shenao-zhang closed 2 years ago
Besides, I noticed that it is commented in https://github.com/openai/gym/issues/2456#issue-1032765998 that "TDS has full differentiability, Brax has accelerator support". Can someone clarify if and why Brax is not considered as "full differentiability"?
Hi there! Thanks for pointing this out. Yes, Brax is fully differentiable, we actually have our own APG trainer here:
https://github.com/google/brax/blob/main/brax/training/agents/apg/train.py
But this is all done natively within Jax. It's quite interesting, I had not considered passing gradients to PyTorch, presumably we could override .grad_fn
on the PyTorch side and pass jax.grad(env.step)
through the JaxToTorch
wrapper. I don't see why that couldn't be done!
We can look into this in a future update, but we're also happy to review a PR for this as it might be a relatively small, targeted change.
Thanks for the clarification!
@erikfrey I am interested in this feature. Will it be possible to provide implementation guidance? I am willing to implement it.
@erikfrey I am interested in this feature, too. Will it be possible to provide implementation guidance? I am willing to implement it.
Hi, I'm trying to implement BPTT (or analytical PG) with brax using pytorch. But it seems that simply converting the Jax brax env to torch will only give the observation tensor, not the torch gradient. For instance, I ran a simple one-step truncated BPTT in pytorch, which gives the error "
element 0 of tensors does not require grad and does not have a grad_fn
"