sangmichaelxie / doremi

Pytorch implementation of DoReMi, a method for optimizing the data mixture weights in language modeling datasets
https://arxiv.org/abs/2305.10429
MIT License
286 stars 32 forks source link

Questions about the loss used for optimizing the proxy model #25

Open clarkkent0618 opened 7 months ago

clarkkent0618 commented 7 months ago

@sangmichaelxie It seems that the loss used for optimizing the proxy model in the code is different from the one described in the paper.

loss = (pertoken_loss * curr_domain_weights.detach()).sum() / normalizer

In the code, you directly use the proxy model's own loss here to optimize. But in the paper, the loss seems to be the minimax loss which uses the excess loss. So which one should I conform? Or there is something wrong with my understanding. Thanks.

image
yuzc19 commented 7 months ago

I have another question. When training the main model, what is the difference between resampling the data from the new distribution or using new weights to re-weight loss? Will these two have a significant performance gap?

Screenshot 2024-01-16 at 12 20 24
sangmichaelxie commented 7 months ago

In the code, you directly use the proxy model's own loss here to optimize. But in the paper, the loss seems to be the minimax loss which uses the excess loss. So which one should I conform? Or there is something wrong with my understanding. Thanks.

The reference model loss is a constant with respect to the proxy model's parameters, so it doesn't affect the proxy model update and we omit it. The reference model loss does affect the domain weight update.

When training the main model, what is the difference between resampling the data from the new distribution or using new weights to re-weight loss? Will these two have a significant performance gap?

Check out the main paragraph in pg 9 and Table 3b in the paper.

yuzc19 commented 7 months ago

I checked it, and it makes sense. Thank you!