elixir-nx / bumblebee

Pre-trained Neural Network models in Axon (+ 🤗 Models integration)
Apache License 2.0
1.27k stars 90 forks source link

Apply cross attention spec unet #298

Closed robinmonjo closed 4 months ago

robinmonjo commented 7 months ago

Warning beginner 😋.

I have been trying to use the UNet2DConditional model but without cross attention (so basically a UNet2D model).

In python, hugging face have both: UNet2DCondition and UNet2D.

I need to use an equivalent of UNet2D. Seems that it's supported by bumblebee for example in the unet layer in down_block_2d there is this doc:

When encoder_hidden_state is not nil, applies cross-attention.

However, from the UNet2DConditional model we have no way to have a nil encoder_hidden_state. So this PR add an option the model.

I'm not sure this won't bring confusion because from my understanding, without encoder_hidden_state this is not a conditional unet anymore.

Let me know what you think.

jonatanklosko commented 7 months ago

Hey @robinmonjo, we should add UNet2D as a separate model and have a test that we can correctly load a hf/transformers model (this checkpoint should be fine). The easiest way to start is to just copy the whole UNet2DConditional and strip the unnecessary parts. If there are obvious duplications we can extract them later, but I'd focus on having a passing UNet2D test first :)

robinmonjo commented 7 months ago

Hey @jonatanklosko thank for your feedback. I can't access the link in "this checkpoint", got a 404, I'm probably not allowed to see it 🤷🏻‍♂️

jonatanklosko commented 7 months ago

Ah sorry, I put the repo name as URL, updated!

jonatanklosko commented 4 months ago

I'm going to close this one. If you want to add the separate model, feel free to open a new PR :)