srush / annotated-s4

Implementation of https://srush.github.io/annotated-s4
https://srush.github.io/annotated-s4
MIT License
468 stars 60 forks source link

Amount of training steps for mnist forecast with context #72

Closed yardenas closed 1 year ago

yardenas commented 1 year ago

Hi!

First thanks for this great implementation! It really helps understanding the underlying methods used in the paper.

I've been implementing my own JAX implementation using Equinox instead of flax.

So far it seems like I cannot make my implementation generate mnist digit images based on short context (as you show in the tutorial). I'm not sure if it's because I don't train my model long enough.

When I train for ~50-100 steps, one step ahead predictions look great. However, when I use the same model to generate images based on context (specifically here), my model gloriously fails. I'm not sure if I'm missing something or is it just the fact that I'm not running the full training loop (on your implementation, it seems like you run it for ~4000 steps).

Any suggestions/ideas would be great for me!

yardenas commented 1 year ago

Having a second look, it seems like in most demonstrations you train for 100 epochs, which make many more steps compared to what I tried. Is it really required to run training for so many epochs?

Thanks again!

srush commented 1 year ago

Hi Yadenas,

Nice to hear from you. I like Equinox a lot.

yardenas commented 1 year ago

Thanks for the prompt response!

I'm off for a ski vacation for the next week, so not so much time to set up plotting and testing ⛷️ 😁

In any case, I can also try running your implementation for the same amount of steps I run my code and see if it fails. It wouldn't necessarily verify that my implementation is correct but at least I'll know I have to run my code for the same amount of steps as yours the get decent results.

Regarding Equinox, my code relies heavily on your implementation---I'm happy to adapt it so that it fits this repo and making a PR. LMK if you'd be interested

srush commented 1 year ago

Did you fix this?

yardenas commented 1 year ago

I made a one epoch run today on your repo, results seem different! For some reason, the results I get from my sampling collapse to black pixels. It's most likely a bug on my end, not sure exactly where though.

Thanks again!

albertfgu commented 1 year ago

One gotcha is that the normalization of inputs seems to matter. If you're using scalar inputs (rather than embeddings), I would recommend normalizing inputs to range [0, 1] (black pixels get mapped to value 0) rather than [-1, 1] (centered, which might be more standard).

yardenas commented 1 year ago

@albertfgu, thanks for looping in, makes sense!

@srush, one thing I did notice is that I get different outputs when running the same model with decode flag turned on and off.

def test_flax(n=8, sequence_length=16):
    cell1 = s4_flax.S4Layer(n, sequence_length, decode=True)
    cell2 = s4_flax.S4Layer(n, sequence_length, decode=False)
    u = jax.numpy.ones((sequence_length,))
    params1 = cell1.init(
        {"params": jax.random.PRNGKey(666)},
        u,
    )
    params2 = cell2.init(
        {"params": jax.random.PRNGKey(666)},
        u,
    )
    assert all(
        jax.tree_map(
            lambda x, y: jax.numpy.allclose(x, y), params1["params"], params2["params"]
        )
    )
    y1 = cell1.apply(params1, u, mutable=["cache"])[0]
    y2 = cell2.apply(params2, u)
    assert np.allclose(y1, y2)

I don't know flax so well so I might be doing something wrong though. Testing decode=False against my implementation (RNN & CNN modes) yields the same result.

Any ideas?

yardenas commented 1 year ago

Ok so I started running my implementation on a GPU, I used the parameters we have in python -m s4.train dataset=mnist layer=s4 train.epochs=100 train.bsz=128 model.d_model=128 model.layer.N=64 and got the following results:

im 0 im 1 im 2 im 3 im 4 im 5

Seems like something is off though I'm still not sure exactly what is it---when I use the CNN modes, my code and your code produce the same outputs for the S4 cell.

Next step is to debug my training loop with your implementation, I'll update once I have more results