Closed briancheung closed 2 years ago
Good catch. For MNIST classification you can remove the assert and it should be fine.
The real issue is that we are dropping the last pixel [:-1] for generation, and left it in for classification. We should not drop the last classification pixel.
Fixed now.
When using the following run command:
python -m s4.train --dataset mnist-classification --model s4 --epochs 10 --bsz 128 --d_model 128 --ssm_n 64
I'm finding that this assertion is faliing: https://github.com/srush/annotated-s4/blob/2cfd155217cc928ded27f1434fb71e3ef8245a95/s4/s4.py#L343for the mnist classification task. I'm not quite sure if the failure is intentional or there's an off-by-one bug somewhere between the initialization of the model and the training. The following lines in the code will pad out any discrepancy between the Kernel length and u length, so the code runs fine if you simply remove the assertion. But that might not be the intended behavior.