make sampling a little more modular so that other datasets can be incorporated more easily
add options for sampling (e.g. prefix length) that are configurable from config / command line
add Embed vs Dense option for the StackedModel backbone, with proper normalization
An interesting finding is the role of normalization for MNIST sampling. Working hypothesis is that any input tokens that semantically act as "padding" (i.e. having more or less on the boundaries of the input does not change the semantic meaning) should be encoded to the 0 vector for S4 layers, because of its global receptive field. Two options exist for doing this:
A nn.Dense encoder should map to [0.0, 1.0] instead of [-1.0, 1.0] so that black pixels get mapped to 0 (the Dense layer should also turn off its bias, but it works fine without that)
A custom nn.Embed encoder that maps the 0 token to the 0.0 vector. This has noticeably better NLL although the sampling might be a little more finicky
Both of these also fix earlier issues with generation with pre-norm architectures, because previously the magnitude of activations was too large ([0, 255]) which gets passed through the network without being normalized
An interesting finding is the role of normalization for MNIST sampling. Working hypothesis is that any input tokens that semantically act as "padding" (i.e. having more or less on the boundaries of the input does not change the semantic meaning) should be encoded to the 0 vector for S4 layers, because of its global receptive field. Two options exist for doing this:
Both of these also fix earlier issues with generation with pre-norm architectures, because previously the magnitude of activations was too large ([0, 255]) which gets passed through the network without being normalized