sangmichaelxie / doremi

Pytorch implementation of DoReMi, a method for optimizing the data mixture weights in language modeling datasets
https://arxiv.org/abs/2305.10429
MIT License
286 stars 32 forks source link

Edge Case Discussion #21

Closed thangld201 closed 7 months ago

thangld201 commented 8 months ago

Thank for the wonderful work ! I have a question about Equation 1 in the paper. If the proxy model's parameters become (or get initialized) the same as the reference model, then training would converge and the learned domain weights might not be meaningful (since they do not matter now as loss = 0). Can you clarify @sangmichaelxie ?

sangmichaelxie commented 8 months ago

The reference model doesn't affect the proxy model's objective (it's a constant), so it's a question of how this affects the domain weights \alpha. In our implementation, we update \alpha as \alpha_new \propto \alpha exp(-eta max(excess loss, 0}), so when excess loss is 0, we'll just stop updating \alpha (and continue training the proxy model with this \alpha). We output the time-averaged \alpha, so this will slowly move the average towards this final value. Intuitively, this value of \alpha allowed us to quickly reduce the excess loss, so we should just continue with this \alpha for the rest of the compute budget remaining.