CompVis / stable-diffusion

A latent text-to-image diffusion model
https://ommer-lab.com/research/latent-diffusion-models/
Other
66.51k stars 9.97k forks source link

CheckpointFunction: AttributeError: 'NoneType' object has no attribute 'detach' #857

Closed sivannavis closed 3 days ago

sivannavis commented 2 weeks ago

Hi! When training the Unet my codes run through the forward pass and got loss calculated, but as soon as it goes into the backward pass, at this line, the error comes up AttributeError: 'NoneType' object has no attribute 'detach' https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/util.py#L132

The full error goes:

Traceback (most recent call last):
  File "/home/szding/.pycharm_helpers/pydev/pydevd.py", line 2236, in <module>
    main()
  File "/home/szding/.pycharm_helpers/pydev/pydevd.py", line 2218, in main
    globals = debugger.run(setup['file'], None, None, is_module)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.pycharm_helpers/pydev/pydevd.py", line 1528, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.pycharm_helpers/pydev/pydevd.py", line 1535, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/szding/v2sa/v2sa/trainer/trainer_ldm.py", line 232, in <module>
    main(config_yaml, arguments.exp_group_name, arguments.exp_name)
  File "/home/szding/v2sa/v2sa/trainer/trainer_ldm.py", line 211, in main
    trainer.fit(model, datamodule, ckpt_path=last_ckpt)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py", line 1033, in _run_stage
    self.fit_loop.run()
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 250, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 190, in run
    self._optimizer_step(batch_idx, closure)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/core/module.py", line 1303, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py", line 152, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 239, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py", line 122, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/optim/optimizer.py", line 391, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/optim/adamw.py", line 165, in step
    loss = closure()
           ^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py", line 108, in _wrap_closure
    closure_result = closure()
                     ^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 138, in closure
    self._backward_fn(step_output.closure_loss)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 239, in backward_fn
    call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py", line 213, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py", line 72, in backward
    model.backward(tensor, *args, **kwargs)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/pytorch_lightning/core/module.py", line 1090, in backward
    loss.backward(*args, **kwargs)
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/.conda/envs/v2sa/lib/python3.12/site-packages/torch/autograd/function.py", line 301, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/szding/v2sa/v2sa/models/ldm/ldm_utils.py", line 202, in backward
    ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
                         ^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'detach'

And when I check what's inside ctx.input_tensors, it turns out it's a list of 2 element where the second element is None. Any idea what this line is trying to do and why is there a None? image

pwppwpwpw commented 5 days ago

have you sovle it?

sivannavis commented 3 days ago

I think it's some checkpoint function behavior in UNet model, I changed it and now it worked