Closed siddk closed 2 years ago
Nice! This is looks great, very excited to play with it.
Couple things to note.
the original file was written in the jupytext light
format https://jupytext.readthedocs.io/en/latest/formats.html#notebooks-as-scripts which let's you render it as a notebook (and a blog eventually). Let's aim to keep it like that unless you disagree. Basically this means putting a blank line in between literate comments and code.
Let's document the formatting and style and add a CI check. Generally I use black
and flake8
. Here's an example https://github.com/harvardnlp/annotated-transformer/blob/v2021/.github/workflows/checks.yml
The automatically generated requirements.txt are too intense for me. Let's just fix a Jax / flax / matplotlib version. Ideally we can get rid of the torch vision dependency as well.
If possible let's not have any inline code in the main file. (That way we can have tests.) This is a little annoying to do, my convention is to wrap code I want to run inline in example_
and then gate calls to example functions.
If some of this is annoying we could also put the data and model parts in a different file and import them. They may not be necessary for the readable part.
Ok - I think all of the above feedback has been addressed, and I have the s4.py --> streambook pipeline pretty fleshed out. I started some initial text structure for these initial experiments, but I think we should probably factor most of this out into separate files.
I think maybe starting with an RNN on one of these tasks is probably enough... then we build S4 from there.
Implements a simple, single input feed-forward model as a sanity check for the initial synthetic datasets. Specifically, this PR:
quantized sin(x)
toy overfitting task -- takes 360 samples of the functionsin(x)
and quantizes outputs into 8 bins, to serve as input for a simple sequence modeling task.quantized sin(ax + b)
task with randomly sampleda, b
values. A more complicated version of the toy task above.jax.jit
), implements a very simple Feed-Forward model that predicts only based on current state. This is a dumb baseline, but will serve as a scaffold for future RNN and S4 experiments.@srush - one note is that not all jax.numpy.linalg functions are implemented on the GPU backend; went ahead and fixed the non-symmetric eigendecomposition call
eigs()
to use the CPU-backed variant instead.