Closed a-ru2016 closed 1 year ago
こちらで Colab の T4 で確認したところ、
となりました。
T4 上で bfloat16 が使えないのは、GPU自体の制限によるものと思われるので、代わりに float32 を使うと正常に学習できると思われます。(config.yaml
の train
の precision
を float32
に設定します)
もし float32 に設定しても NotImplementedError ... が発生してしまう場合は、ファイルのパスが正しいかやファイルが正常に保存されているかを確認してもう一度お試しください。
また、Colab 上で動作を確認している簡易的な学習用ノートブックを作成いたしましたので、そちらも参照ください。
https://github.com/p1atdev/LECO/blob/main/train.ipynb
Upon verifying on Colab with a T4, the following results were observed:
The inability to use bfloat16 on a T4 seems to be a restriction of the GPU itself. Therefore, we believe using float32 instead should allow normal training. (please set the precision
to float32 in train
in config.yaml
)
If you still encounter a NotImplementedError even after setting it to float32, please check if the file path is correct and whether the file has been saved properly, then try again.
Additionally, we have created a simple notebook for training verification on Colab. Please refer to that as well.
ご確認ありがとうございます。Platさんの記事からコピペした時にミスがあっただけでした。お時間とらせて申し訳ございません。ありがとうございました。 Thank you for confirmation. There was only a mistake when I copied and pasted from Mr. Plat's article. I apologize for taking your time. thank you very much.
学習用コードを作ってくださりありがとうございます。google colabのT4でエラーが出たため報告させていただきます。 Thank you for creating the learning code. I would like to report an error on T4 of google colab.
error message
``` ╭───────────────────── Traceback (most recent call last) ──────────────────────╮ │ /content/LECO/./train_lora.py:318 in │
│ │
│ 315 │ │
│ 316 │ args = parser.parse_args() │
│ 317 │ │
│ ❱ 318 │ main(args) │
│ 319 │
│ │
│ /content/LECO/./train_lora.py:305 in main │
│ │
│ 302 │ config = config_util.load_config_from_yaml(config_file) │
│ 303 │ prompts = prompt_util.load_prompts_from_yaml(config.prompts_file) │
│ 304 │ │
│ ❱ 305 │ train(config, prompts) │
│ 306 │
│ 307 │
│ 308 if __name__ == "__main__": │
│ │
│ /content/LECO/./train_lora.py:156 in train │
│ │
│ 153 │ │ │ │
│ 154 │ │ │ with network: │
│ 155 │ │ │ │ # ちょっとデノイズされれたものが返る │
│ ❱ 156 │ │ │ │ denoised_latents = train_util.diffusion( │
│ 157 │ │ │ │ │ unet, │
│ 158 │ │ │ │ │ scheduler, │
│ 159 │ │ │ │ │ latents, # 単純なノイズのlatentsを渡す │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115 in │
│ decorate_context │
│ │
│ 112 │ @functools.wraps(func) │
│ 113 │ def decorate_context(*args, **kwargs): │
│ 114 │ │ with ctx_factory(): │
│ ❱ 115 │ │ │ return func(*args, **kwargs) │
│ 116 │ │
│ 117 │ return decorate_context │
│ 118 │
│ │
│ /content/LECO/train_util.py:142 in diffusion │
│ │
│ 139 │ # latents_steps = [] │
│ 140 │ │
│ 141 │ for timestep in tqdm(scheduler.timesteps[start_timesteps:total_tim │
│ ❱ 142 │ │ noise_pred = predict_noise( │
│ 143 │ │ │ unet, scheduler, timestep, latents, text_embeddings, **kwa │
│ 144 │ │ ) │
│ 145 │
│ │
│ /content/LECO/train_util.py:113 in predict_noise │
│ │
│ 110 │ latent_model_input = scheduler.scale_model_input(latent_model_inpu │
│ 111 │ │
│ 112 │ # predict the noise residual │
│ ❱ 113 │ noise_pred = unet( │
│ 114 │ │ latent_model_input, │
│ 115 │ │ timestep, │
│ 116 │ │ encoder_hidden_states=text_embeddings, │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in │
│ _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_condition.p │
│ y:797 in forward │
│ │
│ 794 │ │ down_block_res_samples = (sample,) │
│ 795 │ │ for downsample_block in self.down_blocks: │
│ 796 │ │ │ if hasattr(downsample_block, "has_cross_attention") and do │
│ ❱ 797 │ │ │ │ sample, res_samples = downsample_block( │
│ 798 │ │ │ │ │ hidden_states=sample, │
│ 799 │ │ │ │ │ temb=emb, │
│ 800 │ │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in │
│ _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/unet_2d_blocks.py:9 │
│ 24 in forward │
│ │
│ 921 │ │ │ │ )[0] │
│ 922 │ │ │ else: │
│ 923 │ │ │ │ hidden_states = resnet(hidden_states, temb) │
│ ❱ 924 │ │ │ │ hidden_states = attn( │
│ 925 │ │ │ │ │ hidden_states, │
│ 926 │ │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ 927 │ │ │ │ │ cross_attention_kwargs=cross_attention_kwargs, │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in │
│ _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/transformer_2d.py:2 │
│ 96 in forward │
│ │
│ 293 │ │ │
│ 294 │ │ # 2. Blocks │
│ 295 │ │ for block in self.transformer_blocks: │
│ ❱ 296 │ │ │ hidden_states = block( │
│ 297 │ │ │ │ hidden_states, │
│ 298 │ │ │ │ attention_mask=attention_mask, │
│ 299 │ │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in │
│ _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/attention.py:144 in │
│ forward │
│ │
│ 141 │ │ │ norm_hidden_states = self.norm1(hidden_states) │
│ 142 │ │ │
│ 143 │ │ cross_attention_kwargs = cross_attention_kwargs if cross_atten │
│ ❱ 144 │ │ attn_output = self.attn1( │
│ 145 │ │ │ norm_hidden_states, │
│ 146 │ │ │ encoder_hidden_states=encoder_hidden_states if self.only_c │
│ 147 │ │ │ attention_mask=attention_mask, │
│ │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in │
│ _call_impl │
│ │
│ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or s │
│ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hoo │
│ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1501 │ │ │ return forward_call(*args, **kwargs) │
│ 1502 │ │ # Do not call functions when jit is used │
│ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │
│ 1504 │ │ backward_pre_hooks = [] │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor │
│ .py:320 in forward │
│ │
│ 317 │ │ # The `Attention` class can call different attention processo │
│ 318 │ │ # here we simply pass along all tensors to the selected proce │
│ 319 │ │ # For standard processors that are defined here, `**cross_att │
│ ❱ 320 │ │ return self.processor( │
│ 321 │ │ │ self, │
│ 322 │ │ │ hidden_states, │
│ 323 │ │ │ encoder_hidden_states=encoder_hidden_states, │
│ │
│ /usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor │
│ .py:1045 in __call__ │
│ │
│ 1042 │ │ key = attn.head_to_batch_dim(key).contiguous() │
│ 1043 │ │ value = attn.head_to_batch_dim(value).contiguous() │
│ 1044 │ │ │
│ ❱ 1045 │ │ hidden_states = xformers.ops.memory_efficient_attention( │
│ 1046 │ │ │ query, key, value, attn_bias=attention_mask, op=self.atte │
│ 1047 │ │ ) │
│ 1048 │ │ hidden_states = hidden_states.to(query.dtype) │
│ │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/__init__.py:192 in │
│ memory_efficient_attention │
│ │
│ 189 │ │ and options. │
│ 190 │ :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]`` │
│ 191 │ """ │
│ ❱ 192 │ return _memory_efficient_attention( │
│ 193 │ │ Inputs( │
│ 194 │ │ │ query=query, key=key, value=value, p=p, attn_bias=attn_bia │
│ 195 │ │ ), │
│ │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/__init__.py:290 in │
│ _memory_efficient_attention │
│ │
│ 287 ) -> torch.Tensor: │
│ 288 │ # fast-path that doesn't require computing the logsumexp for backw │
│ 289 │ if all(x.requires_grad is False for x in [inp.query, inp.key, inp. │
│ ❱ 290 │ │ return _memory_efficient_attention_forward( │
│ 291 │ │ │ inp, op=op[0] if op is not None else None │
│ 292 │ │ ) │
│ 293 │
│ │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/__init__.py:306 in │
│ _memory_efficient_attention_forward │
│ │
│ 303 │ inp.validate_inputs() │
│ 304 │ output_shape = inp.normalize_bmhk() │
│ 305 │ if op is None: │
│ ❱ 306 │ │ op = _dispatch_fw(inp) │
│ 307 │ else: │
│ 308 │ │ _ensure_op_supports_or_raise(ValueError, "memory_efficient_att │
│ 309 │
│ │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/dispatch.py:94 in │
│ _dispatch_fw │
│ │
│ 91 │ if _is_triton_fwd_fastest(inp): │
│ 92 │ │ priority_list_ops.remove(triton.FwOp) │
│ 93 │ │ priority_list_ops.insert(0, triton.FwOp) │
│ ❱ 94 │ return _run_priority_list( │
│ 95 │ │ "memory_efficient_attention_forward", priority_list_ops, inp │
│ 96 │ ) │
│ 97 │
│ │
│ /usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/dispatch.py:69 in │
│ _run_priority_list │
│ │
│ 66 {textwrap.indent(_format_inputs_description(inp), ' ')}""" │
│ 67 │ for op, not_supported in zip(priority_list, not_supported_reasons) │
│ 68 │ │ msg += "\n" + _format_not_supported_reasons(op, not_supported) │
│ ❱ 69 │ raise NotImplementedError(msg) │
│ 70 │
│ 71 │
│ 72 def _dispatch_fw(inp: Inputs) -> Type[AttentionFwOpBase]: │
╰──────────────────────────────────────────────────────────────────────────────╯
NotImplementedError: No operator found for `memory_efficient_attention_forward`
with inputs:
query : shape=(20, 4096, 1, 64) (torch.bfloat16)
key : shape=(20, 4096, 1, 64) (torch.bfloat16)
value : shape=(20, 4096, 1, 64) (torch.bfloat16)
attn_bias :
p : 0.0
`flshattF` is not supported because:
bf16 is only supported on A100+ GPUs
`tritonflashattF` is not supported because:
bf16 is only supported on A100+ GPUs
requires A100 GPU
`cutlassF` is not supported because:
bf16 is only supported on A100+ GPUs
`smallkF` is not supported because:
dtype=torch.bfloat16 (supported: {torch.float32})
max(query.shape[-1] != value.shape[-1]) > 32
has custom scale
bf16 is only supported on A100+ GPUs
unsupported embed per head: 64
```
config.yaml
``` prompts_file: "./prompts.yaml" pretrained_model: name_or_path: "/content/model/CounterfeitV3_fix.ckpt" # you can also use .ckpt or .safetensors models v2: False # false if model is v2.x v_pred: False # true if model uses v-prediction network: type: "lierla" #or "c3lier" rank: 16 alpha: 1.0 train: precision: "float32" #float16,bfloat16でも同じエラーが出ます noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a" iterations: 1000 lr: 1e-4 optimizer: "AdamW" #Lion AdamW Adam lr_scheduler: "cosine" #constant linear cosine save: name: "hand" path: "./output" per_steps: 100 precision: "float32" #上と同じく変更しても同じエラーです logging: use_wandb: false verbose: false other: use_xformers: true ```
prompts.yaml
``` - target: "hand" # what word for erasing the positive concept from positive: "hand" # concept to erase unconditional: "" # word to take the difference from the positive concept neutral: "" # starting point for conditioning the target action: "erase" # erase or enhance guidance_scale: 1.0 resolution: 512 batch_size: 2 ```
ご確認よろしくお願いします。 Please confirm.