srush / annotated-s4

Implementation of https://srush.github.io/annotated-s4
https://srush.github.io/annotated-s4
MIT License
460 stars 61 forks source link

Assertion failure for differing kernel K and input u #39

Closed briancheung closed 2 years ago

briancheung commented 2 years ago

When using the following run command: python -m s4.train --dataset mnist-classification --model s4 --epochs 10 --bsz 128 --d_model 128 --ssm_n 64 I'm finding that this assertion is faliing: https://github.com/srush/annotated-s4/blob/2cfd155217cc928ded27f1434fb71e3ef8245a95/s4/s4.py#L343

for the mnist classification task. I'm not quite sure if the failure is intentional or there's an off-by-one bug somewhere between the initialization of the model and the training. The following lines in the code will pad out any discrepancy between the Kernel length and u length, so the code runs fine if you simply remove the assertion. But that might not be the intended behavior.

        ud = np.fft.rfft(np.pad(u, (0, K.shape[0])))
        Kd = np.fft.rfft(np.pad(K, (0, u.shape[0])))
srush commented 2 years ago

Good catch. For MNIST classification you can remove the assert and it should be fine.

The real issue is that we are dropping the last pixel [:-1] for generation, and left it in for classification. We should not drop the last classification pixel.

srush commented 2 years ago

Fixed now.