HazyResearch / safari

Convolutions for Sequence Modeling
Apache License 2.0
848 stars 70 forks source link

Hyena seems forward leakage? #40

Closed datatalkv closed 8 months ago

datatalkv commented 8 months ago

hi, I'm testing Hyena structure and for simplicity let's focus on standalone_hyena.py. I revise the existing example just a little to demo the existence of forward leakage below.

I understand there is a "Causality check by gradient". But we can do it alternatively in a disturbance from input to output. Specifically, we should believe "a disturbance over input from future should not affect output in the past".


if __name__ == "__main__":
    x = torch.randn(1, 1024, 512, requires_grad=True)
    x2 = x.detach().clone()
    x2[:, 3, :] = 1e10 * torch.sign(x2[:, 3, :])      # an obvious interrupt/disturbance added at somewhere, set to 3 in this example

    m_hyena = HyenaOperator(
        d_model=x.shape[2],
        l_max=1024,
        order=2,
        filter_order=64
    )

    yhat = m_hyena(x)
    yhat2 = m_hyena(x2)
    print(yhat[0,:20,0])
    print(yhat2[0,:20,0])

tensor([-0.0627, -0.0652, -0.0135, -0.0132, -0.0580, -0.0491, 0.0646, -0.1024, -0.0986, -0.0817, -0.0367, -0.1008, 0.0637, -0.0156, -0.0301, -0.1105, -0.0579, -0.1214, 0.0497, -0.1366], grad_fn=) tensor([ 5.7309e+08, 3.5364e+09, 2.8772e+09, 1.9058e+27, 7.9790e+27, 1.5603e+27, 1.4734e+16, -4.5865e+16, -1.1996e+17, -6.1985e+16, 5.0280e+16, 2.5590e+16, 6.0684e+16, 1.7392e+16, -8.0251e+15, 6.0426e+16, 1.1078e+16, 8.8085e+16, 1.4208e+16, 1.3096e+17], grad_fn=)

Intuitively, the we should expect the first three outputs of yhat and yhat2 are the same. However, they are not as print. I think it is an indicator showing that there are some "forward leakage" problems in Hyena implementation.

DanFu09 commented 8 months ago

I'm looking into this. Does it still happen with smaller interruptions, i.e. 1e5? One possibility is you may be hitting the limits of fp32 precision. Hyena gates the input, so it looks like v * (conv(x1 * x2, k)), where x1, x2, and v are linear projections of the input. Randomly initialized weights could make that calculation overflow.

datatalkv commented 8 months ago

I'm looking into this. Does it still happen with smaller interruptions, i.e. 1e5? One possibility is you may be hitting the limits of fp32 precision. Hyena gates the input, so it looks like v * (conv(x1 * x2, k)), where x1, x2, and v are linear projections of the input. Randomly initialized weights could make that calculation overflow.

Hi yes, 1e3/1e4/1e5 still leads to significant differences.

Zymrael commented 8 months ago

fp32 precision limits are never hit during training unless something goes really wrong during optimization, especially with normalization layers. Causal leakage is easy to detect since the loss will rapidly decrease towards unreasonably small values (often 0) and generation outputs will be nonsensical. Downstream eval numbers will also generally be terrible with leakage. FWIW, we've never seen this happen in practice (and have trained large causal models with a similar structure).

With your code, and a perturbation of magnitude 1e3 at position 3, at float32 precision:

tensor([0.0700, 0.0530, 0.0830, 0.0300, 0.0730, 0.0920, 0.1050, 0.1250, 0.0870,
        0.1480, 0.1910, 0.0740, 0.0240, 0.0160, 0.0900, 0.1170, 0.0100, 0.0820,
        0.1730, 0.1700])
tensor([ 7.0000e-02,  5.3000e-02,  8.3000e-02,  9.6446e+05, -3.1481e+05,
         1.6080e+06,  5.0385e+02, -7.4629e+01,  1.5281e+03,  1.5118e+03,
         1.3675e+03,  7.9940e+02,  1.0951e+03,  3.9184e+03,  4.1972e+03,
         1.5654e+03,  1.4578e+03,  1.2164e+03,  2.4598e+03,  1.2853e+03])

You can see a difference at 1e4 and larger. For example, if you compare single and double precision (torch.float32 and torch.float64):

tensor([-0.0270, -0.0500, -0.0600, -0.0250, -0.0550, -0.1160, -0.0360, -0.0520,
        -0.0050, -0.0310,  0.0490,  0.1170,  0.0120,  0.0790, -0.0160,  0.0110,
         0.0790,  0.0330, -0.0270,  0.1380])
tensor([-2.9000e-02, -5.4000e-02, -6.0000e-02, -1.2226e+09,  6.4406e+08,
        -5.8460e+08, -1.5587e+05,  4.3244e+03, -1.0345e+05, -5.9982e+04,
         9.2474e+03, -2.0271e+03, -6.4360e+04, -4.8636e+04, -4.5383e+04,
        -2.9082e+04, -5.0083e+04, -3.7026e+04, -5.1901e+04,  1.0700e+05])
tensor([ 0.0210,  0.0830, -0.0710, -0.0170,  0.0020, -0.0180, -0.0580, -0.0380,
         0.0360, -0.1420,  0.0550,  0.0040,  0.0770,  0.0200, -0.0940, -0.0460,
         0.0980, -0.0270, -0.0050,  0.0250], dtype=torch.float64)
tensor([ 2.1000e-02,  8.3000e-02, -7.1000e-02,  3.9577e+09,  4.1313e+09,
        -2.5326e+09, -1.7930e+05,  3.4640e+03, -3.1660e+05, -1.3449e+05,
        -1.1140e+05, -1.3196e+05, -8.4130e+04, -2.7075e+05, -1.4431e+05,
        -1.0833e+05, -3.4407e+05, -3.9656e+05, -9.8230e+04, -2.0352e+05],
       dtype=torch.float64)

You can see torch.float32 having 1e-3 errors in the first three positions. In other words, this has nothing to do with the model structure, but the particular kernel implementing the convolution.

If in your application domain or datasets you have reason to expect activations with large magnitude, and for some reason you can't introduce normalization layers, you'd need to write a custom kernel for the convolution / fft convolution.

datatalkv commented 8 months ago

What are those four tensors and why/what do we need to compare them?

For example, if you compare single and double precision (torch.float32 and torch.float64):

tensor([-0.0270, -0.0500, -0.0600, -0.0250, -0.0550, -0.1160, -0.0360, -0.0520,
        -0.0050, -0.0310,  0.0490,  0.1170,  0.0120,  0.0790, -0.0160,  0.0110,
         0.0790,  0.0330, -0.0270,  0.1380])
tensor([-2.9000e-02, -5.4000e-02, -6.0000e-02, -1.2226e+09,  6.4406e+08,
        -5.8460e+08, -1.5587e+05,  4.3244e+03, -1.0345e+05, -5.9982e+04,
         9.2474e+03, -2.0271e+03, -6.4360e+04, -4.8636e+04, -4.5383e+04,
        -2.9082e+04, -5.0083e+04, -3.7026e+04, -5.1901e+04,  1.0700e+05])
tensor([ 0.0210,  0.0830, -0.0710, -0.0170,  0.0020, -0.0180, -0.0580, -0.0380,
         0.0360, -0.1420,  0.0550,  0.0040,  0.0770,  0.0200, -0.0940, -0.0460,
         0.0980, -0.0270, -0.0050,  0.0250], dtype=torch.float64)
tensor([ 2.1000e-02,  8.3000e-02, -7.1000e-02,  3.9577e+09,  4.1313e+09,
        -2.5326e+09, -1.7930e+05,  3.4640e+03, -3.1660e+05, -1.3449e+05,
        -1.1140e+05, -1.3196e+05, -8.4130e+04, -2.7075e+05, -1.4431e+05,
        -1.0833e+05, -3.4407e+05, -3.9656e+05, -9.8230e+04, -2.0352e+05],
       dtype=torch.float64)

You can see torch.float32 having 1e-3 errors in the first three positions. In other words, this has nothing to do with the model structure, but the particular kernel implementing the convolution.

DanFu09 commented 8 months ago

The tensors are the output of your code that you pasted in with the bug report. When you say "1e3/1e4/1e5 still leads to significant differences," what differences do you observe?

DanFu09 commented 8 months ago

The original error that you're seeing in the original bug is just a result of errors from machine precision in fp32 being multiplied by a large magnitude.

Consider this code:

if __name__ == "__main__":
    layer = HyenaOperator(
        d_model=512, 
        l_max=1024, 
        order=2, 
        filter_order=64
    )
    x = torch.randn(1, 1024, 512, requires_grad=True, dtype=torch.float32)
    x2 = x.detach().clone()
    for m in [1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10]:
        x2[:, 3, :] = m * torch.sign(x2[:, 3, :])
        y = layer(x)
        y2 = layer(x2)

        print('m:', m)
        print('Error: ', (y-y2)[0, :5, 0].detach())

It's the same as your original bug report, with different magnitudes for the interrupt. This is the output:

m: 1000.0
Error:  tensor([-3.3397e-05, -5.7926e-05, -4.5875e-05, -2.0014e+06, -1.4382e+06])
m: 10000.0
Error:  tensor([-3.6624e-03, -6.7233e-03,  1.6033e-03, -2.0073e+09, -1.4371e+09])
m: 100000.0
Error:  tensor([ 3.1293e-01,  1.6798e-01, -2.0627e-01, -2.0079e+12, -1.4370e+12])
m: 1000000.0
Error:  tensor([-4.0541e+01,  5.4897e-01, -8.9160e+01, -2.0080e+15, -1.4369e+15])
m: 10000000.0
Error:  tensor([-3.1954e+03, -4.6748e+03, -1.1974e+04, -2.0080e+18, -1.4369e+18])
m: 100000000.0
Error:  tensor([ 8.1633e+04, -6.5942e+05, -3.0295e+04, -2.0080e+21, -1.4369e+21])
m: 1000000000.0
Error:  tensor([-1.8014e+07,  1.1134e+07,  8.2600e+06, -2.0080e+24, -1.4369e+24])
m: 10000000000.0
Error:  tensor([-9.3507e+08,  3.1199e+09, -4.5474e+09, -2.0080e+27, -1.4369e+27])

You can see at m=1e3, the difference for the first three elements is around 1e-5.

When m=1e10, the error increases by a factor of (1e7) ** 2 = 1e14. This comes from accumulating error across FFT and inverse FFT in the FFT convolution.

As Michael said, normally during training we never see the values of activations get this high unless something goes very wrong.

datatalkv commented 8 months ago

Thanks, this makes observations much clearer.

Well, I cannot get your claims saying this problem comes from implementation of FFT/invFFT. Could you please provide more evidences or hints?

Furthermore, do you mean one doesn't need to care about this problem in practice?

datatalkv commented 8 months ago

fp32 precision limits are never hit during training unless something goes really wrong during optimization, especially with normalization layers. Causal leakage is easy to detect since the loss will rapidly decrease towards unreasonably small values (often 0) and generation outputs will be nonsensical. Downstream eval numbers will also generally be terrible with leakage. FWIW, we've never seen this happen in practice (and have trained large causal models with a similar structure).

With your code, and a perturbation of magnitude 1e3 at position 3, at float32 precision:

tensor([0.0700, 0.0530, 0.0830, 0.0300, 0.0730, 0.0920, 0.1050, 0.1250, 0.0870,
        0.1480, 0.1910, 0.0740, 0.0240, 0.0160, 0.0900, 0.1170, 0.0100, 0.0820,
        0.1730, 0.1700])
tensor([ 7.0000e-02,  5.3000e-02,  8.3000e-02,  9.6446e+05, -3.1481e+05,
         1.6080e+06,  5.0385e+02, -7.4629e+01,  1.5281e+03,  1.5118e+03,
         1.3675e+03,  7.9940e+02,  1.0951e+03,  3.9184e+03,  4.1972e+03,
         1.5654e+03,  1.4578e+03,  1.2164e+03,  2.4598e+03,  1.2853e+03])

You can see a difference at 1e4 and larger. For example, if you compare single and double precision (torch.float32 and torch.float64):

tensor([-0.0270, -0.0500, -0.0600, -0.0250, -0.0550, -0.1160, -0.0360, -0.0520,
        -0.0050, -0.0310,  0.0490,  0.1170,  0.0120,  0.0790, -0.0160,  0.0110,
         0.0790,  0.0330, -0.0270,  0.1380])
tensor([-2.9000e-02, -5.4000e-02, -6.0000e-02, -1.2226e+09,  6.4406e+08,
        -5.8460e+08, -1.5587e+05,  4.3244e+03, -1.0345e+05, -5.9982e+04,
         9.2474e+03, -2.0271e+03, -6.4360e+04, -4.8636e+04, -4.5383e+04,
        -2.9082e+04, -5.0083e+04, -3.7026e+04, -5.1901e+04,  1.0700e+05])
tensor([ 0.0210,  0.0830, -0.0710, -0.0170,  0.0020, -0.0180, -0.0580, -0.0380,
         0.0360, -0.1420,  0.0550,  0.0040,  0.0770,  0.0200, -0.0940, -0.0460,
         0.0980, -0.0270, -0.0050,  0.0250], dtype=torch.float64)
tensor([ 2.1000e-02,  8.3000e-02, -7.1000e-02,  3.9577e+09,  4.1313e+09,
        -2.5326e+09, -1.7930e+05,  3.4640e+03, -3.1660e+05, -1.3449e+05,
        -1.1140e+05, -1.3196e+05, -8.4130e+04, -2.7075e+05, -1.4431e+05,
        -1.0833e+05, -3.4407e+05, -3.9656e+05, -9.8230e+04, -2.0352e+05],
       dtype=torch.float64)

You can see torch.float32 having 1e-3 errors in the first three positions. In other words, this has nothing to do with the model structure, but the particular kernel implementing the convolution.

If in your application domain or datasets you have reason to expect activations with large magnitude, and for some reason you can't introduce normalization layers, you'd need to write a custom kernel for the convolution / fft convolution.

BTW, I observe that all the input tensors have the same precision up to the 3rd behind 0. Do we have to follow this intentionally?

DanFu09 commented 8 months ago

Well, I cannot get your claims saying this problem comes from implementation of FFT/invFFT. Could you please provide more evidences or hints?

Here's a minimum example that you can run to see it in action. It's not quit the same since Hyena has more operations (which accumulates more error), but you can see that there's quite a bit of error in FFT/inverse FFT when you have very high magnitudes:

import torch

for m in [1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10]:
    x = torch.randn(1024)
    x2 = x.clone()
    x2[3:] = m * torch.sign(x[3:])

    y = torch.fft.irfft(torch.fft.rfft(x, n=2048), n=2048)
    y2 = torch.fft.irfft(torch.fft.rfft(x2, n=2048), n=2048)

    print('m:', m)
    print('Error: ', (y-y2)[:5])

This code is just taking the FFT and then the inverse FFT, so mathematically it's a no-op. However when we see the actual errors it's a different story:

m: 1000.0
Error:  tensor([-8.9407e-07,  6.0678e-05,  6.7472e-05, -9.9923e+02, -9.9959e+02])
m: 10000.0
Error:  tensor([ 1.0395e-04, -1.2151e-03, -2.2638e-04, -9.9976e+03,  9.9993e+03])
m: 100000.0
Error:  tensor([-6.3157e-03,  7.4155e-03, -3.0141e-03, -1.0000e+05, -9.9999e+04])
m: 1000000.0
Error:  tensor([ 3.7217e-03, -3.9581e-02,  1.2945e-01,  1.0000e+06, -1.0000e+06])
m: 10000000.0
Error:  tensor([-1.4344e-01,  7.4805e-01,  2.8370e-01, -1.0000e+07,  1.0000e+07])
m: 100000000.0
Error:  tensor([-7.9974e+00,  1.7516e+01, -5.7919e+00,  1.0000e+08,  1.0000e+08])
m: 1000000000.0
Error:  tensor([-6.3118e+01, -5.3386e-01,  7.8356e+01, -1.0000e+09,  1.0000e+09])
m: 10000000000.0
Error:  tensor([ 2.5470e+02, -1.0224e+03, -1.2801e+03, -1.0000e+10,  1.0000e+10])

Furthermore, do you mean one doesn't need to care about this problem in practice?

Yes, we've never seen this be a problem in practice.

BTW, I observe that all the input tensors have the same precision up to the 3rd behind 0. Do we have to follow this intentionally?

No, I think this is just an artifact of this particular test.