Closed zhouwei5113 closed 2 years ago
@property
def num_training_steps(self) -> int:
"""Total training steps inferred from datamodule and devices."""
#if self.trainer.max_steps:
# return self.trainer.max_steps
#dataset = self.train_dataloader()
dataset = self.trainer._data_connector._train_dataloader_source.dataloader()
dataset_size = len(dataset)
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)
effective_batch_size = dataset.batch_size * self.trainer.accumulate_grad_batches * num_devices
#print(dataset.batch_size, self.trainer.accumulate_grad_batches, num_devices)
#print(dataset_size, effective_batch_size, self.trainer.max_epochs)
#num_steps = (dataset_size // effective_batch_size) * self.trainer.max_epochs
num_steps = dataset_size * self.trainer.max_epochs // (self.trainer.accumulate_grad_batches * num_devices)
print(num_steps)
return num_steps
You are like a hero to me
HI, where should I paste the code snippet to fix this issue?
Doesn't work with Lightning 2.0
Error: AttributeError: '_DataConnector' object has no attribute '_train_dataloader_source'
For Lightning 2.0
def num_steps(self) -> int:
"""Get number of steps"""
# Accessing _data_source is flaky and might break
dataset = self.trainer.fit_loop._data_source.dataloader()
dataset_size = len(dataset)
num_devices = max(1, self.trainer.num_devices)
num_steps = dataset_size * self.trainer.max_epochs // (self.trainer.accumulate_grad_batches * num_devices)
return num_steps
Traceback (most recent call last): File "train_finetune.py", line 39, in
main(args)
File "train_finetune.py", line 29, in main
trainer.fit(model, dm)
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in fit
self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
return trainer_fn(*args, *kwargs)
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1145, in _run
self.accelerator.setup(self)
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu.py", line 46, in setup
return super().setup(trainer)
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 93, in setup
self.setup_optimizers(trainer)
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 355, in setup_optimizers
trainer=trainer, model=self.lightning_module
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 245, in init_optimizers
return trainer.init_optimizers(model)
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/trainer/optimizers.py", line 35, in init_optimizers
optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module)
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1501, in call_hook
output = model_fx(args, **kwargs)
File "/workspace/cpfs-data/workspace_pytorch/hclip/train-CLIP/models/wrapper.py", line 337, in configure_optimizers
first_cycle_steps=self.num_training_steps,
File "/workspace/cpfs-data/workspace_pytorch/hclip/train-CLIP/models/wrapper.py", line 38, in num_training_steps
dataset = self.train_dataloader()
File "/workspace/cpfs-data/miniconda3/envs/tensorflow/lib/python3.7/site-packages/pytorch_lightning/core/hooks.py", line 477, in train_dataloader
raise NotImplementedError("
train_dataloader
must be implemented to be used with the Lightning Trainer") NotImplementedError:train_dataloader
must be implemented to be used with the Lightning Trainer