sangmichaelxie / doremi

Pytorch implementation of DoReMi, a method for optimizing the data mixture weights in language modeling datasets
https://arxiv.org/abs/2305.10429
MIT License
286 stars 32 forks source link

Domain weights are mostly near one-hot #5

Closed xiamengzhou closed 1 year ago

xiamengzhou commented 1 year ago

Hi Michael,

Thanks for the amazing work and releasing the code base! I am running doremi in my own setting and noticed that most of the times, the domain weights are nearly one-hot. As a result, most proxy model updates are dominated by the domain that has the worst excess loss. I wonder if it is expected? In this case, the proxy model reduces the overall average loss much slower than training with original domain proportion.

Thanks!

mysuns commented 1 year ago

weights are one-hot. why? There are only 2 datasets, the model parameters (batch size is 2, 8 gpus)are git default, can you give some suggestions, thank you

sangmichaelxie commented 1 year ago

This could depend on many factors, but you could try a few things: 1) training the reference model with the original domain proportion (according to domain sizes) to avoid overfitting to any small domains, 2) making sure the batch size is big enough to include examples from all the domains, 3) reducing the reweight_eta parameter to 0.1 (but I haven't found this to be necessary before). Properly shuffling the data also makes a big difference in the domain weight stability - if you were using an old version of this codebase, we didn't handle shuffling well, so it could be worth trying to update to HEAD.

xiamengzhou commented 1 year ago

Hi Michael, thanks for the response! I was not using this codebase, and I implemented it on my tech stack based on it -- it seems that one-hot weight is more like a feature rather than a bug, as the data loading prioritizes the worst-performing domain if a domain keeps worst-performing for >50 steps.