finegrain-ai / refiners

A microframework on top of PyTorch with first-class citizen APIs for foundation model adaptation
https://refine.rs
MIT License
400 stars 36 forks source link

Injecting/ejecting multiple controlnet models results in an error #377

Closed holwech closed 2 months ago

holwech commented 2 months ago

I was trying to inject -> eject -> inject two controlnet models but this doesn't seem to work properly.

Reproduction:

from refiners.foundationals.latent_diffusion import (
    StableDiffusion_1
)
from refiners.foundationals.latent_diffusion.solvers import DDIM
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter
from refiners.fluxion.utils import load_from_safetensors
import torch

device, dtype = ("cuda", torch.float16)
solver = DDIM(num_inference_steps=10)
sd = StableDiffusion_1(device=device, dtype=dtype, solver=solver)

sd.clip_text_encoder.load_from_safetensors("../weights/sd-text-encoder.safetensors")
sd.lda.load_from_safetensors("../weights/sd-lda.safetensors")
sd.unet.load_from_safetensors("../weights/sd-unet.safetensors")

controlnet = {
    "lineart": SD1ControlnetAdapter(
        sd.unet, name="lineart", scale=1.0, weights=load_from_safetensors("../weights/cn-lineart.safetensors")
    ).to(device, dtype),
    "canny": SD1ControlnetAdapter(
        sd.unet, name="canny", scale=1.0, weights=load_from_safetensors("../weights/cn-canny.safetensors")
    ).to(device, dtype),
}

if controlnet:
    for value in controlnet.values():
        value.inject()
if controlnet:
    for value in controlnet.values():
        value.eject()

# Error
if controlnet:
    for value in controlnet.values():
        value.inject()

Stack trace:

    "name": "AssertionError",
    "message": "SD1UNet(in_channels=4) not in SD1ControlnetAdapter(name=canny)",
    "stack": "---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
File /home/azureuser/cloudfiles/code/Users/joachim/migrate to refiners/inject_eject.py:4
      2 if controlnet:
      3     for value in controlnet.values():
----> 4         value.inject()

File /anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py:164, in SD1ControlnetAdapter.inject(self, parent)
    162     assert cn.name != self.name, f\"Controlnet named {self.name} is already injected\"
    163 self.target.insert(0, controlnet)
--> 164 return super().inject(parent)

File /anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/refiners/fluxion/adapters/adapter.py:78, in Adapter.inject(self, parent)
     75 # In general, `true_parent` is `parent`. We do this to support multiple adaptation,
     76 # i.e. initializing two adapters before injecting them.
     77 true_parent = parent.ensure_find_parent(self.target)
---> 78 true_parent.replace(
     79     old_module=self.target,
     80     new_module=self,
     81     old_module_parent=target_parent,
     82 )
     83 return self

File /anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/refiners/fluxion/layers/chain.py:607, in Chain.replace(self, old_module, new_module, old_module_parent)
    605     new_module._set_parent(self)
    606 if isinstance(old_module, ContextModule):
--> 607     old_module._set_parent(old_module_parent)

File /anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/refiners/fluxion/layers/module.py:187, in ContextModule._set_parent(self, parent)
    185     return
    186 # Always insert the module in the Chain first to avoid inconsistencies.
--> 187 assert self in iter(parent), f\"{self} not in {parent}\"
    188 self._parent = [parent]

AssertionError: SD1UNet(in_channels=4) not in SD1ControlnetAdapter(name=canny)"
limiteinductive commented 2 months ago

Hey @holwech,

This happens because you are ejecting nested adapters in the wrong order. If you eject in reverse, it should work:

if controlnet:
    for value in controlnet.values():
        value.inject()
if controller:
    # We iterate starting from the outermost adapter
    for value in reversed(controlnet.values()):
        value.eject()

# Now, this should be working
if controlnet:
    for value in controlnet.values():
        value.inject()
holwech commented 2 months ago

I see, thanks!