graphcore-research / unit-scaling

A library for unit scaling in PyTorch
https://graphcore-research.github.io/unit-scaling/
Apache License 2.0
98 stars 7 forks source link

Custom loss unit scaling #63

Closed norikazu99 closed 2 months ago

norikazu99 commented 2 months ago

Hello and thanks a lot for sharing your great research and code.

The mse loss doesn't seem to scale_fwd whereas cross entropy loss does. On the other hand mse loss scales bwd for both input and target by a fixed value, while cross entropy doesn't. Can you please explain the reasons behind these decisions.

The following is my first attempt at implementing a unit_scaling version of quantile loss and would like to know if I missed something.

def loss_fn(self, y_hat, y):
    loss = repeat(y, 'b s d -> b s (d n)', n=self.n_quantiles) - y_hat
    loss = torch.max((self.quantiles-1.0)*loss, self.quantiles*loss).mean()
    return loss

def unit_mup_loss_fn(self, y_hat, y):
    grad_scale = 8 ** -0.5
    y_hat = scale_bwd(y_hat, grad_scale)
    y = scale_bwd(y, grad_scale)

    loss = repeat(y, 'b s d -> b s (d n)', n=self.n_quantiles) - y_hat
    loss = torch.max((self.quantiles-1.0)*loss, self.quantiles*loss).sum()
    return scale_fwd(loss, 1 / y_hat.nelement())
DouglasOrr commented 2 months ago

Hi @norikazu99, thanks for your comments & for sharing your question.

The mse loss doesn't seem to scale_fwd whereas cross entropy loss does. On the other hand mse loss scales bwd for both input and target by a fixed value, while cross entropy doesn't. Can you please explain the reasons behind these decisions.

I'll have a go at annotating the code. An important thing to note is that for losses, we assume constraint=None, which means we're free to choose separate scaling factors in the forward and backward passes.

# Cross entropy loss
    input = scale_bwd(input, vocab_size / (vocab_size - 1) ** 0.5)  # from [unit scaling] Table 5
    input = scale_fwd(input, mult)  # multiplier to "sharpen" the softmax; no need to scale grad since constraint=None
    loss = F.cross_entropy(...)
    if reduction == "mean":
        return scale_fwd(loss, 1 / batch_size)  # just scale_fwd, thanks to separate scaling of fwd/bwd
    assert reduction == "sum"
    return loss

# MSE loss
    grad_scale = 8**-0.5  # Assume `input` and `target` are independent ~N(0, 1) & grad(input)=2*(input-target)
    input = scale_bwd(input, grad_scale)
    target = scale_bwd(target, grad_scale)  # typically target.requires_grad is False, but just in case
    loss = F.mse_loss(input, target, size_average, reduce, reduction="sum")
    if reduction == "mean":
        return scale_fwd(loss, 1 / input.nelement())  # just scale_fwd, thanks to separate scaling of fwd/bwd
    assert reduction == "sum"
    return loss

[unit scaling] https://arxiv.org/abs/2303.11257

So I think they're reasonably consistent. Softmax cross entropy has no scaling for labels which can't receive a gradient. MSE loss has no mult, as I think a multiplier on the inputs wouldn't make sense, and a multiplier on (inputs - targets) would end up as a scaling factor on both forward and backward passes, which we'd then cancel out to recover unit scale.


R.e. quantile loss, that looks broadly reasonable, yes, although I am unfamiliar with the particulars of this loss function. If independent Gaussian inputs are a reasonable assumption, it looks like the scale should possibly be $\times2$ not $\times2\sqrt{2}$, however:

quantiles = torch.tensor([0.5])  # multiple quantile scaling is broken
y = torch.randn(1000, 128, 1).requires_grad_()
y_hat = torch.randn(1000, 128, len(quantiles)).requires_grad_()

def loss_fn(y_hat, y):
    loss = repeat(y, 'b s d -> b s (d n)', n=len(quantiles)) - y_hat
    loss = torch.max((quantiles-1.0)*loss, quantiles*loss).mean()
    return loss

def unit_mup_loss_fn(y_hat, y):
    grad_scale = 2  # 8 ** -0.5
    y_hat = scale_bwd(y_hat, grad_scale)
    y = scale_bwd(y, grad_scale)

    loss = repeat(y, 'b s d -> b s (d n)', n=len(quantiles)) - y_hat
    loss = torch.max((quantiles-1.0)*loss, quantiles*loss).sum()
    return scale_fwd(loss, 1 / y_hat.nelement())

y.grad = y_hat.grad = None
loss = loss_fn(y_hat, y)
loss.backward()
print("   unscaled", loss.item(), y.grad.std().item(), y_hat.grad.std().item())

y.grad = y_hat.grad = None
loss = unit_mup_loss_fn(y_hat, y)
loss.backward()
print("unit-scaled", loss.item(), y.grad.std().item(), y_hat.grad.std().item())

This seems to be robust to different choices of a single quantile. But the scaling is not good when using multiple quantiles. This is because the repeat() causes a sum of gradients in the backward pass. This might require some thought about the distributions to fix. (Also, perhaps in the case of multiple quantiles, drawing y_hat independently is a bad idea?)

norikazu99 commented 2 months ago

Hello @DouglasOrr . Thanks for your help and quick response. Indeed drawing y_hat independently using the method described above isn't ideal.

In my case, the bwd of y doesn't really matter. I attempted to find a rule for y_hat_grad_scale but was unsuccessful in doing so for multiple quantile sets (values and lengths) due to loss bwd depending on quantiles values (and length too i suppose). I tried to play around with averaging out 3rd dim of loss and scaling y_hat_grad_scale by functions of quantiles.std(). Would hand picking y_hat_grad_scale for a given set of quantiles (to satisfy y_hat.grad.std() == 1.0) do the trick for a given setup?

If independent Gaussian inputs are a reasonable assumption

Are you referring to the inputs of the model or of the loss function? I'm assuming both right? If so how does this translate to sequence models and next token prediction where the sequence isn't iid I believe.

DouglasOrr commented 2 months ago

I tried to play around with averaging out 3rd dim of loss and scaling y_hat_grad_scale by functions of quantiles.std(). Would hand picking y_hat_grad_scale for a given set of quantiles (to satisfy y_hat.grad.std() == 1.0) do the trick for a given setup?

That would be fine, but the approach we've often taken has been to go for a simple and rough scaling rule, and then check the at-initialisation scales of gradients in the full model, treating them as acceptable if they're within a factor of 10 or so from unit scale. The technique does not break down if scales aren't close to 1, the main thing to look out for is any scales that are comparable to the FP format range that is being used (~400 for FP8-E4M3, ~60,000 for FP16), for sake of numerics.

If independent Gaussian inputs are a reasonable assumption

Are you referring to the inputs of the model or of the loss function? I'm assuming both right? If so how does this translate to sequence models and next token prediction where the sequence isn't iid I believe.

I was thinking of the at-initialisation inputs to the loss function, since we're just trying to determine a local scaling rule for the loss function in isolation. You're right, in any case, that this assumption is almost always wrong, but often it isn't "too wrong" to give reasonable overall scaling behaviour, especially at initialisation. The benefit of taking simplistic assumptions is that they keep the components quite generic, not requiring the user to choose the right version based on their input distribution, but the opportunity remains open to make specialised versions that do specify input distributions.

I hope this helps, but I can't be too prescriptive as I don't think there's a right or wrong answer when it comes to the assumptions to make & the tightness of the scaling estimation—it's somewhat of a trade-off!

norikazu99 commented 2 months ago

That would be fine, but the approach we've often taken has been to go for a simple and rough scaling rule, and then check the at-initialisation scales of gradients in the full model, treating them as acceptable if they're within a factor of 10 or so from unit scale. The technique does not break down if scales aren't close to 1, the main thing to look out for is any scales that are comparable to the FP format range that is being used (~400 for FP8-E4M3, ~60,000 for FP16), for sake of numerics.

That is reassuring to hear from one of the authors. When referring to the "technique breaking down", are your talking about stable low-precision training or hyperparameter transfer? , ,

I was thinking of the at-initialisation inputs to the loss function, since we're just trying to determine a local scaling rule for the loss function in isolation. You're right, in any case, that this assumption is almost always wrong, but often it isn't "too wrong" to give reasonable overall scaling behaviour, especially at initialisation. The benefit of taking simplistic assumptions is that they keep the components quite generic, not requiring the user to choose the right version based on their input distribution, but the opportunity remains open to make specialised versions that do specify input distributions.

I hope this helps, but I can't be too prescriptive as I don't think there's a right or wrong answer when it comes to the assumptions to make & the tightness of the scaling estimation—it's somewhat of a trade-off!

Yes. This definitely helps a lot. It seems that the loss scaling seems to be one of the simplest parts of unit scaling. The fwd scale can be decoupled from bwd scale due to it being a cut-edge (if I understood the term correctly), allowing us to unit scale the bwd.

Thanks for sharing your work. Will be sharing how well hyperparameter transfer does on different loss function as I work on them ( I use differentiable simulators and other less common loss functions).

DouglasOrr commented 2 months ago

When referring to the "technique breaking down", are your talking about stable low-precision training or hyperparameter transfer?

I was thinking of hyperparameter transfer primarily, but stable low-precision training should be OK with a somewhat relaxed tolerance on scaling properties (the rough factor-of-10 heuristic I mentioned.)

Will be sharing how well hyperparameter transfer does on different loss function as I work on them ( I use differentiable simulators and other less common loss functions).

Sounds interesting; let me know if/when your findings are public - I'd be keen to know how it goes!

DouglasOrr commented 2 months ago

Hi, thanks for the update, good to hear!

Would you recommend I scale the loss function as a whole, or the different operations independently.

I think the general rule is that if it's possible to make reasonable distributional assumptions on an intermediate value, we can scale ops independently, otherwise we should consider the larger unit. An example is dot-product-attention, where the output of the attention softmax is interesting & cannot be scaled to approximate a unit normal, so we consider scaling it together with the following matrix multiplication.

How would you recommend I scale this loss function with respect to s and d.

The particular difficulty with the loss you give is the occurrence of out in (out * x[:, i] * y[:, i]).mean(dim=1, keepdim=True). If we were to ignore this, the only part of the loss with bad scaling would be the mean, so we can do something like this:

def scaled_sim_loss_MODIFIED(x, y):
    grad_scale = x.size(2)
    x = scale_bwd(x, grad_scale)
    out = 1.0
    for i in range(x.size(1)):
        out = out + (x[:, i] * y[:, i]).mean(dim=1, keepdim=True)
    out = (out - 1.0).sum()
    return scale_fwd(out, 1/x.size(0))

My working here is to go backwards. The last two lines scale_fwd and (sum, subtract) do not modify the gradients of out. The for loop is now a simple reduce-sum over non-overlapping inputs, so also doesn't modify gradients. The mean-of-products only introduces bad scaling due to mean, which we can correct by scaling up by that axis length.

If I take this and reintroduce out into (x[:, i] * y[:, i]).mean(dim=1, keepdim=True), the results are mis-scaled, but not too badly (I think the mis-scaling depends on the ratio (s/d)).

Hope this is somewhat useful, even if not quite there.