kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch
MIT License
3.48k stars 255 forks source link

Streaming intermediate images? #56

Closed sabetAI closed 2 years ago

sabetAI commented 2 years ago

Is it possible to publish an update of the model that supports streaming intermediate images during reverse diffusion ie with an iterator? Would greatly help UX if the user can see their image form while they're waiting for the process to finish.

kuprel commented 2 years ago

This isn't a diffusion model so that wouldn't work

sabetAI commented 2 years ago

Diffusion models iteratively update the image over multiple steps. These iterates can be streamed out (ie see glide demo). 'Reverse diffusion' is simply the image generation step ('diffusion' is the noising process during training), which is what your model is doing during inference. Can you update the code to output intermediate images?

sabetAI commented 2 years ago

Using the term 'reverse diffusion' might have caused some confusion with what I was asking.

iScriptLex commented 2 years ago

This model is not like glide or VQGAN+CLIP. DALL-E works on a entirely different principle. The image is generated with tiny squares (tokens), square by square, from left to right and top to bottom. It does not change all image at once at every iteration like diffuse models. Every iteration it just fills another tiny bit of the empty area with the completely ready tiny portion of the final image.

sabetAI commented 2 years ago

Ah good point @iScriptLex , I made assumptions about the model architecture. Even if it's outputting autoregressively, tokens can still be streamed out to incrementally update a canvas a pixel at a time. The main use-case here is so to show intermediate results to the user, as waiting kills the UX.

kuprel commented 2 years ago

It might be possible to generate the images each time a row of tokens is decoded, and use some kind of blank token for the missing rows

sabetAI commented 2 years ago

@kuprel yes exactly. Also would it be more efficient just to stream rows of tokens and have the client handle everything else? Want to minimize latency that streaming may add.

w4ffl35 commented 2 years ago

This model is not like glide or VQGAN+CLIP. DALL-E works on a entirely different principle. The image is generated with tiny squares (tokens), square by square, from left to right and top to bottom. It does not change all image at once at every iteration like diffuse models. Every iteration it just fills another tiny bit of the empty area with the completely ready tiny portion of the final image.

this would still look cool while it was loading but i worry about latency and bandwidth, wouldn't a loading bar or something work just as well?

sabetAI commented 2 years ago

@w4ffl35 can you quantify marginal latency/bandwidth costs? Loaders may work for one-time uses, but users will churn if they're stuck looking at loaders 95% of the time. See urzas.ai for example of UX with intermediate outputs. Imo if a flag was made available it would be hugely valuable for devs.

w4ffl35 commented 2 years ago

@sabetAI those are great points

kuprel commented 2 years ago

Ok I got it working in the colab now I just have to figure out how to get it on replicate. An intermediate image count of 8 only adds a couple seconds to the overall decoding time on the P100

kuprel commented 2 years ago

Here's what it looks like (open in new tab to see animation) animated

sabetAI commented 2 years ago

@kuprel so good 👏. When can you merge 🙏?

kuprel commented 2 years ago

I merged it. You can try it in the colab. Hopefully will get it onto replicate by tomorrow

kuprel commented 2 years ago

Ok it's live on replicate now