sangmichaelxie / doremi

Pytorch implementation of DoReMi, a method for optimizing the data mixture weights in language modeling datasets
https://arxiv.org/abs/2305.10429
MIT License
286 stars 32 forks source link

loss computation wrong? #9

Closed tt6746690 closed 11 months ago

tt6746690 commented 1 year ago

It seems that the loss implementation (https://github.com/sangmichaelxie/doremi/blob/main/doremi/trainer.py#L360) is not exactly the same as the loss in the paper. In the implementation, the normalizer is Σ_i α_i Σ_{x\in Dᵢ} |x| but should just be Σ_{x\in Dᵢ} |x| for samples from i-th domain. Any comments on this observation?

Here is the code that implements the loss in the paper. It seems you get smoother domain weights using the following implementation.

# compute the rescaled loss, divide by domain weights
train_domain_weights = self.read_weights().to(pertoken_loss.device)
# if doing non-uniform sampling, normalize by inverse sampling weight
train_domain_weights = train_domain_weights / self.sampling_weights.to(train_domain_weights.device)
train_domain_weights = train_domain_weights / train_domain_weights.sum()

# (#domains,) total number of tokens amongst samples from each domain
perdomain_num_tokens = []
for domain_id in range(len(train_domain_weights)):
    domain_mask = (inputs['domain_ids'] == domain_id)
    if domain_mask.sum() > 0:
        num_tokens = token_mask[domain_mask].sum()
    else:
        num_tokens = torch.tensor(0., device=token_mask.device)
    perdomain_num_tokens.append(num_tokens)
perdomain_num_tokens = torch.stack(perdomain_num_tokens)

## sync between procs `perdomain_num_tokens` since different procs 
# might process micro-batch samples from the same domain.
dist.all_reduce(perdomain_num_tokens, op=torch.distributed.ReduceOp.SUM)
# scale by world size because DDP averages gradients
perdomain_num_tokens = perdomain_num_tokens / self.args.world_size

# avoid division by zero
perdomain_num_tokens[torch.where(perdomain_num_tokens==0)] = 1.
# (#domains,) equivalent to αᵢ / Σ_{x\in D_i} |x|
perdomain_coeff = train_domain_weights/perdomain_num_tokens
# (bsz, seq_len-1)
coeff = perdomain_coeff[inputs['domain_ids']].unsqueeze(-1) * token_mask
loss = (pertoken_loss * coeff.detach()).sum()
sangmichaelxie commented 12 months ago

As long as there are a sufficient number of tokens in the batch, the two objectives should be about the same (and they converge to the same value with more samples). In my tests, the two objectives seem to return almost the same average domain weights. Our implementation follows Pytorch's class weighting implementation for the CrossEntropyLoss, and makes sure that the loss scale is preserved even when there are some domains missing (could be important for situations where there are more domains).

About your code, I think that if you choose to normalize by the observed token frequency, you don't need to do this part:

# if doing non-uniform sampling, normalize by inverse sampling weight
train_domain_weights = train_domain_weights / self.sampling_weights.to(train_domain_weights.device)
train_domain_weights = train_domain_weights / train_domain_weights.sum()
tt6746690 commented 11 months ago

Thanks for the clarification!