XiangLi1999 / Diffusion-LM

Diffusion-LM
Apache License 2.0
1.02k stars 133 forks source link

E2E training procedure #67

Open elephantmipt opened 1 year ago

elephantmipt commented 1 year ago

Hi,

Thank you for sharing the code from your insightful paper!

I'm attempting to train the model using the end-to-end (e2e) setup, and I've encountered an issue related to embeddings. As I understand, you're utilizing the TextDataset_NoCache class for the dataset, which comprises the model's embedding.

https://github.com/XiangLi1999/Diffusion-LM/blob/759889d58ef38e2eed41a8c34db8032e072826f4/improved-diffusion/improved_diffusion/text_datasets.py#L815-L828

In the training script, you're passing model=None to the load_data_text function.

https://github.com/XiangLi1999/Diffusion-LM/blob/759889d58ef38e2eed41a8c34db8032e072826f4/improved-diffusion/scripts/train.py#L81-L105

I assume that the embeddings are initialized at:

https://github.com/XiangLi1999/Diffusion-LM/blob/main/improved-diffusion/improved_diffusion/text_datasets.py

However, in the e2e setup, it seems logical that one would want to use the continuously updated embeddings from the model. Looking through the training loop, I couldn't find any indication that the embeddings are updated from the model after each gradient step. Could you please shed light on how it's feasible to train embeddings end-to-end when the embeddings are housed within the dataset class?

Thank you for your time and clarification!