Mikubill / naifu

Train generative models with pytorch lightning
MIT License
294 stars 38 forks source link

AttributeError: 'StableDiffusionModel' object has no attribute 'data_sampler' #25

Closed biasnhbi closed 1 year ago

biasnhbi commented 1 year ago

The following error occurred after opening cache_latents diffusers==0.17.0


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:612: UserWarning: Checkpoint directory /home/ubuntu/checkpoint exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Loading captions: 967it [00:00, 6004.40it/s]
BucketManager initialized with base_res = [512, 512], max_size = [768, 512]
Loading resolutions: 967it [00:09, 102.76it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using scaled LR: 2e-06

  | Name         | Type                 | Params
------------------------------------------------------
0 | unet         | UNet2DConditionModel | 859 M 
1 | vae          | AutoencoderKL        | 83.7 M
2 | text_encoder | CLIPTextModel        | 123 M 
------------------------------------------------------
859 M     Trainable params
206 M     Non-trainable params
1.1 B     Total params
2,132.471 Total estimated model params size (MB)
/home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:224: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/fabric/utilities/data.py:63: UserWarning: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
  rank_zero_warn(
Training: 0it [00:00, ?it/s]╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/ubuntu/naifu-diffusion/trainer.py:116 in <module>                                          │
│                                                                                                  │
│   113                                                                                            │
│   114 if __name__ == "__main__":                                                                 │
│   115 │   args = parse_args()                                                                    │
│ ❱ 116 │   main(args)                                                                             │
│   117                                                                                            │
│                                                                                                  │
│ /home/ubuntu/naifu-diffusion/trainer.py:112 in main                                              │
│                                                                                                  │
│   109 │                                                                                          │
│   110 │   config, callbacks = pl_compat_fix(config, callbacks)                                   │
│   111 │   trainer = pl.Trainer(logger=logger, callbacks=callbacks, strategy=strategy, plugins=   │
│ ❱ 112 │   trainer.fit(model=model, ckpt_path=args.resume if args.resume else None)               │
│   113                                                                                            │
│   114 if __name__ == "__main__":                                                                 │
│   115 │   args = parse_args()                                                                    │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.p │
│ y:608 in fit                                                                                     │
│                                                                                                  │
│    605 │   │   if not isinstance(model, pl.LightningModule):                                     │
│    606 │   │   │   raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.  │
│    607 │   │   self.strategy._lightning_module = model                                           │
│ ❱  608 │   │   call._call_and_handle_interrupt(                                                  │
│    609 │   │   │   self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule,  │
│    610 │   │   )                                                                                 │
│    611                                                                                           │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:3 │
│ 8 in _call_and_handle_interrupt                                                                  │
│                                                                                                  │
│   35 │   │   if trainer.strategy.launcher is not None:                                           │
│   36 │   │   │   return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer,     │
│   37 │   │   else:                                                                               │
│ ❱ 38 │   │   │   return trainer_fn(*args, **kwargs)                                              │
│   39 │                                                                                           │
│   40 │   except _TunerExitException:                                                             │
│   41 │   │   trainer._call_teardown_hook()                                                       │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.p │
│ y:650 in _fit_impl                                                                               │
│                                                                                                  │
│    647 │   │   │   model_provided=True,                                                          │
│    648 │   │   │   model_connected=self.lightning_module is not None,                            │
│    649 │   │   )                                                                                 │
│ ❱  650 │   │   self._run(model, ckpt_path=self.ckpt_path)                                        │
│    651 │   │                                                                                     │
│    652 │   │   assert self.state.stopped                                                         │
│    653 │   │   self.training = False                                                             │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.p │
│ y:1103 in _run                                                                                   │
│                                                                                                  │
│   1100 │   │                                                                                     │
│   1101 │   │   self._checkpoint_connector.resume_end()                                           │
│   1102 │   │                                                                                     │
│ ❱ 1103 │   │   results = self._run_stage()                                                       │
│   1104 │   │                                                                                     │
│   1105 │   │   log.detail(f"{self.__class__.__name__}: trainer tearing down")                    │
│   1106 │   │   self._teardown()                                                                  │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.p │
│ y:1182 in _run_stage                                                                             │
│                                                                                                  │
│   1179 │   │   │   return self._run_evaluate()                                                   │
│   1180 │   │   if self.predicting:                                                               │
│   1181 │   │   │   return self._run_predict()                                                    │
│ ❱ 1182 │   │   self._run_train()                                                                 │
│   1183 │                                                                                         │
│   1184 │   def _pre_training_routine(self) -> None:                                              │
│   1185 │   │   # wait for all to join if on distributed                                          │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.p │
│ y:1205 in _run_train                                                                             │
│                                                                                                  │
│   1202 │   │   self.fit_loop.trainer = self                                                      │
│   1203 │   │                                                                                     │
│   1204 │   │   with torch.autograd.set_detect_anomaly(self._detect_anomaly):                     │
│ ❱ 1205 │   │   │   self.fit_loop.run()                                                           │
│   1206 │                                                                                         │
│   1207 │   def _run_evaluate(self) -> _EVALUATE_OUTPUT:                                          │
│   1208 │   │   assert self.evaluating                                                            │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/loops/loop.py:194 │
│ in run                                                                                           │
│                                                                                                  │
│   191 │   │                                                                                      │
│   192 │   │   self.reset()                                                                       │
│   193 │   │                                                                                      │
│ ❱ 194 │   │   self.on_run_start(*args, **kwargs)                                                 │
│   195 │   │                                                                                      │
│   196 │   │   while not self.done:                                                               │
│   197 │   │   │   try:                                                                           │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py │
│ :218 in on_run_start                                                                             │
│                                                                                                  │
│   215 │   │   self._results.to(device=self.trainer.lightning_module.device)                      │
│   216 │   │                                                                                      │
│   217 │   │   self.trainer._call_callback_hooks("on_train_start")                                │
│ ❱ 218 │   │   self.trainer._call_lightning_module_hook("on_train_start")                         │
│   219 │   │   self.trainer._call_strategy_hook("on_train_start")                                 │
│   220 │                                                                                          │
│   221 │   def on_advance_start(self) -> None:                                                    │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.p │
│ y:1347 in _call_lightning_module_hook                                                            │
│                                                                                                  │
│   1344 │   │   pl_module._current_fx_name = hook_name                                            │
│   1345 │   │                                                                                     │
│   1346 │   │   with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{ho  │
│ ❱ 1347 │   │   │   output = fn(*args, **kwargs)                                                  │
│   1348 │   │                                                                                     │
│   1349 │   │   # restore current_fx when nested context                                          │
│   1350 │   │   pl_module._current_fx_name = prev_fx_name                                         │
│                                                                                                  │
│ /home/ubuntu/naifu-diffusion/lib/model.py:288 in on_train_start                                  │
│                                                                                                  │
│   285 │   │   │   self.ema.to(self.device, dtype=self.unet.dtype)                                │
│   286 │   │                                                                                      │
│   287 │   │   if self.use_latent_cache:                                                          │
│ ❱ 288 │   │   │   self.dataset.cache_latents(self.vae, self.data_sampler.buckets if self.confi   │
│   289 │                                                                                          │
│   290 │   def on_train_epoch_start(self) -> None:                                                │
│   291 │   │   if self.use_latent_cache:                                                          │
│                                                                                                  │
│ /home/ubuntu/miniconda3/envs/nd/lib/python3.10/site-packages/torch/nn/modules/module.py:1614 in  │
│ __getattr__                                                                                      │
│                                                                                                  │
│   1611 │   │   │   modules = self.__dict__['_modules']                                           │
│   1612 │   │   │   if name in modules:                                                           │
│   1613 │   │   │   │   return modules[name]                                                      │
│ ❱ 1614 │   │   raise AttributeError("'{}' object has no attribute '{}'".format(                  │
│   1615 │   │   │   type(self).__name__, name))                                                   │
│   1616 │                                                                                         │
│   1617 │   def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'StableDiffusionModel' object has no attribute 'data_sampler'
Mikubill commented 1 year ago

fixed in https://github.com/Mikubill/naifu-diffusion/commit/ee665e5eee966c4800a8271fc9e8961061aaf8d1