pagand / ORL_optimizer

offline RL optimizer
0 stars 0 forks source link

Convert Rebrac from Jax to Pytorch #15

Open jnqian99 opened 2 weeks ago

jnqian99 commented 2 weeks ago

Base line run of Rebrac on half cheetah medium v2

https://wandb.ai/jnqian/CORL/runs/a4876f1d-be93-4616-b5d8-2ec84a1a9f5a

pagand commented 2 weeks ago

This looks good. Try to come up with better names and try to separate different components in different issues so later we can look for it faster. Some better title would be, conversion of rebrac from Jax to pytorch. Then in a new issue, create one for changing the gym to your simulator.

jnqian99 commented 2 weeks ago

I am currently running into issues while evaluating my converted algorithm

the critic_loss is way higher, and the eval/return_mean is way lower than the original Rebrac in JAX

I am still trying to figuring out why. Probably some problem in my update_actor and update_critic in the rebrac_update.py

@pagand

Image

jnqian99 commented 2 weeks ago

A few things to consider:

The input data are the same: consider using fixed batch instead of random batch to ensure reproducivility

The parameters of the network are the same ( between JAX and Pytorch)

Check the prediction from the actor and critic are the same

Check the update_actor and update_critic return the same loss

Check the update_actor and update_critic get the same actor and critic after the backward prop

@pagand

jnqian99 commented 1 week ago

The rebrac in pytorch model generates output (see the link below) seems similar to the Jax version now

https://wandb.ai/jnqian/TORL/runs/8c7f0fe8-bf84-4a0d-9fd8-ef73d1e32aef

I added one more run which adds more evaluations to smooth out the graphs a bit:

https://wandb.ai/jnqian/TORL/runs/ee4f5486-7104-4e19-ae7d-3569129b0660

@pagand