lllyasviel / ControlNet

Let us control diffusion models!
Apache License 2.0
30.67k stars 2.75k forks source link

Using Multiple Text Encoders along the CLIP #627

Closed sayeh1994 closed 10 months ago

sayeh1994 commented 10 months ago

Hi. Is it possible to apply multiple text encoders to this model? I want to have CLIP, T5, and Bert. I modified FrozenCLIPT5Encoder to FrozenCLIPT5BertEncoder and defined a new class FrozenBertEmbedder in encoders/modules.py

class FrozenCLIPT5BertEncoder(AbstractEncoder):
    def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", bert_version = "emilyalsentzer/Bio_ClinicalBERT", device="cuda",
                 clip_max_length=77, t5_max_length=77, bert_max_length=110):
        super().__init__()
        self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
        self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
        self.bert_encoder = FrozenBertEmbedder(bert_version, device, max_length=bert_max_length)
        print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
              f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params,  "
              f"{self.bert_encoder.__class__.__name__} comes with {count_params(self.bert_encoder)*1.e-6:.2f} M params.")

    def encode(self, text):
        return self(text)

    def forward(self, text):
        clip_z = self.clip_encoder.encode(text)
        t5_z = self.t5_encoder.encode(text)
        bert_z = self.bert_encoder.encode(text)
        return [clip_z, t5_z, bert_z]

with the FrozenBertEmbedder to be:

class FrozenBertEmbedder(AbstractEncoder):
    """Uses the Bert transformer encoder for text"""
    def __init__(self, version="emilyalsentzer/Bio_ClinicalBERT", device="cuda", max_length=110, freeze=True):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(version)
        self.transformer = AutoModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length   # TODO: typical value?
        if freeze:
            self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        #self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens)

        z = outputs.last_hidden_state
        return z

    def encode(self, text):
        return self(text)

And modified the cond_stage_config: target: ldm.modules.encoders.modules.FrozenCLIPEmbedder in models/cldm_v15.yaml to cond_stage_config: target: ldm.modules.encoders.modules.FrozenCLIPT5BertEncoder . However, it raises the error of

RuntimeError: Error(s) in loading state_dict for ControlLDM:
    Missing key(s) in state_dict:

and also Unexpected key(s) in state_dict: with a bunch of .weight and .bias from all three encoders. Suffice it to say, that I get the result of the printing:

print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
              f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params,  "
              f"{self.bert_encoder.__class__.__name__} comes with {count_params(self.bert_encoder)*1.e-6:.2f} M params.")

which is: FrozenCLIPEmbedder has 123.06 M parameters, FrozenT5Embedder comes with 1223.53 M params, FrozenBertEmbedder comes with 108.31 M params.

I would really appreciate your help.

sayeh1994 commented 10 months ago

An update to the previous error. I managed to solve it by running :

python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt

for the new config.

However, now I have a new problem:

/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:56: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 4. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
  warning_cache.warn(
Traceback (most recent call last):
  File "tutorial_train_v2.py", line 39, in <module>
    trainer.fit(model, dataloader)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
    self._call_and_handle_interrupt(
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
    self._dispatch()
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
    return self._run_train()
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1312, in _run_train
    self.fit_loop.run()
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 195, in advance
    batch_output = self.batch_loop.run(batch, batch_idx)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 215, in advance
    result = self._run_optimization(
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 266, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 378, in _optimizer_step
    lightning_module.optimizer_step(
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1662, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 164, in step
    trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in optimizer_step
    self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 163, in optimizer_step
    optimizer.step(closure=closure, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/torch/optim/optimizer.py", line 373, in wrapper
    out = func(*args, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/torch/optim/adamw.py", line 161, in step
    loss = closure()
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 148, in _wrap_closure
    closure_result = closure()
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure
    step_output = self._step_fn()
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 435, in _training_step
    training_step_output = self.trainer.accelerator.training_step(step_kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 216, in training_step
    return self.training_type_plugin.training_step(*step_kwargs.values())
  File "/.conda/envs/control/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 213, in training_step
    return self.model.training_step(*args, **kwargs)
  File "/ControlNet/ldm/models/diffusion/ddpm.py", line 442, in training_step
    loss, loss_dict = self.shared_step(batch)
  File "/ControlNet/ldm/models/diffusion/ddpm.py", line 836, in shared_step
    loss = self(x, c)
  File "/.conda/envs/control/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/.conda/envs/control/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ControlNet/ldm/models/diffusion/ddpm.py", line 848, in forward
    return self.p_losses(x, c, t, *args, **kwargs)
  File "/ControlNet/ldm/models/diffusion/ddpm.py", line 888, in p_losses
    model_output = self.apply_model(x_noisy, t, cond)
  File "/ControlNet/cldm/cldm.py", line 332, in apply_model
    cond_txt = torch.cat(cond['c_crossattn'], 1)
TypeError: expected Tensor as element 0 in argument 0, but got list

I understand that the list of three text encoders might be the issue but I don't have the solution.