Closed yardenas closed 1 year ago
Having a second look, it seems like in most demonstrations you train for 100 epochs, which make many more steps compared to what I tried. Is it really required to run training for so many epochs?
Thanks again!
Hi Yadenas,
Nice to hear from you. I like Equinox a lot.
Thanks for the prompt response!
I'm off for a ski vacation for the next week, so not so much time to set up plotting and testing ⛷️ 😁
In any case, I can also try running your implementation for the same amount of steps I run my code and see if it fails. It wouldn't necessarily verify that my implementation is correct but at least I'll know I have to run my code for the same amount of steps as yours the get decent results.
Regarding Equinox, my code relies heavily on your implementation---I'm happy to adapt it so that it fits this repo and making a PR. LMK if you'd be interested
Did you fix this?
I made a one epoch run today on your repo, results seem different! For some reason, the results I get from my sampling collapse to black pixels. It's most likely a bug on my end, not sure exactly where though.
Thanks again!
One gotcha is that the normalization of inputs seems to matter. If you're using scalar inputs (rather than embeddings), I would recommend normalizing inputs to range [0, 1] (black pixels get mapped to value 0) rather than [-1, 1] (centered, which might be more standard).
@albertfgu, thanks for looping in, makes sense!
@srush, one thing I did notice is that I get different outputs when running the same model with decode
flag turned on and off.
def test_flax(n=8, sequence_length=16):
cell1 = s4_flax.S4Layer(n, sequence_length, decode=True)
cell2 = s4_flax.S4Layer(n, sequence_length, decode=False)
u = jax.numpy.ones((sequence_length,))
params1 = cell1.init(
{"params": jax.random.PRNGKey(666)},
u,
)
params2 = cell2.init(
{"params": jax.random.PRNGKey(666)},
u,
)
assert all(
jax.tree_map(
lambda x, y: jax.numpy.allclose(x, y), params1["params"], params2["params"]
)
)
y1 = cell1.apply(params1, u, mutable=["cache"])[0]
y2 = cell2.apply(params2, u)
assert np.allclose(y1, y2)
I don't know flax so well so I might be doing something wrong though. Testing decode=False
against my implementation (RNN & CNN modes) yields the same result.
Any ideas?
Ok so I started running my implementation on a GPU, I used the parameters we have in python -m s4.train dataset=mnist layer=s4 train.epochs=100 train.bsz=128 model.d_model=128 model.layer.N=64
and got the following results:
Seems like something is off though I'm still not sure exactly what is it---when I use the CNN modes, my code and your code produce the same outputs for the S4 cell.
Next step is to debug my training loop with your implementation, I'll update once I have more results
Hi!
First thanks for this great implementation! It really helps understanding the underlying methods used in the paper.
I've been implementing my own JAX implementation using Equinox instead of flax.
So far it seems like I cannot make my implementation generate mnist digit images based on short context (as you show in the tutorial). I'm not sure if it's because I don't train my model long enough.
When I train for ~50-100 steps, one step ahead predictions look great. However, when I use the same model to generate images based on context (specifically here), my model gloriously fails. I'm not sure if I'm missing something or is it just the fact that I'm not running the full training loop (on your implementation, it seems like you run it for ~4000 steps).
Any suggestions/ideas would be great for me!