sangmichaelxie / doremi

Pytorch implementation of DoReMi, a method for optimizing the data mixture weights in language modeling datasets
https://arxiv.org/abs/2305.10429
MIT License
286 stars 32 forks source link

question about only updating the domain weights on process 0 #8

Closed SueJane closed 9 months ago

SueJane commented 1 year ago

Hi Michael,

Thanks for releasing this code base and all the amazing work you have done! I'm learning about DoReMi and have a question: I noticed that the domain weights are updated only on the process 0, so how do other processes get the new weights when compute the loss and update the proxy model?

Thanks!

SueJane commented 1 year ago

cc @zhuzilin

sangmichaelxie commented 1 year ago

The domain weight update is communicated to the other processes automatically since the weights are stored in a torch buffer. This is similar to how batch norm is implemented.

MauritsBleeker commented 9 months ago

Hi,

Thanks for sharing the project code!

Are you sure buffers are automatically broadcasted to all processes when you change the weights/values of the buffer only on process/rank 0? I implemented something similar a while ago, and if I changed the buffer on rank/process 0, only the buffer on rank/process 0 was updated, the others remained the same. I explicitly had to broadcast the update.

Thanks, Maurits

sangmichaelxie commented 9 months ago

Yes, we've checked before (and I just checked again) and the buffers are automatically updated. I also ran a version with broadcasting for a while and the domain weight trajectories were almost identical. In your other implementation, did you create the buffer before wrapping the model in DDP?

Also see the image below from pytorch DDP docs: https://pytorch.org/docs/stable/notes/ddp.html

image