weipu-zhang / STORM

40 stars 10 forks source link

How to reproduce the results in table2? #5

Closed hutchinsonian closed 3 months ago

hutchinsonian commented 3 months ago

Thanks for such a great job, I'm new in this field. According to the readme, I run ./train.sh first and then ./eval.sh. I got this step, episode_avg_return 100000,1555.5 in the MsPacman-life_done-wm_2L512D8H-100k-seed1.csv file. Is there any connection between this result and table2? @weipu-zhang

weipu-zhang commented 3 months ago

Hi,

Could you try more seeds and compare them with the plot in Appendix A Figure 6? We've developed on Pacman for some time and I suppose it should generate a similar performance. The result in Table 2 is the average over 5 seeds.

(The results are generated with 3090/4090, we don't know the full performance for fp32/fp16. I'm not sure about your device, but if it's the same please ignore this.)

hutchinsonian commented 3 months ago

thank you for your reply. I checked Appendix A Figure 6. Does the horizontal axis mean training steps? In the initial settings, the default is 100k. Does this mean that the result of eval 1555.5 corresponds to the value of 100k on the horizontal axis in this figure? image

weipu-zhang commented 3 months ago

No, that's already the final result.

The 400k corresponds to the Atari-gym's sample steps (60FPS). Most algorithms use frame skip (skip 4 -> 15FPS) to speed up training/inference, and their authors use the term "Atari 100k" to describe it. This statement is usually hidden in the experiments/method section with only one sentence and thus the setting may be unclear to the readers. Following the DreamerV3, we use 400k in the plot to clarify this.

hutchinsonian commented 3 months ago

Thank you very much for your reply again!

My device is A6000, which can support bfloat16.

I did not modify the configuration during training, and the seed was 1. I found that only termination_loss and reconstruction loss have a downward trend in the losses corresponding to the world model. Does this mean there is something wrong with my operation?

image

weipu-zhang commented 3 months ago

That looks good to me. MsPacman will see more different rewards in later episodes, so the reward loss may go up, for games like Pong it will just go down. The Dyn loss would also go up as the visual encoder is trained with the sequence model simultaneously, so it may behave like bootstrapping updates (like the value loss at the policy part, the absolute loss value can't reflect learning progress).

hutchinsonian commented 3 months ago

You answered my question, thanks:)