kohya-ss / sd-scripts

Apache License 2.0
5.24k stars 869 forks source link

Add `accelerate` torch.compile() support for faster training on Pytorch 2.0 #65

Open brucethemoose opened 1 year ago

brucethemoose commented 1 year ago

When selecting this in accelerate config:

Do you wish to optimize your script with torch dynamo?[yes/NO]:yes
---------------------------------------------------------------------------------------------------------Which dynamo backend would you like to use?
Please select a choice using the arrow or number keys, and selecting with enter
    eager
    aot_eager
 ➔  inductor
    nvfuser
    aot_nvfuser
    aot_cudagraphs
    ofi
    fx2trt
    onnxrt
    ipex

The LORA training script errors out with:

steps:   0%|                                                                    | 0/1600 [00:00<?, ?it/s]epoch 1/2
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:674 in             │
│ call_user_compiler                                                                               │
│                                                                                                  │
│   671 │   │   │   elif config.DO_NOT_USE_legacy_non_fake_example_inputs:                         │
│   672 │   │   │   │   compiled_fn = compiler_fn(gm, self.example_inputs())                       │
│   673 │   │   │   else:                                                                          │
│ ❱ 674 │   │   │   │   compiled_fn = compiler_fn(gm, self.fake_example_inputs())                  │
│   675 │   │   │   _step_logger()(logging.INFO, f"done compiler function {name}")                 │
│   676 │   │   │   assert callable(compiled_fn), "compiler_fn did not return callable"            │
│   677 │   │   except Exception as e:                                                             │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py:1032 in             │
│ debug_wrapper                                                                                    │
│                                                                                                  │
│   1029 │   │   │   │   │   )                                                                     │
│   1030 │   │   │   │   │   raise                                                                 │
│   1031 │   │   else:                                                                             │
│ ❱ 1032 │   │   │   compiled_gm = compiler_fn(gm, example_inputs, **kwargs)                       │
│   1033 │   │                                                                                     │
│   1034 │   │   return compiled_gm                                                                │
│   1035                                                                                           │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:398 in compile_fx  │
│                                                                                                  │
│   395 │   │   # TODO: can add logging before/after the call to create_aot_dispatcher_function    │
│   396 │   │   # in torch._functorch/aot_autograd.py::aot_module_simplified::aot_function_simpl   │
│   397 │   │   # once torchdynamo is merged into pytorch                                          │
│ ❱ 398 │   │   return aot_autograd(                                                               │
│   399 │   │   │   fw_compiler=fw_compiler,                                                       │
│   400 │   │   │   bw_compiler=bw_compiler,                                                       │
│   401 │   │   │   decompositions=select_decomp_table(),                                          │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/optimizations/training.py:78 in    │
│ compiler_fn                                                                                      │
│                                                                                                  │
│    75 │   │   try:                                                                               │
│    76 │   │   │   # NB: NOT cloned!                                                              │
│    77 │   │   │   with enable_aot_logging():                                                     │
│ ❱  78 │   │   │   │   cg = aot_module_simplified(gm, example_inputs, **kwargs)                   │
│    79 │   │   │   │   counters["aot_autograd"]["ok"] += 1                                        │
│    80 │   │   │   │   return eval_frame.disable(cg)                                              │
│    81 │   │   except Exception:                                                                  │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2355 in         │
│ aot_module_simplified                                                                            │
│                                                                                                  │
│   2352 │   full_args.extend(params_flat)                                                         │
│   2353 │   full_args.extend(args)                                                                │
│   2354 │                                                                                         │
│ ❱ 2355 │   compiled_fn = create_aot_dispatcher_function(                                         │
│   2356 │   │   functional_call,                                                                  │
│   2357 │   │   full_args,                                                                        │
│   2358 │   │   aot_config,                                                                       │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py:94 in time_wrapper        │
│                                                                                                  │
│     91 │   │   if key not in compilation_metrics:                                                │
│     92 │   │   │   compilation_metrics[key] = []                                                 │
│     93 │   │   t0 = time.time()                                                                  │
│ ❱   94 │   │   r = func(*args, **kwargs)                                                         │
│     95 │   │   latency = time.time() - t0                                                        │
│     96 │   │   # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")                    │
│     97 │   │   compilation_metrics[key].append(latency)                                          │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2052 in         │
│ create_aot_dispatcher_function                                                                   │
│                                                                                                  │
│   2049 │   │   compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn)                │
│   2050 │   │   # You can put more passes here                                                    │
│   2051 │   │                                                                                     │
│ ❱ 2052 │   │   compiled_fn = compiler_fn(flat_fn, fake_flat_tensor_args, aot_config)             │
│   2053 │   │                                                                                     │
│   2054 │   │   if not hasattr(compiled_fn, '_boxed_call'):                                       │
│   2055 │   │   │   compiled_fn = make_boxed_func(compiled_fn)                                    │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1273 in         │
│ aot_wrapper_dedupe                                                                               │
│                                                                                                  │
│   1270 │   # or not                                                                              │
│   1271 │   try:                                                                                  │
│   1272 │   │   with enable_python_dispatcher():                                                  │
│ ❱ 1273 │   │   │   fw_metadata, _out, _num_aliasing_metadata_outs = run_functionalized_fw_and_c  │
│   1274 │   │   │   │   flat_fn                                                                   │
│   1275 │   │   │   )(*flat_args)                                                                 │
│   1276 │   except RuntimeError as e:                                                             │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:289 in inner    │
│                                                                                                  │
│    286 │   │                                                                                     │
│    287 │   │   torch._enable_functionalization(reapply_views=True)                               │
│    288 │   │   try:                                                                              │
│ ❱  289 │   │   │   outs = f(*f_args)                                                             │
│    290 │   │   finally:                                                                          │
│    291 │   │   │   torch._disable_functionalization()                                            │
│    292                                                                                           │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:2327 in         │
│ functional_call                                                                                  │
│                                                                                                  │
│   2324 │   │   │   │   │   │   "ignore", "Anomaly Detection has been enabled."                   │
│   2325 │   │   │   │   │   )                                                                     │
│   2326 │   │   │   │   │   with torch.autograd.detect_anomaly(check_nan=False):                  │
│ ❱ 2327 │   │   │   │   │   │   out = Interpreter(mod).run(*args[params_len:], **kwargs)          │
│   2328 │   │   │   else:                                                                         │
│   2329 │   │   │   │   out = mod(*args[params_len:], **kwargs)                                   │
│   2330                                                                                           │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/fx/interpreter.py:136 in run               │
│                                                                                                  │
│   133 │   │   │   │   continue                                                                   │
│   134 │   │   │                                                                                  │
│   135 │   │   │   try:                                                                           │
│ ❱ 136 │   │   │   │   self.env[node] = self.run_node(node)                                       │
│   137 │   │   │   except Exception as e:                                                         │
│   138 │   │   │   │   msg = f"While executing {node.format_node()}"                              │
│   139 │   │   │   │   msg = '{}\n\n{}'.format(e.args[0], msg) if e.args else str(msg)            │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/fx/interpreter.py:177 in run_node          │
│                                                                                                  │
│   174 │   │   │   args, kwargs = self.fetch_args_kwargs_from_env(n)                              │
│   175 │   │   │   assert isinstance(args, tuple)                                                 │
│   176 │   │   │   assert isinstance(kwargs, dict)                                                │
│ ❱ 177 │   │   │   return getattr(self, n.op)(n.target, args, kwargs)                             │
│   178 │                                                                                          │
│   179 │   # Main Node running APIs                                                               │
│   180 │   @compatibility(is_backward_compatible=True)                                            │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/fx/interpreter.py:294 in call_module       │
│                                                                                                  │
│   291 │   │   assert isinstance(target, str)                                                     │
│   292 │   │   submod = self.fetch_attr(target)                                                   │
│   293 │   │                                                                                      │
│ ❱ 294 │   │   return submod(*args, **kwargs)                                                     │
│   295 │                                                                                          │
│   296 │   @compatibility(is_backward_compatible=True)                                            │
│   297 │   def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str,    │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1482 in _call_impl    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │
│   1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/alpha/clone/sd-scripts/networks/lora.py:44 in forward                                      │
│                                                                                                  │
│    41 │   del self.org_module                                                                    │
│    42                                                                                            │
│    43   def forward(self, x):                                                                    │
│ ❱  44 │   return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier         │
│    45                                                                                            │
│    46                                                                                            │
│    47 def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs):            │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1482 in _call_impl    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │
│   1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/nn/modules/linear.py:114 in forward        │
│                                                                                                  │
│   111 │   │   │   init.uniform_(self.bias, -bound, bound)                                        │
│   112 │                                                                                          │
│   113 │   def forward(self, input: Tensor) -> Tensor:                                            │
│ ❱ 114 │   │   return F.linear(input, self.weight, self.bias)                                     │
│   115 │                                                                                          │
│   116 │   def extra_repr(self) -> str:                                                           │
│   117 │   │   return 'in_features={}, out_features={}, bias={}'.format(                          │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_inductor/overrides.py:37 in               │
│ __torch_function__                                                                               │
│                                                                                                  │
│    34 │   │   │   and replacements[func] in replacements_using_triton_random                     │
│    35 │   │   ):                                                                                 │
│    36 │   │   │   return replacements[func](*args, **kwargs)                                     │
│ ❱  37 │   │   return func(*args, **kwargs)                                                       │
│    38                                                                                            │
│    39                                                                                            │
│    40 patch_functions = AutogradMonkeypatch                                                      │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:825 in          │
│ __torch_dispatch__                                                                               │
│                                                                                                  │
│    822 │   │   │   ), f"{args} {kwargs}"                                                         │
│    823 │   │   │   return converter(self, args[0])                                               │
│    824 │   │                                                                                     │
│ ❱  825 │   │   args, kwargs = self.validate_and_convert_non_fake_tensors(                        │
│    826 │   │   │   func, converter, args, kwargs                                                 │
│    827 │   │   )                                                                                 │
│    828                                                                                           │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:973 in          │
│ validate_and_convert_non_fake_tensors                                                            │
│                                                                                                  │
│    970 │   │   │   │   return converter(self, x)                                                 │
│    971 │   │   │   return x                                                                      │
│    972 │   │                                                                                     │
│ ❱  973 │   │   return tree_map_only(                                                             │
│    974 │   │   │   torch.Tensor,                                                                 │
│    975 │   │   │   validate,                                                                     │
│    976 │   │   │   (args, kwargs),                                                               │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:259 in tree_map_only      │
│                                                                                                  │
│   256 │   ...                                                                                    │
│   257                                                                                            │
│   258 def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree:                  │
│ ❱ 259 │   return tree_map(map_only(ty)(fn), pytree)                                              │
│   260                                                                                            │
│   261 def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool:                         │
│   262 │   flat_args, _ = tree_flatten(pytree)                                                    │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:195 in tree_map           │
│                                                                                                  │
│   192                                                                                            │
│   193 def tree_map(fn: Any, pytree: PyTree) -> PyTree:                                           │
│   194 │   flat_args, spec = tree_flatten(pytree)                                                 │
│ ❱ 195 │   return tree_unflatten([fn(i) for i in flat_args], spec)                                │
│   196                                                                                            │
│   197 Type2 = Tuple[Type[T], Type[S]]                                                            │
│   198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]                                          │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:195 in <listcomp>         │
│                                                                                                  │
│   192                                                                                            │
│   193 def tree_map(fn: Any, pytree: PyTree) -> PyTree:                                           │
│   194 │   flat_args, spec = tree_flatten(pytree)                                                 │
│ ❱ 195 │   return tree_unflatten([fn(i) for i in flat_args], spec)                                │
│   196                                                                                            │
│   197 Type2 = Tuple[Type[T], Type[S]]                                                            │
│   198 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]                                          │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/utils/_pytree.py:244 in inner              │
│                                                                                                  │
│   241 │   │   @functools.wraps(f)                                                                │
│   242 │   │   def inner(x: T) -> Any:                                                            │
│   243 │   │   │   if isinstance(x, ty):                                                          │
│ ❱ 244 │   │   │   │   return f(x)                                                                │
│   245 │   │   │   else:                                                                          │
│   246 │   │   │   │   return x                                                                   │
│   247 │   │   return inner                                                                       │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:965 in validate │
│                                                                                                  │
│    962 │   │   │   │   │   │   f"Can't call metadata mutating ops on non-Fake Tensor inputs. Fo  │
│    963 │   │   │   │   │   )                                                                     │
│    964 │   │   │   │   if not self.allow_non_fake_inputs:                                        │
│ ❱  965 │   │   │   │   │   raise Exception(                                                      │
│    966 │   │   │   │   │   │   f"Please convert all Tensors to FakeTensors first or instantiate  │
│    967 │   │   │   │   │   │   f"with 'allow_non_fake_inputs'. Found in {func}(*{args}, **{kwar  │
│    968 │   │   │   │   │   )                                                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
Exception: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with
'allow_non_fake_inputs'. Found in aten._to_copy.default(*(Parameter containing:
tensor([[ 0.0292,  0.0266,  0.0296,  ...,  0.0353, -0.0317, -0.0230],
        [ 0.0112, -0.0135,  0.0291,  ..., -0.0087,  0.0124,  0.0297],
        [-0.0299,  0.0291, -0.0143,  ..., -0.0097,  0.0106, -0.0191],
        [-0.0344, -0.0083,  0.0227,  ...,  0.0093,  0.0345, -0.0343]],
       device='cuda:0', requires_grad=True),), **{'dtype': torch.float16})

While executing %self_text_model_encoder_layers_0_self_attn_q_proj : [#users=1] =
call_module[target=self_text_model_encoder_layers_0_self_attn_q_proj](args =
(%self_text_model_encoder_layers_0_layer_norm1,), kwargs = {})
Original traceback:
  File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line
209, in forward
    query_states = self.q_proj(hidden_states) * self.scale
 |   File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py",
line 317, in forward
    hidden_states, attn_weights = self.self_attn(
 |   File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py",
line 574, in forward
    layer_outputs = encoder_layer(
 |   File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py",
line 643, in forward
    encoder_outputs = self.encoder(
 |   File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py",
line 722, in forward
    return self.text_model(

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/alpha/clone/sd-scripts/train_network.py:419 in <module>                                    │
│                                                                                                  │
│   416 │   │   │   │   │     help="only training Text Encoder part / Text Encoder関連部分のみ学   │
│   417                                                                                            │
│   418   args = parser.parse_args()                                                               │
│ ❱ 419   train(args)                                                                              │
│   420                                                                                            │
│                                                                                                  │
│ /home/alpha/clone/sd-scripts/train_network.py:283 in train                                       │
│                                                                                                  │
│   280 │   │   with torch.set_grad_enabled(train_text_encoder):                                   │
│   281 │   │     # Get the text embedding for conditioning                                        │
│   282 │   │     input_ids = batch["input_ids"].to(accelerator.device)                            │
│ ❱ 283 │   │     encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenize   │
│   284 │   │                                                                                      │
│   285 │   │   # Sample noise that we'll add to the latents                                       │
│   286 │   │   noise = torch.randn_like(latents, device=latents.device)                           │
│                                                                                                  │
│ /home/alpha/clone/sd-scripts/library/train_util.py:1257 in get_hidden_states                     │
│                                                                                                  │
│   1254   if args.clip_skip is None:                                                              │
│   1255 │   encoder_hidden_states = text_encoder(input_ids)[0]                                    │
│   1256   else:                                                                                   │
│ ❱ 1257 │   enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)        │
│   1258 │   encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]                     │
│   1259 │   if weight_dtype is not None:                                                          │
│   1260 │     # this is required for additional network training                                  │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1482 in _call_impl    │
│                                                                                                  │
│   1479 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1480 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1481 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1482 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1483 │   │   # Do not call functions when jit is used                                          │
│   1484 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1485 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/accelerate/utils/operations.py:490 in __call__   │
│                                                                                                  │
│   487 │   │   update_wrapper(self, model_forward)                                                │
│   488 │                                                                                          │
│   489 │   def __call__(self, *args, **kwargs):                                                   │
│ ❱ 490 │   │   return convert_to_fp32(self.model_forward(*args, **kwargs))                        │
│   491 │                                                                                          │
│   492 │   def __getstate__(self):                                                                │
│   493 │   │   raise pickle.PicklingError(                                                        │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/amp/autocast_mode.py:14 in                 │
│ decorate_autocast                                                                                │
│                                                                                                  │
│    11 │   @functools.wraps(func)                                                                 │
│    12 │   def decorate_autocast(*args, **kwargs):                                                │
│    13 │   │   with autocast_instance:                                                            │
│ ❱  14 │   │   │   return func(*args, **kwargs)                                                   │
│    15 │   decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in    │
│    16 │   return decorate_autocast                                                               │
│    17                                                                                            │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:83 in forward        │
│                                                                                                  │
│    80 │   │   return getattr(self._orig_mod, name)                                               │
│    81 │                                                                                          │
│    82 │   def forward(self, *args, **kwargs):                                                    │
│ ❱  83 │   │   return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)                    │
│    84                                                                                            │
│    85                                                                                            │
│    86 def remove_from_cache(f):                                                                  │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:212 in _fn           │
│                                                                                                  │
│   209 │   │   │   dynamic_ctx = enable_dynamic(self.dynamic)                                     │
│   210 │   │   │   dynamic_ctx.__enter__()                                                        │
│   211 │   │   │   try:                                                                           │
│ ❱ 212 │   │   │   │   return fn(*args, **kwargs)                                                 │
│   213 │   │   │   finally:                                                                       │
│   214 │   │   │   │   set_eval_frame(prior)                                                      │
│   215 │   │   │   │   dynamic_ctx.__exit__(None, None, None)                                     │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:333 in catch_errors  │
│                                                                                                  │
│   330 │   │   │   │   │   return hijacked_callback(frame, cache_size, hooks)                     │
│   331 │   │                                                                                      │
│   332 │   │   with compile_lock:                                                                 │
│ ❱ 333 │   │   │   return callback(frame, cache_size, hooks)                                      │
│   334 │                                                                                          │
│   335 │   catch_errors._torchdynamo_orig_callable = callback  # type: ignore[attr-defined]       │
│   336 │   return catch_errors                                                                    │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:480 in            │
│ _convert_frame                                                                                   │
│                                                                                                  │
│   477 │   def _convert_frame(frame: types.FrameType, cache_size: int, hooks: Hooks):             │
│   478 │   │   counters["frames"]["total"] += 1                                                   │
│   479 │   │   try:                                                                               │
│ ❱ 480 │   │   │   result = inner_convert(frame, cache_size, hooks)                               │
│   481 │   │   │   counters["frames"]["ok"] += 1                                                  │
│   482 │   │   │   return result                                                                  │
│   483 │   │   except (NotImplementedError, Unsupported):                                         │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:103 in _fn        │
│                                                                                                  │
│   100 │   │   prior_fwd_from_src = torch.fx.graph_module._forward_from_src                       │
│   101 │   │   torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result          │
│   102 │   │   try:                                                                               │
│ ❱ 103 │   │   │   return fn(*args, **kwargs)                                                     │
│   104 │   │   finally:                                                                           │
│   105 │   │   │   torch._C._set_grad_enabled(prior_grad_mode)                                    │
│   106 │   │   │   torch.random.set_rng_state(rng_state)                                          │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py:94 in time_wrapper        │
│                                                                                                  │
│     91 │   │   if key not in compilation_metrics:                                                │
│     92 │   │   │   compilation_metrics[key] = []                                                 │
│     93 │   │   t0 = time.time()                                                                  │
│ ❱   94 │   │   r = func(*args, **kwargs)                                                         │
│     95 │   │   latency = time.time() - t0                                                        │
│     96 │   │   # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec")                    │
│     97 │   │   compilation_metrics[key].append(latency)                                          │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:339 in            │
│ _convert_frame_assert                                                                            │
│                                                                                                  │
│   336 │   │   global initial_grad_state                                                          │
│   337 │   │   initial_grad_state = torch.is_grad_enabled()                                       │
│   338 │   │                                                                                      │
│ ❱ 339 │   │   return _compile(                                                                   │
│   340 │   │   │   frame.f_code,                                                                  │
│   341 │   │   │   frame.f_globals,                                                               │
│   342 │   │   │   frame.f_locals,                                                                │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:400 in _compile   │
│                                                                                                  │
│   397 │   try:                                                                                   │
│   398 │   │   for attempt in itertools.count():                                                  │
│   399 │   │   │   try:                                                                           │
│ ❱ 400 │   │   │   │   out_code = transform_code_object(code, transform)                          │
│   401 │   │   │   │   orig_code_map[out_code] = code                                             │
│   402 │   │   │   │   break                                                                      │
│   403 │   │   │   except exc.RestartAnalysis:                                                    │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py:341 in  │
│ transform_code_object                                                                            │
│                                                                                                  │
│   338 │   instructions = cleaned_instructions(code, safe)                                        │
│   339 │   propagate_line_nums(instructions)                                                      │
│   340 │                                                                                          │
│ ❱ 341 │   transformations(instructions, code_options)                                            │
│   342 │                                                                                          │
│   343 │   fix_vars(instructions, code_options)                                                   │
│   344                                                                                            │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py:387 in transform  │
│                                                                                                  │
│   384 │   │   │   export,                                                                        │
│   385 │   │   │   mutated_closure_cell_contents,                                                 │
│   386 │   │   )                                                                                  │
│ ❱ 387 │   │   tracer.run()                                                                       │
│   388 │   │   output = tracer.output                                                             │
│   389 │   │   assert output is not None                                                          │
│   390 │   │   assert output.output_instructions                                                  │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1692 in run    │
│                                                                                                  │
│   1689 │                                                                                         │
│   1690 │   def run(self):                                                                        │
│   1691 │   │   _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")  │
│ ❱ 1692 │   │   super().run()                                                                     │
│   1693 │                                                                                         │
│   1694 │   def match_nested_cell(self, name, cell):                                              │
│   1695 │   │   """Match a cell in this method to one in a function we are inlining"""            │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:538 in run     │
│                                                                                                  │
│    535 │   │   │   while (                                                                       │
│    536 │   │   │   │   self.instruction_pointer is not None                                      │
│    537 │   │   │   │   and not self.output.should_exit                                           │
│ ❱  538 │   │   │   │   and self.step()                                                           │
│    539 │   │   │   ):                                                                            │
│    540 │   │   │   │   pass                                                                      │
│    541 │   │   except BackendCompilerFailed:                                                     │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:501 in step    │
│                                                                                                  │
│    498 │   │   try:                                                                              │
│    499 │   │   │   if not hasattr(self, inst.opname):                                            │
│    500 │   │   │   │   unimplemented(f"missing: {inst.opname}")                                  │
│ ❱  501 │   │   │   getattr(self, inst.opname)(inst)                                              │
│    502 │   │   │                                                                                 │
│    503 │   │   │   return inst.opname != "RETURN_VALUE"                                          │
│    504 │   │   except BackendCompilerFailed:                                                     │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py:1758 in        │
│ RETURN_VALUE                                                                                     │
│                                                                                                  │
│   1755 │   │   │   f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",             │
│   1756 │   │   )                                                                                 │
│   1757 │   │   log.debug("RETURN_VALUE triggered compile")                                       │
│ ❱ 1758 │   │   self.output.compile_subgraph(self)                                                │
│   1759 │   │   self.output.add_output_instructions([create_instruction("RETURN_VALUE")])         │
│   1760                                                                                           │
│   1761                                                                                           │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:551 in             │
│ compile_subgraph                                                                                 │
│                                                                                                  │
│   548 │   │   │   output = []                                                                    │
│   549 │   │   │   if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:              │
│   550 │   │   │   │   output.extend(                                                             │
│ ❱ 551 │   │   │   │   │   self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)    │
│   552 │   │   │   │   )                                                                          │
│   553 │   │   │   │                                                                              │
│   554 │   │   │   │   if len(pass2.graph_outputs) != 0:                                          │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:598 in             │
│ compile_and_call_fx_graph                                                                        │
│                                                                                                  │
│   595 │   │                                                                                      │
│   596 │   │   assert_no_fake_params_or_buffers(gm)                                               │
│   597 │   │   with tracing(self.tracing_context):                                                │
│ ❱ 598 │   │   │   compiled_fn = self.call_user_compiler(gm)                                      │
│   599 │   │   compiled_fn = disable(compiled_fn)                                                 │
│   600 │   │                                                                                      │
│   601 │   │   counters["stats"]["unique_graphs"] += 1                                            │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py:679 in             │
│ call_user_compiler                                                                               │
│                                                                                                  │
│   676 │   │   │   assert callable(compiled_fn), "compiler_fn did not return callable"            │
│   677 │   │   except Exception as e:                                                             │
│   678 │   │   │   compiled_fn = gm.forward                                                       │
│ ❱ 679 │   │   │   raise BackendCompilerFailed(self.compiler_fn, e) from e                        │
│   680 │   │   return compiled_fn                                                                 │
│   681 │                                                                                          │
│   682 │   def fake_example_inputs(self) -> List[torch.Tensor]:                                   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
BackendCompilerFailed: compile_fx raised Exception: Please convert all Tensors to FakeTensors first or
instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten._to_copy.default(*(Parameter
containing:
tensor([[ 0.0292,  0.0266,  0.0296,  ...,  0.0353, -0.0317, -0.0230],
        [ 0.0112, -0.0135,  0.0291,  ..., -0.0087,  0.0124,  0.0297],
        [-0.0299,  0.0291, -0.0143,  ..., -0.0097,  0.0106, -0.0191],
        [-0.0344, -0.0083,  0.0227,  ...,  0.0093,  0.0345, -0.0343]],
       device='cuda:0', requires_grad=True),), **{'dtype': torch.float16})

While executing %self_text_model_encoder_layers_0_self_attn_q_proj : [#users=1] =
call_module[target=self_text_model_encoder_layers_0_self_attn_q_proj](args =
(%self_text_model_encoder_layers_0_layer_norm1,), kwargs = {})
Original traceback:
  File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py", line
209, in forward
    query_states = self.q_proj(hidden_states) * self.scale
 |   File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py",
line 317, in forward
    hidden_states, attn_weights = self.self_attn(
 |   File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py",
line 574, in forward
    layer_outputs = encoder_layer(
 |   File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py",
line 643, in forward
    encoder_outputs = self.encoder(
 |   File "/home/alpha/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.py",
line 722, in forward
    return self.text_model(

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

steps:   0%|                                                                    | 0/1600 [00:03<?, ?it/s]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/alpha/.local/bin/accelerate:8 in <module>                                                  │
│                                                                                                  │
│   5 from accelerate.commands.accelerate_cli import main                                          │
│   6 if __name__ == '__main__':                                                                   │
│   7 │   sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])                         │
│ ❱ 8 │   sys.exit(main())                                                                         │
│   9                                                                                              │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py:45 in main │
│                                                                                                  │
│   42 │   │   exit(1)                                                                             │
│   43 │                                                                                           │
│   44 │   # Run                                                                                   │
│ ❱ 45 │   args.func(args)                                                                         │
│   46                                                                                             │
│   47                                                                                             │
│   48 if __name__ == "__main__":                                                                  │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/accelerate/commands/launch.py:1104 in            │
│ launch_command                                                                                   │
│                                                                                                  │
│   1101 │   elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMA  │
│   1102 │   │   sagemaker_launcher(defaults, args)                                                │
│   1103 │   else:                                                                                 │
│ ❱ 1104 │   │   simple_launcher(args)                                                             │
│   1105                                                                                           │
│   1106                                                                                           │
│   1107 def main():                                                                               │
│                                                                                                  │
│ /home/alpha/.local/lib/python3.10/site-packages/accelerate/commands/launch.py:567 in             │
│ simple_launcher                                                                                  │
│                                                                                                  │
│    564 │   process = subprocess.Popen(cmd, env=current_env)                                      │
│    565 │   process.wait()                                                                        │
│    566 │   if process.returncode != 0:                                                           │
│ ❱  567 │   │   raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)       │
│    568                                                                                           │
│    569                                                                                           │
│    570 def multi_gpu_launcher(args):                                                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
CalledProcessError: Command '['/usr/bin/python', 'train_network.py',
'--pretrained_model_name_or_path=/home/alpha/Storage/AIModels/Stable-diffusion/panatomy05full_0.7-AIModel
s_Anything-V3.0-pruned-fp16_0.3-Weighted_sum-merged.ckpt',
'--train_data_dir=/home/alpha/Storage/TrainingData/test/training_data',
'--output_dir=/home/alpha/Storage/TrainingOutput/test/', '--prior_loss_weight=1.0',
'--resolution=512,512', '--train_batch_size=1', '--learning_rate=1e-5', '--max_train_steps=1600',
'--use_8bit_adam', '--xformers', '--mixed_precision=fp16', '--cache_latents', '--save_precision=fp16',
'--save_model_as=safetensors', '--clip_skip=2', '--network_module=networks.lora']' returned non-zero exit
status 1.

(While the same arguments work with TorchDynamo disabled.

Maybe torch.compile() needs to be added conditionally and manually, instead of automatically with accelerate?

lolxdmainkaisemaanlu commented 1 year ago

If this could be implemented, it would be awesome. 6 GB VRAM cards would be able to train better and faster.

iqiancheng commented 11 months ago

same issue