Shilin-LU / TF-ICON

[ICCV 2023] "TF-ICON: Diffusion-Based Training-Free Cross-Domain Image Composition" (Official Implementation)
https://shilin-lu.github.io/tf-icon.github.io/
MIT License
798 stars 103 forks source link

How much VRAM is needed for this? #9

Closed Echo411 closed 1 year ago

Echo411 commented 1 year ago

This looks great!

But I ran into an out of memory error while running the code. The device I'm using is RTX 3090, 24G. Could you share some information on how much memory is needed to run this code successfully? Thanks!

The error message is as follows:

/home/wenhuaszhgc/miniconda3/envs/xqqpy38/bin/python3.8 /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/main_tf_icon.py --ckpt /home/wenhuaszhgc/users/xqq/TF-ICON/ckpt/v2-1_512-ema-pruned.ckpt --root /home/wenhuaszhgc/users/xqq/TF-ICON/inputs/cross_domain --domain cross --dpm_steps 20 --dpm_order 2 --scale 5 --tau_a 0.4 --tau_b 0.8 --outdir /home/wenhuaszhgc/users/xqq/TF-ICON/outputs --gpu cuda:2 --seed 3407 
Loading model from /home/wenhuaszhgc/users/xqq/TF-ICON/ckpt/v2-1_512-ema-pruned.ckpt
Global Step: 220000
/home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:258: LightningDeprecationWarning: `pytorch_lightning.utilities.distributed.rank_zero_only` has been deprecated in v1.8.1 and will be removed in v2.0.0. You can import it from `pytorch_lightning.utilities` instead.
  rank_zero_deprecation(
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Global seed set to 3407
1
loaded input image of size (512, 512) from /home/wenhuaszhgc/users/xqq/TF-ICON/inputs/cross_domain/a pencil drawing of an eiffel tower in the distance, black and white painting/bg48.png
['a pencil drawing of an eiffel tower in the distance, black and white painting']
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/main_tf_icon.py:565 in <module>  │
│                                                                              │
│   562                                                                        │
│   563                                                                        │
│   564 if __name__ == "__main__":                                             │
│ ❱ 565 │   main()                                                             │
│   566                                                                        │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/main_tf_icon.py:519 in main      │
│                                                                              │
│   516 │   │   │   │   │   │   │   mask = torch.zeros_like(z_enc, device=devi │
│   517 │   │   │   │   │   │   │   mask[:, :, param[0]:param[1], param[2]:par │
│   518 │   │   │   │   │   │   │                                              │
│ ❱ 519 │   │   │   │   │   │   │   samples, _ = sampler.sample(steps=opt.dpm_ │
│   520 │   │   │   │   │   │   │   │   │   │   │   │   │   │   inv_emb=inv_em │
│   521 │   │   │   │   │   │   │   │   │   │   │   │   │   │   conditioning=c │
│   522 │   │   │   │   │   │   │   │   │   │   │   │   │   │   batch_size=opt │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ autograd/grad_mode.py:27 in decorate_context                                 │
│                                                                              │
│    24 │   │   @functools.wraps(func)                                         │
│    25 │   │   def decorate_context(*args, **kwargs):                         │
│    26 │   │   │   with self.clone():                                         │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                           │
│    28 │   │   return cast(F, decorate_context)                               │
│    29 │                                                                      │
│    30 │   def _wrap_generator(self, func):                                   │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/models/diffusion/dpm_solver/sampler. │
│ py:184 in sample                                                             │
│                                                                              │
│   181 │   │   │   │   # decoded background                                   │
│   182 │   │   │   │   ptp_utils.register_attention_control(self.model, orig_ │
│   183 │   │   │   │   │   │   │   │   │   │   │   │   │    width, height, to │
│ ❱ 184 │   │   │   │   orig = dpm_solver_decode.sample_one_step(orig, step, s │
│   185 │   │   │   │                                                          │
│   186 │   │   │   │   # decode for cross-attention                           │
│   187 │   │   │   │   ptp_utils.register_attention_control(self.model, cross │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:1289 in    │
│ sample_one_step                                                              │
│                                                                              │
│   1286 │   │   data['t_prev_list'][-1] = vec_t                               │
│   1287 │   │   # We do not need to evaluate the final model value.           │
│   1288 │   │   if step < steps:                                              │
│ ❱ 1289 │   │   │   data['model_prev_list'][-1] = self.model_fn(data['x'], ve │
│   1290 │   │                                                                 │
│   1291 │   │   del vec_t                                                     │
│   1292                                                                       │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:471 in     │
│ model_fn                                                                     │
│                                                                              │
│    468 │   │   Convert the model to the noise prediction model or the data p │
│    469 │   │   """                                                           │
│    470 │   │   if self.algorithm_type == "dpmsolver++":                      │
│ ❱  471 │   │   │   return self.data_prediction_fn(x, t, DPMencode=DPMencode, │
│    472 │   │   else:                                                         │
│    473 │   │   │   return self.noise_prediction_fn(x, t, DPMencode=DPMencode │
│    474                                                                       │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:459 in     │
│ data_prediction_fn                                                           │
│                                                                              │
│    456 │   │   """                                                           │
│    457 │   │   Return the data prediction model (with corrector).            │
│    458 │   │   """                                                           │
│ ❱  459 │   │   noise = self.noise_prediction_fn(x, t, DPMencode=DPMencode, c │
│    460 │   │   alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), sel │
│    461 │   │   x0 = (x - sigma_t * noise) / alpha_t                          │
│    462 │   │   if self.correcting_x0_fn is not None:                         │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:453 in     │
│ noise_prediction_fn                                                          │
│                                                                              │
│    450 │   │   """                                                           │
│    451 │   │   Return the noise prediction model.                            │
│    452 │   │   """                                                           │
│ ❱  453 │   │   return self.model(x, t, DPMencode=DPMencode, controller=contr │
│    454 │                                                                     │
│    455 │   def data_prediction_fn(self, x, t, DPMencode=False, controller=No │
│    456 │   │   """                                                           │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:347 in     │
│ model_fn                                                                     │
│                                                                              │
│    344 │   │   │   │   t_in = torch.cat([t_continuous] * 2)                  │
│    345 │   │   │   │                                                         │
│    346 │   │   │   │   if ref_init == None:                                  │
│ ❱  347 │   │   │   │   │   noise_uncond, noise = noise_pred_fn(x_in, t_in, c │
│    348 │   │   │   │   else:                                                 │
│    349 │   │   │   │   │   noise_uncond, noise, _, _ = noise_pred_fn(x_in, t │
│    350                                                                       │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/scripts/dpm_solver_pytorch.py:295 in     │
│ noise_pred_fn                                                                │
│                                                                              │
│    292 │   │   if cond is None:                                              │
│    293 │   │   │   output = model(x, t_input, **model_kwargs)                │
│    294 │   │   else:                                                         │
│ ❱  295 │   │   │   output = model(x, t_input, cond, DPMencode, controller=co │
│    296 │   │   if model_type == "noise":                                     │
│    297 │   │   │   return output                                             │
│    298 │   │   elif model_type == "x_start":                                 │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/models/diffusion/dpm_solver/sampler. │
│ py:118 in <lambda>                                                           │
│                                                                              │
│   115 │   │   else:                                                          │
│   116 │   │   │   # x_T is a list                                            │
│   117 │   │   │   model_fn_decode = model_wrapper(                           │
│ ❱ 118 │   │   │   │   lambda x, t, c, DPMencode, controller, inject: self.mo │
│   119 │   │   │   │   ns,                                                    │
│   120 │   │   │   │   model_type=MODEL_TYPES[self.model.parameterization],   │
│   121 │   │   │   │   guidance_type="classifier-free",                       │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/models/diffusion/ddpm.py:859 in      │
│ apply_model                                                                  │
│                                                                              │
│    856 │   │   │   key = 'c_concat' if self.model.conditioning_key == 'conca │
│    857 │   │   │   cond = {key: cond}                                        │
│    858 │   │                                                                 │
│ ❱  859 │   │   x_recon = self.model(x_noisy, t, **cond, encode=encode, encod │
│    860 │   │                                                                 │
│    861 │   │   if isinstance(x_recon, tuple) and not return_ids:             │
│    862 │   │   │   return x_recon[0]                                         │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/models/diffusion/ddpm.py:1330 in     │
│ forward                                                                      │
│                                                                              │
│   1327 │   │   │   │   cc = torch.cat(c_crossattn, 1)                        │
│   1328 │   │   │   else:                                                     │
│   1329 │   │   │   │   cc = c_crossattn                                      │
│ ❱ 1330 │   │   │   out = self.diffusion_model(x, t, context=cc, encode=encod │
│   1331 │   │   elif self.conditioning_key == 'hybrid':                       │
│   1332 │   │   │   xc = torch.cat([x] + c_concat, dim=1)                     │
│   1333 │   │   │   cc = torch.cat(c_crossattn, 1)                            │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/diffusionmodules/openaimodel │
│ .py:796 in forward                                                           │
│                                                                              │
│   793 │   │   layernum = 0                                                   │
│   794 │   │   for module in self.output_blocks:                              │
│   795 │   │   │   h = th.cat([h, hs.pop()], dim=1)                           │
│ ❱ 796 │   │   │   h, layernum = module(h, emb, context, encode=encode, encod │
│   797 │   │   │   # print(layernum)                                          │
│   798 │   │                                                                  │
│   799 │   │   h = h.type(x.dtype)                                            │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/diffusionmodules/openaimodel │
│ .py:89 in forward                                                            │
│                                                                              │
│    86 │   │   │   │   else:                                                  │
│    87 │   │   │   │   │   x = layer(x, emb)                                  │
│    88 │   │   │   elif isinstance(layer, SpatialTransformer):                │
│ ❱  89 │   │   │   │   x, layernum = layer(x, context, encode=encode, encode_ │
│    90 │   │   │   │   │   │   │   │   │   controller=controller, inject=inje │
│    91 │   │   │   else:                                                      │
│    92 │   │   │   │   x = layer(x)                                           │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/attention.py:366 in forward  │
│                                                                              │
│   363 │   │   if self.use_linear:                                            │
│   364 │   │   │   x = self.proj_in(x)                                        │
│   365 │   │   for i, block in enumerate(self.transformer_blocks):            │
│ ❱ 366 │   │   │   x = block(x, context=context[i], encode=encode, encode_unc │
│   367 │   │   │   │   │     controller=controller, inject=inject, layernum=l │
│   368 │   │   if self.use_linear:                                            │
│   369 │   │   │   x = self.proj_out(x)                                       │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/attention.py:280 in forward  │
│                                                                              │
│   277 │   │   self.checkpoint = checkpoint                                   │
│   278 │                                                                      │
│   279 │   def forward(self, x, context=None, encode=False, encode_uncon=Fals │
│ ❱ 280 │   │   return checkpoint(self._forward, (x, context, encode, encode_u │
│   281 │                                                                      │
│   282 │   def _forward(self, x, context=None, encode=False, encode_uncon=Fal │
│   283                                                                        │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/diffusionmodules/util.py:117 │
│ in checkpoint                                                                │
│                                                                              │
│   114 │   """                                                                │
│   115 │   if flag:                                                           │
│   116 │   │   args = tuple(inputs) + tuple(params)                           │
│ ❱ 117 │   │   return CheckpointFunction.apply(func, len(inputs), *args)      │
│   118 │   else:                                                              │
│   119 │   │   return func(*inputs)                                           │
│   120                                                                        │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/diffusionmodules/util.py:132 │
│ in forward                                                                   │
│                                                                              │
│   129 │   │   │   │   │   │   │   │      "dtype": torch.get_autocast_gpu_dty │
│   130 │   │   │   │   │   │   │   │      "cache_enabled": torch.is_autocast_ │
│   131 │   │   with torch.no_grad():                                          │
│ ❱ 132 │   │   │   output_tensors = ctx.run_function(*ctx.input_tensors)      │
│   133 │   │   return output_tensors                                          │
│   134 │                                                                      │
│   135 │   @staticmethod                                                      │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ldm/modules/attention.py:301 in _forward │
│                                                                              │
│   298 │   │   │   │   x = self.attn2(self.norm2(x), context=context) + x     │
│   299 │   │                                                                  │
│   300 │   │   elif encode_uncon == False and decode_uncon == False:          │
│ ❱ 301 │   │   │   x = self.attn1(self.norm1(x), context=context if self.disa │
│   302 │   │   │   │   │   │      controller_for_inject=controller, inject=in │
│   303 │   │   │   x = self.attn2(self.norm2(x), context=context, encode=enco │
│   304 │   │   │   # pass                                                     │
│                                                                              │
│ /home/wenhuaszhgc/miniconda3/envs/xqqpy38/lib/python3.8/site-packages/torch/ │
│ nn/modules/module.py:1130 in _call_impl                                      │
│                                                                              │
│   1127 │   │   # this function, and just call forward.                       │
│   1128 │   │   if not (self._backward_hooks or self._forward_hooks or self._ │
│   1129 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1130 │   │   │   return forward_call(*input, **kwargs)                     │
│   1131 │   │   # Do not call functions when jit is used                      │
│   1132 │   │   full_backward_hooks, non_full_backward_hooks = [], []         │
│   1133 │   │   if self._backward_hooks or _global_backward_hooks:            │
│                                                                              │
│ /home/wenhuaszhgc/users/xqq/TF-ICON/ptp_scripts/ptp_utils.py:375 in forward  │
│                                                                              │
│   372 │   │   │   │                                                          │
│   373 │   │   │   │   del orig_mask, mask_for_realSA, orig_loc_masked, orig_ │
│   374 │   │   │                                                              │
│ ❱ 375 │   │   │   sim = sim.softmax(dim=-1)                                  │
│   376 │   │   │                                                              │
│   377 │   │   │   out = einsum('b i j, b j d -> b i d', sim, v)              │
│   378 │   │   │   out = rearrange(out, '(b h) n d -> b n (h d)', h=h)        │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 2; 23.70 GiB
total capacity; 18.52 GiB already allocated; 434.69 MiB free; 19.23 GiB reserved
in total by PyTorch) If reserved memory is >> allocated memory try setting 
max_split_size_mb to avoid fragmentation.  See documentation for Memory 
Management and PYTORCH_CUDA_ALLOC_CONF

Process finished with exit code 1
Shilin-LU commented 1 year ago

Hi, thanks for your question.

25 GB is safe, but 24 GB should be enough for most cases. Make sure your conda environment is the same as ours.

You may try our exemplar input.

Echo411 commented 1 year ago

Hi, thanks for your question.

25 GB is safe, but 24 GB should be enough for most cases. Make sure your conda environment is the same as ours.

You may try our exemplar input.

hi, thank you very much for your reply! The input I use is the exemplar input given in the code. Is there any way to reduce memory? Running this code is very important to me, thanks again for your reply!

Shilin-LU commented 1 year ago

Upon testing, I've identified a simple way to reduce memory usage in the latest version. One effective solution is to call torch.cuda.empty_cache() more frequently, as pointed out in this section of the code.

During our evaluation, we observed that the memory consumption was limited to a maximum of 22 GB. Hope this helps you out!

I encourage you to clone the revised version of our repository and give it another shot.