archinetai / audio-diffusion-pytorch

Audio generation using diffusion models, in PyTorch.
MIT License
1.95k stars 169 forks source link

I have a few questions about 1D-UNet #76

Open 0417keito opened 1 year ago

0417keito commented 1 year ago

The code shows that a-unet is used to construct the unet, but looking at the a-unet, the unet is constructed in a nested-like structure. So, does this unet have middle blocks other than the encoder and decoder parts as used in other diffsuion models? What is the unet without middle blocks?

jbmaxwell commented 1 year ago

The way it's written is highly modular, which can be handy, but also makes it quite tricky to interpret (and hack!). But yes, it does have mid-blocks. I wound up using model.modules() to enumerate the model and save it to a text file. I first got the XUNet using the following:

# Get the XUNet
for i, mod in enumerate(model.modules()):
    if isinstance(mod, XUNet):
        xunet = mod
        break

Then I just iterated over the unet and figured out where the different structures were. I don't recall whether the mid-blocks were labelled in any way, but the up and down blocks were pretty easy to identify looking either for strides (down) or "upscale" modules (up), as I recall (it was a while ago now). I think I just figured out the mid-blocks by deduction.

0417keito commented 1 year ago

I think that nested recursive structures are convenient, but they are a bit difficult to use when you want to apply something like ControlNet or IP-Adapter that performs processing for each block. So, I wanted to flatten it with the following code, but is it working…

https://github.com/0417keito/Encofusion/blob/main/flat_audio_diffusion/flat_a_unet/apex.py#L342-L434

jbmaxwell commented 1 year ago

Yeah, I totally agree about the nesting making things tricky to work with. I'm actually also very much interested in ControlNet. It would be great to get that working with ADP!

SuperiorDtj commented 1 year ago

Is there any open-source well-trained ckpt to be used for ControlNet?