ShihaoZhaoZSH / Uni-ControlNet

[NeurIPS 2023] Uni-ControlNet: All-in-One Control to Text-to-Image Diffusion Models
MIT License
574 stars 41 forks source link

A technical problem about the feature openpose #18

Open modric197 opened 11 months ago

modric197 commented 11 months ago

Thank you for your great work! But I have a little problem about your work. In the code, the number of channels of the input is fixed to 21, however, for many data, we cannot extract openpose feature from them, which causes the results that for these images, there are only 6 features, so how to make the 6 features fit in the 21 channels?

modric197 commented 11 months ago

I tried to directly use the 6 features for such images, but an error occurs:

Traceback (most recent call last): File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap fn(i, args) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/spawn.py", line 101, in _wrapping_function results = function(args, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 812, in _fit_impl results = self._run(model, ckpt_path=self.ckpt_path) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1237, in _run results = self._run_stage() File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1324, in _run_stage return self._run_train() File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1354, in _run_train self.fit_loop.run() File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run self.advance(*args, *kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 269, in advance self._outputs = self.epoch_loop.run(self._data_fetcher) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run self.advance(args, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 208, in advance batch_output = self.batch_loop.run(batch, batch_idx) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run self.advance(*args, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run self.advance(*args, *kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 203, in advance result = self._run_optimization( File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 256, in _run_optimization self._optimizer_step(optimizer, opt_idx, batch_idx, closure) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 369, in _optimizer_step self.trainer._call_lightning_module_hook( File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1596, in _call_lightning_module_hook output = fn(args, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1625, in optimizer_step optimizer.step(closure=optimizer_closure) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 193, in optimizer_step return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 155, in optimizer_step return optimizer.step(closure=closure, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/optim/optimizer.py", line 88, in wrapper return func(*args, *kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(args, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/optim/adamw.py", line 100, in step loss = closure() File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 140, in _wrap_closure closure_result = closure() File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 148, in call self._result = self.closure(*args, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 134, in closure step_output = self._step_fn() File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 427, in _training_step training_step_output = self.trainer._call_strategy_hook("training_step", step_kwargs.values()) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1766, in _call_strategy_hook output = fn(args, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp_spawn.py", line 240, in training_step return self.model(*args, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 963, in forward output = self.module(inputs[0], kwargs[0]) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 82, in forward output = self.module.training_step(*inputs, *kwargs) File "./ldm/models/diffusion/ddpm.py", line 442, in training_step loss, loss_dict = self.shared_step(batch) File "./ldm/models/diffusion/ddpm.py", line 836, in shared_step loss = self(x, c) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "./ldm/models/diffusion/ddpm.py", line 848, in forward return self.p_losses(x, c, t, *args, kwargs) File "./ldm/models/diffusion/ddpm.py", line 888, in p_losses model_output = self.apply_model(x_noisy, t, cond) File "./models/uni_controlnet.py", line 59, in apply_model local_control = self.local_adapter(x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "./models/local_adapter.py", line 401, in forward local_features = self.feature_extractor(local_conditions) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, kwargs) File "./models/local_adapter.py", line 157, in forward local_features = self.pre_extractor(local_conditions, None) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, *kwargs) File "./models/local_adapter.py", line 27, in forward x = layer(x) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(input, **kwargs) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 447, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/yiw182/anaconda3/envs/unicontrol/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Given groups=1, weight of size [32, 21, 3, 3], expected input[1, 18, 512, 512] to have 21 channels, but got 18 channels instead

yairshp commented 5 months ago

+1

haikuoxin commented 1 month ago

@modric197 Channels corresponding to unused conditions should be initialized to 0, refer to src.test.test