srush / annotated-s4

Implementation of https://srush.github.io/annotated-s4
https://srush.github.io/annotated-s4
MIT License
460 stars 61 forks source link

Sampling #63

Closed albertfgu closed 2 years ago

albertfgu commented 2 years ago

An interesting finding is the role of normalization for MNIST sampling. Working hypothesis is that any input tokens that semantically act as "padding" (i.e. having more or less on the boundaries of the input does not change the semantic meaning) should be encoded to the 0 vector for S4 layers, because of its global receptive field. Two options exist for doing this:

  1. A nn.Dense encoder should map to [0.0, 1.0] instead of [-1.0, 1.0] so that black pixels get mapped to 0 (the Dense layer should also turn off its bias, but it works fine without that)
  2. A custom nn.Embed encoder that maps the 0 token to the 0.0 vector. This has noticeably better NLL although the sampling might be a little more finicky

Both of these also fix earlier issues with generation with pre-norm architectures, because previously the magnitude of activations was too large ([0, 255]) which gets passed through the network without being normalized