srush / annotated-s4

Implementation of https://srush.github.io/annotated-s4
https://srush.github.io/annotated-s4
MIT License
460 stars 61 forks source link

V2 - RNN mode, various precision improvements #43

Closed srush closed 2 years ago

srush commented 2 years ago

This PR is towards a V2 of the blog post. Extensions include:

Successes:

Failures:

srush commented 2 years ago

@albertfgu if you are interested

siddk commented 2 years ago

This is awesome - thanks @srush! Have you already re-run QuickDraw, MNIST completion, and MNIST/CIFAR Classification with the updated code? If not, let me do that the next couple of days!

srush commented 2 years ago

I reran MNIST, and MNIST-Classification and results looked good. Not sure if we need to run the others, but if you have time. Someone also added wandb integration which might make it easier to store logs and even models, synced with versions.

Was going to try a bit more on speech. I would like to get SpeechCommands generation working on 12GB GPU. But might need to figure out some Jax memory tricks first.

siddk commented 2 years ago

Sounds good; let me re-run CIFAR Classification at least, as we couldn't reproduce results with v1 for whatever reason (#17). Happy to try to add W&B integration as well.

Speech on 12 GB GPU would be super cool - let me know if I can help out there as well.

frederick0329 commented 2 years ago

Hi @siddk, wonder if you made any progress reproducing sCIFAR-10. Adding another data point: I was able to reproduce the sCIFAR-10 with v1 with my own trainer with the following setup

model_cls = functools.partial(
      s4.BatchSeqModel,
      layer=s4.S4LayerInit(N=64),
      d_model=1024,
      d_output=10,
      dropout=0.25,
      n_layers=6,
      l_max=1024,
      classification=True,
  )

  state = create_train_state(
      's4',
      model_cls,
      rng,
      # Defined input shape
      in_dim=3,
      bsz=1,
      seq_len=1024,
      lr=1e-2,
      total_steps=total_steps,
  )

and the following optimizer

s4_fn = map_nested_fn(lambda k, _: 's4' if k in
                          ['B', 'C', 'Ct', 'D', 'log_step'] else 'regular')
tx = optax.multi_transform(
{
            's4': optax.adam(learning_rate=min(lr, 0.001)),
            'regular': optax.adamw(learning_rate=lr, weight_decay=0.01),
        },
        s4_fn,
)

Also attached my eval on validation set.

image

Hope this helps!

srush commented 2 years ago

This rocks! Thanks so much for letting us know.

If I get a big gpu free I'll try on v2 as well.

siddk commented 2 years ago

Oh this is amazing! I hadn't run CIFAR since like v0.5 (before we added some of the last changes for v1), so this is wonderful - thanks @frederick0329!

@srush - I can try grabbing a free GPU next week/week after (dead week/spring break) and see if we can reproduce for v2!

albertfgu commented 2 years ago

Thanks for chiming in @frederick0329!

FYI @srush @siddk - for reproducing, I would recommend using a smaller model actually which should be a lot faster and get pretty close. In particular you can decrease the model dimension from 1024 to 512. Also, you might want to lower the learning rate if you're using postnorm (from 1e-2 to 4e-3 or 5e-3); prenorm also gives similar results maybe at most 1% worse. Keeping the dt/A/B learning rates at 1e-3 or 0.1x lower should still be fine.

siddk commented 2 years ago

Got it - will give it a shot later this week @albertfgu! Thanks for the tips

srush commented 2 years ago

Also note that our model is not training A at all. I saw that gave a small decrease in some of Albert's experiments, but maybe it is okay?

siddk commented 2 years ago

Hey @frederick0329 - we're trying to replicate the results you got with the newer version of the codebase; we're pretty sure there are no breaking changes, but still aren't able to get the results you've seen on CIFAR test.

Specifically, we're running the existing V2 code as follows (two versions):

# From your (@frederick0329) initial reply
python -m s4.train --dataset cifar-classification --model s4 --epoch 100 --bsz 64 --n_layers 6 --p_dropout 0.25 --lr 1e-2 --d_model 1024

# Following @albertfgu's suggestions
python -m s4.train --dataset cifar-classification --model s4 --epoch 100 --bsz 64 --n_layers 6 --p_dropout 0.25 --lr 5e-3 --d_model 512

The only changes between how we're running and your original code seem to be the batch_size (you seem to be using a batch size of 1), and anything special in your custom training loop -- do you remember if you were using a LR schedule, or any other things that wouldn't be covered (e.g., gradient clipping)?

albertfgu commented 2 years ago

What sort of performance are you getting?

siddk commented 2 years ago

Hey @albertfgu - sorry we took the discussion offline; newest results are in the README. We're getting 85.81% accuracy now. As for other points:

albertfgu commented 2 years ago

I also have some dumb questions about the code/JAX. This line in the convolution is not passing in any axis arguments so according to the numpy documentation it's performing convolutions on the last axis? But why are the shapes being passed in from the 0-th axis? Overall what's the tensor shape throughout the network, is it (length, batch, dim)?

siddk commented 2 years ago

Yeah we can try removing "C" from the special parameters list and see!

As for the question -- so it's definitely a little weird, but we're using a vmap at the batch level as well as in some other places that basically automatically map over those dimensions.

Basically, the internals of our model never see the batch dimension, so for the most part it's always (length, dim) (I think, @srush correct me if I'm wrong!)

albertfgu commented 2 years ago

I see. Speaking of vmapping, I was curious whether this duplicates the initialization of the layers as well or they are initialized afterwards (and all independently).

Cool, I gathered that it was (length, dim) although I'm still confused about that rfft line since the documentation says default axis is -1

srush commented 2 years ago

Some answer inline.

I see. Speaking of vmapping, I was curious whether this duplicates the initialization of the layers as well or they are initialized afterwards (and all independently).

The argument split_rngs={"params": True} means they are initialized with a different random seed.

This line in the convolution is not passing in any axis arguments so according to the numpy documentation it's performing convolutions on the last axis? But why are the shapes being passed in from the 0-th axis? Overall what's the tensor shape throughout the network, is it (length, batch, dim)?

Cool, I gathered that it was (length, dim) although I'm still confused about that rfft line since the documentation says default axis is -1

I should document this better, but both u and K here are 1d vectors of size (length). None of this code knows about any of the other dimensions:

def non_circular_convolution(u, K, nofft=False):

When we call it down here u and K are only are shape (L). Both the batch and H dimension have been vmapped over.

srush commented 2 years ago

To make it a little more clear, this line handles H and then this one handles batch. This lets us basically implement it like the scalar version in the paper.

Unfortunately, to implement batchnorm, we'll probably need to pull the second into the model

albertfgu commented 2 years ago

Got it, I need to get my hands dirty with JAX when I get a chance. Batchnorm is really annoying in general - I don't know if it's even worth it, I generally stick to LN unless I really feel like tuning benchmarks.

srush commented 2 years ago

Yeah, it's actually kind of impressive how annoying it is in JAX. It breaks all kinds of statefulness and independence assumptions that never occurred to me.

I think we'll add it though. A couple people have requested it.