zbzhu99 / decision-diffuser-jax

A JAX implementation of "Is Conditional Generative Modeling all you need for Decision-Making?"
MIT License
4 stars 1 forks source link

can you share the training logs? such as diffusion model or inverse dynamics? #3

Closed return-sleep closed 5 months ago

return-sleep commented 7 months ago

Can you please share the training logs so that we can understand the model better? I'm a little curious how to go about avoiding overfitting when we train inverse dynamics. BTW your survey "Diffusion Models for Reinforcement Learning: A Survey" is an interesting work. Thanks !

github-actions[bot] commented 7 months ago

Congrats to your first issue!

zbzhu99 commented 7 months ago

Thanks for your interest in our survey paper!

Sure, the link to my training logs on D4RL hopper datasets is https://drive.google.com/file/d/1VffG8cCve55c3i-X0rO41HbdynaS2vcu/view?usp=sharing. This repo is my self-implementation of Decision Diffuser using JAX, and the results on most datasets can not match authors' reported numbers. I have not got time to examine the reason of poor performances, and I will investigate this when I am free.

May I ask why you care about the overfitting when training inverse dynamics? The official implementation of DD just train the inverse dynamics model using all data and do not consider the overfitting issue. Did you have evidence that the overfitting can harm the performances?

return-sleep commented 7 months ago

Thank you for sharing!

My concern about overfitting comes from one of my previous experiments on a simple behavioral cloning model which maps states directly to actions. When I trained it with all the data without validation set, I found that its performance varied greatly with different random seeds, and that a minimal training loss is not equivalent to optimal performance. I'm not quite sure if it's reasonable when such a phenomenon occurs or if there's a problem with my experiment.

I noticed the relevant code in the DD about the inverse dynamics model, and training in this way makes me concerned about the generalization ability of the model. I'm also curious about how it ensures the reachability between adjacent states in generated trajectory. Because when I try to make trajectory predcition by DD-like model in latent space of images, I notice that it is possible to generate trajectories that violate the dynamics. The action inferred by inverse dynamics may not cause the desired next state.