Closed srush closed 2 years ago
@albertfgu if you are interested
This is awesome - thanks @srush! Have you already re-run QuickDraw, MNIST completion, and MNIST/CIFAR Classification with the updated code? If not, let me do that the next couple of days!
I reran MNIST, and MNIST-Classification and results looked good. Not sure if we need to run the others, but if you have time. Someone also added wandb integration which might make it easier to store logs and even models, synced with versions.
Was going to try a bit more on speech. I would like to get SpeechCommands generation working on 12GB GPU. But might need to figure out some Jax memory tricks first.
Sounds good; let me re-run CIFAR Classification at least, as we couldn't reproduce results with v1 for whatever reason (#17). Happy to try to add W&B integration as well.
Speech on 12 GB GPU would be super cool - let me know if I can help out there as well.
Hi @siddk, wonder if you made any progress reproducing sCIFAR-10. Adding another data point: I was able to reproduce the sCIFAR-10 with v1 with my own trainer with the following setup
model_cls = functools.partial(
s4.BatchSeqModel,
layer=s4.S4LayerInit(N=64),
d_model=1024,
d_output=10,
dropout=0.25,
n_layers=6,
l_max=1024,
classification=True,
)
state = create_train_state(
's4',
model_cls,
rng,
# Defined input shape
in_dim=3,
bsz=1,
seq_len=1024,
lr=1e-2,
total_steps=total_steps,
)
and the following optimizer
s4_fn = map_nested_fn(lambda k, _: 's4' if k in
['B', 'C', 'Ct', 'D', 'log_step'] else 'regular')
tx = optax.multi_transform(
{
's4': optax.adam(learning_rate=min(lr, 0.001)),
'regular': optax.adamw(learning_rate=lr, weight_decay=0.01),
},
s4_fn,
)
Also attached my eval on validation set.
Hope this helps!
This rocks! Thanks so much for letting us know.
If I get a big gpu free I'll try on v2 as well.
Oh this is amazing! I hadn't run CIFAR since like v0.5 (before we added some of the last changes for v1), so this is wonderful - thanks @frederick0329!
@srush - I can try grabbing a free GPU next week/week after (dead week/spring break) and see if we can reproduce for v2!
Thanks for chiming in @frederick0329!
FYI @srush @siddk - for reproducing, I would recommend using a smaller model actually which should be a lot faster and get pretty close. In particular you can decrease the model dimension from 1024 to 512. Also, you might want to lower the learning rate if you're using postnorm (from 1e-2 to 4e-3 or 5e-3); prenorm also gives similar results maybe at most 1% worse. Keeping the dt/A/B learning rates at 1e-3 or 0.1x lower should still be fine.
Got it - will give it a shot later this week @albertfgu! Thanks for the tips
Also note that our model is not training A at all. I saw that gave a small decrease in some of Albert's experiments, but maybe it is okay?
Hey @frederick0329 - we're trying to replicate the results you got with the newer version of the codebase; we're pretty sure there are no breaking changes, but still aren't able to get the results you've seen on CIFAR test.
Specifically, we're running the existing V2 code as follows (two versions):
# From your (@frederick0329) initial reply
python -m s4.train --dataset cifar-classification --model s4 --epoch 100 --bsz 64 --n_layers 6 --p_dropout 0.25 --lr 1e-2 --d_model 1024
# Following @albertfgu's suggestions
python -m s4.train --dataset cifar-classification --model s4 --epoch 100 --bsz 64 --n_layers 6 --p_dropout 0.25 --lr 5e-3 --d_model 512
The only changes between how we're running and your original code seem to be the batch_size (you seem to be using a batch size of 1), and anything special in your custom training loop -- do you remember if you were using a LR schedule, or any other things that wouldn't be covered (e.g., gradient clipping)?
What sort of performance are you getting?
nn.Dropout
which seems to be the standard one? IIRC this can generalize a lot worse for CNNs/RNNs - but I'm not sure how @frederick0329 managed to get his numbers if he used the code here without modifying it. He mentioned using his own trainer, I'm not sure if that makes any difference.Hey @albertfgu - sorry we took the discussion offline; newest results are in the README. We're getting 85.81% accuracy now. As for other points:
broadcast_dims=0
which should be equivalent to the Dropout2D (channel-zeroing); I remember ensuring this with Karan.I also have some dumb questions about the code/JAX. This line in the convolution is not passing in any axis arguments so according to the numpy documentation it's performing convolutions on the last axis? But why are the shapes being passed in from the 0-th axis? Overall what's the tensor shape throughout the network, is it (length, batch, dim)
?
Yeah we can try removing "C" from the special parameters list and see!
As for the question -- so it's definitely a little weird, but we're using a vmap at the batch level as well as in some other places that basically automatically map over those dimensions.
Basically, the internals of our model never see the batch dimension, so for the most part it's always (length, dim)
(I think, @srush correct me if I'm wrong!)
I see. Speaking of vmapping, I was curious whether this duplicates the initialization of the layers as well or they are initialized afterwards (and all independently).
Cool, I gathered that it was (length, dim)
although I'm still confused about that rfft
line since the documentation says default axis is -1
Some answer inline.
I see. Speaking of vmapping, I was curious whether this duplicates the initialization of the layers as well or they are initialized afterwards (and all independently).
The argument split_rngs={"params": True}
means they are initialized with a different random seed.
This line in the convolution is not passing in any axis arguments so according to the numpy documentation it's performing convolutions on the last axis? But why are the shapes being passed in from the 0-th axis? Overall what's the tensor shape throughout the network, is it (length, batch, dim)?
Cool, I gathered that it was
(length, dim)
although I'm still confused about thatrfft
line since the documentation says default axis is-1
I should document this better, but both u
and K
here are 1d vectors of size (length). None of this code knows about any of the other dimensions:
def non_circular_convolution(u, K, nofft=False):
When we call it down here u
and K
are only are shape (L). Both the batch and H dimension have been vmapped over.
Got it, I need to get my hands dirty with JAX when I get a chance. Batchnorm is really annoying in general - I don't know if it's even worth it, I generally stick to LN unless I really feel like tuning benchmarks.
Yeah, it's actually kind of impressive how annoying it is in JAX. It breaks all kinds of statefulness and independence assumptions that never occurred to me.
I think we'll add it though. A couple people have requested it.
This PR is towards a V2 of the blog post. Extensions include:
Successes:
prime
the preprocessing during inference to avoid recalculation for parameterization when parameters do not change.Failures: