enolan / decision_transformer

CLIP image generator using a decision transformer inspired model
5 stars 1 forks source link

nystrom causal attention #1

Open enolan opened 3 years ago

enolan commented 3 years ago

should get a big improvement from working from pixels rather than vqgan

enolan commented 3 years ago

working with pixels we want to output a distribution - let's say it's a multivariate normal. We can then implement an equivalent to top-p sampling by transforming the output distributions into truncated normal distributions, using the cdf to find the correct truncation point. See here for p = 0.05.

This becomes more complicated if we make the r/g/b covariance non-identity, if it's hard I think it's fine to constrain to identity covariance since we're sampling multiple times with MC-dropout anyway.

enolan commented 3 years ago

Important if I'm doing this in JAX: the JAX pseudoinverse function uses SVD, which is hideously slow, especially on TPU. Need to use the fast approximation from the paper, also implemented here. Could contribute a faster pseudoinverse to JAX too, that'd be good.