buoyancy99 / diffusion-forcing

code for "Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion"
Other
494 stars 19 forks source link

Video UNet not learning? #13

Closed julian-q closed 4 weeks ago

julian-q commented 1 month ago

image

Hi, thanks for the nice repo! I'm trying to follow the instructions for training a video model on Minecraft, and it looks like the loss is staying stagnant for 60k steps.

Command: python -m main +name=minecraft_unet algorithm=df_video dataset=video_minecraft

Is this normal? Any tips much appreciated!!

buoyancy99 commented 1 month ago

Hello,

How does the training_vis images look? Do they look legit?

I am wondering whether you finished downloading the entire dataset, which can take a couple days, or did you accidentally used the mini subset shipped with the checkpoint file?

Some people trained with the mini subset and as a result, it's only training with like 10 videos and overfit.

julian-q commented 1 month ago

Here's the latest training_vis/video_2 from wandb - it looks pretty noisy: video_2_6086_b0aaf937f038141a0728

I think I have the whole dataset because I symlinked the dataset I previously downloaded with the TECO bash script.

Here's my config.yaml. I think the one modification I made was the batch size. Maybe I should retry on a machine with more memory so that I can use the original config.

Update: Oh, oops, it turns out I was not training on the whole dataset because I was still using the metadata.json from the sample dataset. To create a metadata.json for my existing TECO dataset, I wrote this script. I'll train again with this metadata file and see how it goes!

julian-q commented 1 month ago

Even after training with the original config on the whole dataset, it still looks stagnant 🤔

Screenshot 2024-08-15 at 11 57 51 AM

And the training vis is still noisy: video_2_4993_bd4aaf968e0eca8147c1

julian-q commented 1 month ago

I don't think it's a problem with the dataset, since when I switch to the paper branch and train the RNN-based video model, the loss goes down just fine

Screenshot 2024-08-15 at 1 44 58 PM

"minecraft_video" is on the paper branch, "minecraft_unet_transformer" is on the main branch, same data for both runs

julian-q commented 1 month ago

Okay, I tried cloning the diffusion-forcing-transformer repo for the original implementation of the transformer UNet to see if there's a difference. I started training using this command, and I now see the loss going down:

Screenshot 2024-08-15 at 2 54 08 PM

After 9k steps the training vis looks good too: video_2_427_871367e5c96e2fa56b1a

Will try to see what the discrepancy is between main and diffusion-forcing-transformer ..

buoyancy99 commented 4 weeks ago

Hi Julian I will immediately start to debug this - Right now the v1.5 video diffusion is supposed to be exactly the same as diffusion-forcing-transformer

buoyancy99 commented 4 weeks ago

Thank you for pointing this out, I found the bug. Can you pull this repo again?

The bug fix is this

julian-q commented 4 weeks ago

Haha, nice catch:) Loss is going down now on main! Thank you 🙏