HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 6 forks source link

Non-Autoregressive Generation #12

Open ClashLuke opened 2 years ago

ClashLuke commented 2 years ago

Recently there have been papers about non-autoregressive text generation, in which models generate many tokens simultaneously instead of only one. Not only does this mean faster decoding times, but it also means that all hidden states can always attend to one another and know of their existence. Using non-autoregressive text generation, a model could first come up with concepts it wants to talk about in the future and generate text that leads to the future event. With autoregressive language modelling, this isn't possible to the same extent.\ This issue involves implementing such a language model and benchmarking against current, autoregressive language models.

ClashLuke commented 2 years ago

One possible approach to implementing non-autoregressive text generation could be diffusion. While diffusion on probabilities seems unlikely to work, diffusion in the embedding space immediately before a classification layer might. The idea here would be that the diffusion model would generate all states simultaneously, and by running 64 to 1024 denoising steps on the input, generate a valid embedding that the classifier can use.\ CompVis did similar work where they run diffusion and clip guidance in embedding space, so the idea seems to be out there. Additionally, we could even add prompt-based conditioning by running latent diffusion as they do. Having a "prompt" you attend to and generate text based on said prompt means the model is almost identical to the classical encoder-decoder architecture. The primary difference would be that the decoder produces all tokens simultaneously and has no conditioning.

ClashLuke commented 2 years ago

Diffusion-LM showed that running diffusion in the embedding space of a classification layer works well. Reproducing their results would be a viable first step.