Open jnqian99 opened 2 weeks ago
This looks good. Try to come up with better names and try to separate different components in different issues so later we can look for it faster. Some better title would be, conversion of rebrac from Jax to pytorch. Then in a new issue, create one for changing the gym to your simulator.
I am currently running into issues while evaluating my converted algorithm
the critic_loss is way higher, and the eval/return_mean is way lower than the original Rebrac in JAX
I am still trying to figuring out why. Probably some problem in my update_actor and update_critic in the rebrac_update.py
@pagand
A few things to consider:
The input data are the same: consider using fixed batch instead of random batch to ensure reproducivility
The parameters of the network are the same ( between JAX and Pytorch)
Check the prediction from the actor and critic are the same
Check the update_actor and update_critic return the same loss
Check the update_actor and update_critic get the same actor and critic after the backward prop
@pagand
The rebrac in pytorch model generates output (see the link below) seems similar to the Jax version now
https://wandb.ai/jnqian/TORL/runs/8c7f0fe8-bf84-4a0d-9fd8-ef73d1e32aef
I added one more run which adds more evaluations to smooth out the graphs a bit:
https://wandb.ai/jnqian/TORL/runs/ee4f5486-7104-4e19-ae7d-3569129b0660
@pagand
Base line run of Rebrac on half cheetah medium v2
https://wandb.ai/jnqian/CORL/runs/a4876f1d-be93-4616-b5d8-2ec84a1a9f5a