rohitgandikota / sliders

Concept Sliders for Precise Control of Diffusion Models
https://sliders.baulab.info
MIT License
785 stars 64 forks source link

error NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs: when training in colab t4 #78

Open loboere opened 4 months ago

loboere commented 4 months ago

%cd /content/sliders !python trainscripts/textsliders/train_lora.py --attributes 'male, female' --name 'ageslider' --rank 4 --alpha 1 --config_file 'trainscripts/textsliders/data/config.yaml' I am trying to train in colab t4 but I am getting this error


/content/sliders
2024-02-25 21:29:44.849533: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-25 21:29:44.849598: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-25 21:29:44.851635: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-25 21:29:46.623676: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[{'target': 'male person', 'positive': 'male person, very old', 'unconditional': 'male person, very young', 'neutral': 'male person', 'action': 'enhance', 'guidance_scale': 4, 'resolution': 512, 'dynamic_resolution': False, 'batch_size': 1}, {'target': 'female person', 'positive': 'female person, very old', 'unconditional': 'female person, very young', 'neutral': 'female person', 'action': 'enhance', 'guidance_scale': 4, 'resolution': 512, 'dynamic_resolution': False, 'batch_size': 1}]
[{'target': 'male male person', 'positive': 'male male person, very old', 'unconditional': 'male male person, very young', 'neutral': 'male male person', 'action': 'enhance', 'guidance_scale': 4, 'resolution': 512, 'dynamic_resolution': False, 'batch_size': 1}, {'target': 'female male person', 'positive': 'female male person, very old', 'unconditional': 'female male person, very young', 'neutral': 'female male person', 'action': 'enhance', 'guidance_scale': 4, 'resolution': 512, 'dynamic_resolution': False, 'batch_size': 1}, {'target': 'male female person', 'positive': 'male female person, very old', 'unconditional': 'male female person, very young', 'neutral': 'male female person', 'action': 'enhance', 'guidance_scale': 4, 'resolution': 512, 'dynamic_resolution': False, 'batch_size': 1}, {'target': 'female female person', 'positive': 'female female person, very old', 'unconditional': 'female female person, very young', 'neutral': 'female female person', 'action': 'enhance', 'guidance_scale': 4, 'resolution': 512, 'dynamic_resolution': False, 'batch_size': 1}]
2 4
create LoRA for U-Net: 150 modules.
Prompts
target='male male person' positive='male male person, very old' unconditional='male male person, very young' neutral='male male person' action='enhance' guidance_scale=4.0 resolution=512 dynamic_resolution=False batch_size=1 dynamic_crops=False
target='female male person' positive='female male person, very old' unconditional='female male person, very young' neutral='female male person' action='enhance' guidance_scale=4.0 resolution=512 dynamic_resolution=False batch_size=1 dynamic_crops=False
target='male female person' positive='male female person, very old' unconditional='male female person, very young' neutral='male female person' action='enhance' guidance_scale=4.0 resolution=512 dynamic_resolution=False batch_size=1 dynamic_crops=False
target='female female person' positive='female female person, very old' unconditional='female female person, very young' neutral='female female person' action='enhance' guidance_scale=4.0 resolution=512 dynamic_resolution=False batch_size=1 dynamic_crops=False
Module: 
    Parameter: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight, Requires Grad: True
    Parameter: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up.weight, Requires Grad: True
Module: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q
    Parameter: lora_down.weight, Requires Grad: True
    Parameter: lora_up.weight, Requires Grad: True
Module: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down
    Parameter: weight, Requires Grad: True
Module: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up
    Parameter: weight, Requires Grad: True
Module: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_k
    Parameter: lora_down.weight, Requires Grad: True
    Parameter: lora_up.weight, Requires Grad: True
Module: , Training Mode: True
Module: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q, Training Mode: True
Module: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_down, Training Mode: True
Module: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up, Training Mode: True
Module: lora_unet_down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_k, Training Mode: True
target='male male person' positive='male male person, very old' unconditional='male male person, very young' neutral='male male person' action='enhance' guidance_scale=4.0 resolution=512 dynamic_resolution=False batch_size=1 dynamic_crops=False
male male person
male male person, very old
male male person
male male person, very young
target='female male person' positive='female male person, very old' unconditional='female male person, very young' neutral='female male person' action='enhance' guidance_scale=4.0 resolution=512 dynamic_resolution=False batch_size=1 dynamic_crops=False
female male person
female male person, very old
female male person
female male person, very young
target='male female person' positive='male female person, very old' unconditional='male female person, very young' neutral='male female person' action='enhance' guidance_scale=4.0 resolution=512 dynamic_resolution=False batch_size=1 dynamic_crops=False
male female person
male female person, very old
male female person
male female person, very young
target='female female person' positive='female female person, very old' unconditional='female female person, very young' neutral='female female person' action='enhance' guidance_scale=4.0 resolution=512 dynamic_resolution=False batch_size=1 dynamic_crops=False
female female person
female female person, very old
female female person
female female person, very young
  0% 0/1000 [00:00<?, ?it/s]
  0% 0/8 [00:00<?, ?it/s]
  0% 0/1000 [00:00<?, ?it/s]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /content/sliders/trainscripts/textsliders/train_lora.py:419 in <module>                          │
│                                                                                                  │
│   416 │                                                                                          │
│   417 │   args = parser.parse_args()                                                             │
│   418 │                                                                                          │
│ ❱ 419 │   main(args)                                                                             │
│   420                                                                                            │
│                                                                                                  │
│ /content/sliders/trainscripts/textsliders/train_lora.py:364 in main                              │
│                                                                                                  │
│   361 │   prompts = prompt_util.load_prompts_from_yaml(config.prompts_file, attributes)          │
│   362 │   device = torch.device(f"cuda:{args.device}")                                           │
│   363 │                                                                                          │
│ ❱ 364 │   train(config=config, prompts=prompts, device=device)                                   │
│   365                                                                                            │
│   366                                                                                            │
│   367 if __name__ == "__main__":                                                                 │
│                                                                                                  │
│ /content/sliders/trainscripts/textsliders/train_lora.py:195 in train                             │
│                                                                                                  │
│   192 │   │   │                                                                                  │
│   193 │   │   │   with network:                                                                  │
│   194 │   │   │   │   # ちょっとデノイズされれたものが返る                                       │
│ ❱ 195 │   │   │   │   denoised_latents = train_util.diffusion(                                   │
│   196 │   │   │   │   │   unet,                                                                  │
│   197 │   │   │   │   │   noise_scheduler,                                                       │
│   198 │   │   │   │   │   latents,  # 単純なノイズのlatentsを渡す                                │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115 in decorate_context       │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /content/sliders/trainscripts/textsliders/train_util.py:188 in diffusion                         │
│                                                                                                  │
│   185 │   # latents_steps = []                                                                   │
│   186 │                                                                                          │
│   187 │   for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):            │
│ ❱ 188 │   │   noise_pred = predict_noise(                                                        │
│   189 │   │   │   unet, scheduler, timestep, latents, text_embeddings, **kwargs                  │
│   190 │   │   )                                                                                  │
│   191                                                                                            │
│                                                                                                  │
│ /content/sliders/trainscripts/textsliders/train_util.py:159 in predict_noise                     │
│                                                                                                  │
│   156 │   latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)         │
│   157 │                                                                                          │
│   158 │   # predict the noise residual                                                           │
│ ❱ 159 │   noise_pred = unet(                                                                     │
│   160 │   │   latent_model_input,                                                                │
│   161 │   │   timestep,                                                                          │
│   162 │   │   encoder_hidden_states=text_embeddings,                                             │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_condition.py:930 in forward     │
│                                                                                                  │
│    927 │   │   │   │   if is_adapter and len(down_block_additional_residuals) > 0:               │
│    928 │   │   │   │   │   additional_residuals["additional_residuals"] = down_block_additional  │
│    929 │   │   │   │                                                                             │
│ ❱  930 │   │   │   │   sample, res_samples = downsample_block(                                   │
│    931 │   │   │   │   │   hidden_states=sample,                                                 │
│    932 │   │   │   │   │   temb=emb,                                                             │
│    933 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                          │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_blocks.py:1053 in forward       │
│                                                                                                  │
│   1050 │   │   │   │   )[0]                                                                      │
│   1051 │   │   │   else:                                                                         │
│   1052 │   │   │   │   hidden_states = resnet(hidden_states, temb)                               │
│ ❱ 1053 │   │   │   │   hidden_states = attn(                                                     │
│   1054 │   │   │   │   │   hidden_states,                                                        │
│   1055 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                          │
│   1056 │   │   │   │   │   cross_attention_kwargs=cross_attention_kwargs,                        │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/transformer_2d.py:309 in forward        │
│                                                                                                  │
│   306 │   │   │   │   │   use_reentrant=False,                                                   │
│   307 │   │   │   │   )                                                                          │
│   308 │   │   │   else:                                                                          │
│ ❱ 309 │   │   │   │   hidden_states = block(                                                     │
│   310 │   │   │   │   │   hidden_states,                                                         │
│   311 │   │   │   │   │   attention_mask=attention_mask,                                         │
│   312 │   │   │   │   │   encoder_hidden_states=encoder_hidden_states,                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/attention.py:194 in forward             │
│                                                                                                  │
│   191 │   │   cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs   │
│   192 │   │   gligen_kwargs = cross_attention_kwargs.pop("gligen", None)                         │
│   193 │   │                                                                                      │
│ ❱ 194 │   │   attn_output = self.attn1(                                                          │
│   195 │   │   │   norm_hidden_states,                                                            │
│   196 │   │   │   encoder_hidden_states=encoder_hidden_states if self.only_cross_attention els   │
│   197 │   │   │   attention_mask=attention_mask,                                                 │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl            │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:322 in forward   │
│                                                                                                  │
│    319 │   │   # The `Attention` class can call different attention processors / attention func  │
│    320 │   │   # here we simply pass along all tensors to the selected processor class           │
│    321 │   │   # For standard processors that are defined here, `**cross_attention_kwargs` is e  │
│ ❱  322 │   │   return self.processor(                                                            │
│    323 │   │   │   self,                                                                         │
│    324 │   │   │   hidden_states,                                                                │
│    325 │   │   │   encoder_hidden_states=encoder_hidden_states,                                  │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:1034 in __call__ │
│                                                                                                  │
│   1031 │   │   key = attn.head_to_batch_dim(key).contiguous()                                    │
│   1032 │   │   value = attn.head_to_batch_dim(value).contiguous()                                │
│   1033 │   │                                                                                     │
│ ❱ 1034 │   │   hidden_states = xformers.ops.memory_efficient_attention(                          │
│   1035 │   │   │   query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=att  │
│   1036 │   │   )                                                                                 │
│   1037 │   │   hidden_states = hidden_states.to(query.dtype)                                     │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/__init__.py:193 in                     │
│ memory_efficient_attention                                                                       │
│                                                                                                  │
│   190 │   │   and options.                                                                       │
│   191 │   :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``                     │
│   192 │   """                                                                                    │
│ ❱ 193 │   return _memory_efficient_attention(                                                    │
│   194 │   │   Inputs(                                                                            │
│   195 │   │   │   query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale       │
│   196 │   │   ),                                                                                 │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/__init__.py:291 in                     │
│ _memory_efficient_attention                                                                      │
│                                                                                                  │
│   288 ) -> torch.Tensor:                                                                         │
│   289 │   # fast-path that doesn't require computing the logsumexp for backward computation      │
│   290 │   if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):             │
│ ❱ 291 │   │   return _memory_efficient_attention_forward(                                        │
│   292 │   │   │   inp, op=op[0] if op is not None else None                                      │
│   293 │   │   )                                                                                  │
│   294                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/__init__.py:307 in                     │
│ _memory_efficient_attention_forward                                                              │
│                                                                                                  │
│   304 │   inp.validate_inputs()                                                                  │
│   305 │   output_shape = inp.normalize_bmhk()                                                    │
│   306 │   if op is None:                                                                         │
│ ❱ 307 │   │   op = _dispatch_fw(inp, False)                                                      │
│   308 │   else:                                                                                  │
│   309 │   │   _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)    │
│   310                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/dispatch.py:96 in _dispatch_fw         │
│                                                                                                  │
│    93 │   │   │   # With multiquery, cutlass is sometimes faster than decoder                    │
│    94 │   │   │   # but it's not currently clear when.                                           │
│    95 │   │   │   priority_list_ops.appendleft(decoder.FwOp)                                     │
│ ❱  96 │   return _run_priority_list(                                                             │
│    97 │   │   "memory_efficient_attention_forward", priority_list_ops, inp                       │
│    98 │   )                                                                                      │
│    99                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/dispatch.py:63 in _run_priority_list   │
│                                                                                                  │
│    60 {textwrap.indent(_format_inputs_description(inp), '     ')}"""                             │
│    61 │   for op, not_supported in zip(priority_list, not_supported_reasons):                    │
│    62 │   │   msg += "\n" + _format_not_supported_reasons(op, not_supported)                     │
│ ❱  63 │   raise NotImplementedError(msg)                                                         │
│    64                                                                                            │
│    65                                                                                            │
│    66 def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NotImplementedError: No operator found for `memory_efficient_attention_forward` with inputs:
     query       : shape=(16, 4096, 1, 40) (torch.bfloat16)
     key         : shape=(16, 4096, 1, 40) (torch.bfloat16)
     value       : shape=(16, 4096, 1, 40) (torch.bfloat16)
     attn_bias   : <class 'NoneType'>
     p           : 0.0
`decoderF` is not supported because:
    requires device with capability > (8, 0) but your GPU has capability (7, 5) (too old)
    attn_bias type is <class 'NoneType'>
    bf16 is only supported on A100+ GPUs
`flshattF@v2.0.8` is not supported because:
    requires device with capability > (8, 0) but your GPU has capability (7, 5) (too old)
    bf16 is only supported on A100+ GPUs
`tritonflashattF` is not supported because:
    requires device with capability > (8, 0) but your GPU has capability (7, 5) (too old)
    bf16 is only supported on A100+ GPUs
    operator wasn't built - see `python -m xformers.info` for more info
    triton is not available
    requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4
`cutlassF` is not supported because:
    bf16 is only supported on A100+ GPUs
`smallkF` is not supported because:
    max(query.shape[-1] != value.shape[-1]) > 32
    dtype=torch.bfloat16 (supported: {torch.float32})
    has custom scale
    bf16 is only supported on A100+ GPUs
    unsupported embed per head: 40
rohitgandikota commented 4 months ago

Looks like the T4 in colab doesn't allow xformers. You can set the flag as false here: https://github.com/rohitgandikota/sliders/blob/b76a63e0df1da5fa51c15b3c68e0344ceb4fd095/trainscripts/textsliders/data/config-xl.yaml#L28

You can edit which ever model and type of slider you are trying to create, go to that config.yaml and set the xformers flag to false

StefanM03 commented 3 months ago

Hello, so i have the same issue using the v100 high ram gpu on colab, i have tried setting the flag as false but i am getting the same issue however i have found that setting the precision to fp16 seemed to work until i used the trained model in the script just for the result to be 4 black images for the textual concept slider non XL version