Closed amorehead closed 3 years ago
Have you saved anything yourself to the checkpoint in any of the hooks, or are you just using the default checkpointing that PL provides? Are interacting with the process group in any way?
@awaelchli, to the best of my knowledge, I have not saved anything to the checkpoints in any of my hooks. I am using the following hooks for training, validation, and testing.
def training_step(self, batch, batch_idx):
"""Lightning calls this inside the training loop."""
graph1, graph2, examples = batch[0], batch[1], batch[2]
# Forward propagate with network layers
logits = self(graph1, graph2)
# Make predictions
preds = torch.sigmoid(logits)
int_labels = examples[:, 2].int()
# Calculate loss and other metrics
loss = loss_fn(logits, examples[:, 2].float()) # Calculate loss of a single sample
train_acc = self.train_acc(preds, int_labels) # Calculate Accuracy of a single sample
train_prec = self.train_prec(preds, int_labels) # Calculate Precision of a single sample
train_recall = self.train_recall(preds, int_labels) # Calculate Recall of a single sample
# Log training step metric(s)
self.log('train_bce', loss)
return {
'loss': loss,
'train_acc': train_acc,
'train_prec': train_prec,
'train_recall': train_recall
}
def training_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None:
"""Lightning calls this at the end of every training epoch."""
# Tuplize scores for the current device (e.g. Rank 0)
train_accs = torch.cat([output_dict['train_acc'].unsqueeze(0) for output_dict in outputs])
train_precs = torch.cat([output_dict['train_prec'].unsqueeze(0) for output_dict in outputs])
train_recalls = torch.cat([output_dict['train_recall'].unsqueeze(0) for output_dict in outputs])
# Concatenate scores over all devices (e.g. Rank 0 | ... | Rank N) - Warning: Memory Intensive
train_accs = torch.cat([train_acc for train_acc in self.all_gather(train_accs)], dim=0)
train_precs = torch.cat([train_prec for train_prec in self.all_gather(train_precs)], dim=0)
train_recalls = torch.cat([train_recall for train_recall in self.all_gather(train_recalls)], dim=0)
# Reset training TorchMetrics for all devices
self.train_acc.reset()
self.train_prec.reset()
self.train_recall.reset()
# Log metric(s) aggregated from all ranks
self.log('med_train_acc', torch.median(train_accs)) # Log MedAccuracy of an epoch
self.log('med_train_prec', torch.median(train_precs)) # Log MedPrecision of an epoch
self.log('med_train_recall', torch.median(train_recalls)) # Log MedRecall of an epoch
def validation_step(self, batch, batch_idx):
"""Lightning calls this inside the validation loop."""
graph1, graph2, examples = batch[0], batch[1], batch[2]
# Forward propagate with network layers
logits = self(graph1, graph2)
# Make predictions
preds = torch.sigmoid(logits)
int_labels = examples[:, 2].int()
# Calculate loss and other metrics
loss = self.loss_fn(sampled_logits, examples[:, 2].float()) # Calculate loss of a single sample
val_acc = self.val_acc(preds, int_labels) # Calculate Accuracy of a single sample
val_prec = self.val_prec(preds, int_labels) # Calculate Precision of a single sample
val_recall = self.val_recall(preds, int_labels) # Calculate Recall of a single sample
val_f1 = self.val_f1(preds, int_labels) # Calculate F1 score of a single sample
val_auroc = self.val_auroc(preds, int_labels) # Calculate AUROC of a sample
val_auprc = self.val_auprc(preds, int_labels) # Calculate AveragePrecision (i.e. AUPRC) of a sample
# Log validation step metric(s)
self.log('val_bce', loss, sync_dist=True)
return {
'loss': loss,
'val_acc': val_acc,
'val_prec': val_prec,
'val_recall': val_recall,
'val_f1': val_f1,
'val_auroc': val_auroc,
'val_auprc': val_auprc
}
def validation_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None:
"""Lightning calls this at the end of every validation epoch."""
# Tuplize scores for the current device (e.g. Rank 0)
val_accs = torch.cat([output_dict['val_acc'].unsqueeze(0) for output_dict in outputs])
val_precs = torch.cat([output_dict['val_prec'].unsqueeze(0) for output_dict in outputs])
val_recalls = torch.cat([output_dict['val_recall'].unsqueeze(0) for output_dict in outputs])
val_f1s = torch.cat([output_dict['val_f1'].unsqueeze(0) for output_dict in outputs])
val_aurocs = torch.cat([output_dict['val_auroc'].unsqueeze(0) for output_dict in outputs])
val_auprcs = torch.cat([output_dict['val_auprc'].unsqueeze(0) for output_dict in outputs])
# Concatenate scores over all devices (e.g. Rank 0 | ... | Rank N) - Warning: Memory Intensive
val_accs = torch.cat([val_acc for val_acc in self.all_gather(val_accs)], dim=0)
val_precs = torch.cat([val_prec for val_prec in self.all_gather(val_precs)], dim=0)
val_recalls = torch.cat([val_recall for val_recall in self.all_gather(val_recalls)], dim=0)
val_f1s = torch.cat([val_f1 for val_f1 in self.all_gather(val_f1s)], dim=0)
val_aurocs = torch.cat([val_auroc for val_auroc in self.all_gather(val_aurocs)], dim=0)
val_auprcs = torch.cat([val_auprc for val_auprc in self.all_gather(val_auprcs)], dim=0)
# Reset validation TorchMetrics for all devices
self.val_acc.reset()
self.val_prec.reset()
self.val_recall.reset()
self.val_f1.reset()
self.val_auroc.reset()
self.val_auprc.reset()
# Log metric(s) aggregated from all ranks
self.log('med_val_acc', torch.median(val_accs)) # Log MedAccuracy of an epoch
self.log('med_val_prec', torch.median(val_precs)) # Log MedPrecision of an epoch
self.log('med_val_recall', torch.median(val_recalls)) # Log MedRecall of an epoch
self.log('med_val_f1', torch.median(val_f1s)) # Log MedF1 of an epoch
self.log('med_val_auroc', torch.median(val_aurocs)) # Log MedAUROC of an epoch
self.log('med_val_auprc', torch.median(val_auprcs)) # Log epoch MedAveragePrecision
def test_step(self, batch, batch_idx):
"""Lightning calls this inside the testing loop."""
graph1, graph2, examples = batch[0], batch[1], batch[2]
# Forward propagate with network layers
logits = self(graph1, graph2)
# Make predictions
preds = torch.sigmoid(logits)
int_labels = examples[:, 2].int()
# Calculate loss and other metrics
loss = self.loss_fn(sampled_logits, examples[:, 2].float()) # Calculate loss of a single sample
test_acc = self.test_acc(preds, int_labels) # Calculate Accuracy of a single sample
test_prec = self.test_prec(preds, int_labels) # Calculate Precision of a single sample
test_recall = self.test_recall(preds, int_labels) # Calculate Recall of a single sample
test_f1 = self.test_f1(preds, int_labels) # Calculate F1 score of a single sample
test_auroc = self.test_auroc(preds, int_labels) # Calculate AUROC of a sample
test_auprc = self.test_auprc(preds, int_labels) # Calculate AveragePrecision (i.e. AUPRC) of a sample
# Manually evaluate test performance by collecting all predicted and ground-truth interaction tensors
argmaxed_logits = torch.round(torch.sigmoid(logits)).cpu().detach()
argmaxed_logits = argmaxed_logits.reshape(graph1.num_nodes(), graph2.num_nodes())
argmaxed_examples = examples[:, 2].float().cpu().detach()
argmaxed_examples = argmaxed_examples.reshape(graph1.num_nodes(), graph2.num_nodes())
# Log test step metric(s)
self.log('test_bce', loss, sync_dist=True)
return {
'loss': loss,
'test_acc': test_acc,
'test_prec': test_prec,
'test_recall': test_recall,
'test_f1': test_f1,
'test_auroc': test_auroc,
'test_auprc': test_auprc,
'test_preds': argmaxed_logits,
'test_labels': argmaxed_examples
}
def test_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT):
"""Lightning calls this at the end of every test epoch."""
# Tuplize scores for the current device (e.g. Rank 0)
test_accs = torch.cat([output_dict['test_acc'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
test_precs = torch.cat([output_dict['test_prec'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
test_recalls = torch.cat([output_dict['test_recall'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
test_f1s = torch.cat([output_dict['test_f1'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
test_aurocs = torch.cat([output_dict['test_auroc'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
test_auprcs = torch.cat([output_dict['test_auprc'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
# Concatenate scores over all devices (e.g. Rank 0 | ... | Rank N) - Warning: Memory Intensive
test_accs = torch.cat([test_acc for test_acc in self.all_gather(test_accs)])
test_precs = torch.cat([test_prec for test_prec in self.all_gather(test_precs)])
test_recalls = torch.cat([test_recall for test_recall in self.all_gather(test_recalls)])
test_f1s = torch.cat([test_f1 for test_f1 in self.all_gather(test_f1s)])
test_aurocs = torch.cat([test_auroc for test_auroc in self.all_gather(test_aurocs)])
test_auprcs = torch.cat([test_auprc for test_auprc in self.all_gather(test_auprcs)])
test_preds = [wandb.Image(output_dict['test_preds']) for output_dict in outputs] # Convert to grayscale image
test_labels = [wandb.Image(output_dict['test_labels']) for output_dict in outputs] # Convert to grayscale image
# Reset test TorchMetrics for all devices
self.test_acc.reset()
self.test_prec.reset()
self.test_recall.reset()
self.test_f1.reset()
self.test_auroc.reset()
self.test_auprc.reset()
# Log metric(s) aggregated from all ranks
self.log('med_test_acc', torch.median(test_accs)) # Log MedAccuracy of an epoch
self.log('med_test_prec', torch.median(test_precs)) # Log MedPrecision of an epoch
self.log('med_test_recall', torch.median(test_recalls)) # Log MedRecall of an epoch
self.log('med_test_f1', torch.median(test_f1s)) # Log MedF1 of an epoch
self.log('med_test_auroc', torch.median(test_aurocs)) # Log MedAUROC of an epoch
self.log('med_test_auprc', torch.median(test_auprcs)) # Log epoch MedAveragePrecision
# Log test predictions with their ground-truth interaction tensors to WandB for visual inspection
if self.hparams['logger_name'].lower() == 'wandb':
self.trainer.logger.experiment.log({'test_preds': test_preds})
self.trainer.logger.experiment.log({'test_labels': test_labels})
I should note that I am logging WandB "Images" to WandB using my LightningModule's self.trainer.logger.experiment object (in the test_epoch_end hook). Not sure if this would be a contributing factor to my original error.
In addition, my training script makes use of PyTorch Lightning's ModelCheckpoint callback to save all of its model checkpoints. It uses a patience of 3 epochs before early-stopping training based on val_bce (i.e., validation binary cross-entropy). All of my checkpoints are created using this ModelCheckpoint callback.
# -----------
# Data
# -----------
# Load data module
data_module = LitDataModule()
data_module.setup()
# ------------
# Model
# ------------
# Assemble a dictionary of model arguments
dict_args = vars(args)
model = LitModel(**dict_args)
args.experiment_name = f'LitModel Experiment {i}'
template_ckpt_filename = 'LitModel-{epoch:02d}-{val_bce:.2f}'
# ------------
# Checkpoint
# ------------
ckpt_path = os.path.join(args.ckpt_dir, args.ckpt_name)
ckpt_path_exists = os.path.exists(ckpt_path)
ckpt_provided = args.ckpt_name != '' and ckpt_path_exists
args.resume_from_checkpoint = ckpt_path if ckpt_provided else None
# ------------
# Trainer
# ------------
trainer = pl.Trainer.from_argparse_args(args)
# ------------
# Logger
# ------------
pl_logger = construct_pl_logger(args) # Log everything to an external logger
trainer.logger = pl_logger # Assign specified logger (e.g. TensorBoardLogger) to Trainer instance
using_wandb_logger = args.logger_name.lower() == 'wandb' # Determine whether the user requested to use WandB
# -----------
# Callbacks
# -----------
# Create and use callbacks
early_stop_callback = pl.callbacks.EarlyStopping(monitor=args.metric_to_track,
mode='min' if 'ce' in args.metric_to_track else 'max',
min_delta=args.min_delta, patience=args.patience)
ckpt_callback = pl.callbacks.ModelCheckpoint(
monitor=args.metric_to_track,
mode='min' if 'ce' in args.metric_to_track else 'max',
save_top_k=5, dirpath=args.ckpt_dir, verbose=True,
filename=template_ckpt_filename # May cause a race condition when calling trainer.test() with many GPUs
)
trainer.callbacks = [early_stop_callback, ckpt_callback]
# ------------
# Restore
# ------------
# If using WandB, download checkpoint file from their servers if the checkpoint is not already stored locally
if using_wandb_logger and ckpt_provided and not os.path.exists(ckpt_path):
# Download checkpoint from WandB
trainer.logger.experiment.restore(ckpt_path, run_path=f'{args.entity}/{args.project_name}/{args.run_id}')
# -------------
# Training
# -------------
# Train with the provided model and DataModule
trainer.fit(model=model, datamodule=data_module)
# Save best checkpoint only on the main process
trainer.save_checkpoint(ckpt_callback.best_model_path)
# -----------
# Testing
# -----------
trainer.test()
# -----------
# Finalizing
# -----------
if using_wandb_logger:
trainer.logger.experiment.save(ckpt_callback.best_model_path)
trainer.logger.experiment.finish()
When it comes to modifying process groups, I do not believe I am performing any manual changes there. Everything regarding process groups should be default by Lightning's standards.
Dear @amorehead,
It seems a progress group was pickled for some reasons. Mind printing the checkpoint content while dropping all tensors ?
Best, T.C
@tchaton, In my understanding, each checkpoint of a pl.trainer instance is saved to storage using torch.save(), so I would need to use torch.load() to load it back into memory to view its contents, right? In which case, which argument(s) can I pass to torch.load(ckpt_filepath) so that it drops all tensors upon loading? Without such an argument, I am encountering the same error as above:
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.
@awaelchli and @tchaton, when I view my checkpoint files in a text editor (Gnome's, for instance), this is what I'm seeing (for its head):
and its tail:
@amorehead is it maybe possible that this only happens with the wandb logger?
could you maybe share the checkpoint file you are trying to reload? I'll try to have a look inside. You can also try to look inside by initializing the progress group as the error message suggests.
@awaelchli, attached below is one of the checkpoints that has been giving me this error. I will also check to see if the WandbLogger is causing this issue.
@awaelchli,
It appears as though the WandBLogger is not the root cause of the issue, at least in my view. I tried swapping it out for the TensorBoardLogger, and I receive the same process group error when I go to call trainer.test().
I was able to load in the checkpoint when initializing the process group manually, as you suggested. I cannot identify any areas in the dictionary where process groups are mentioned, at least on my first few passes.
In addition, I tried running trainer.test() with my (now) loaded-in checkpoint to run inference, and I found that, stemming from my LightningModule's self.save_hyperparameters() call in its init() method, I am seeing a new error. It seems as though my LightningModule is trying to pickle a _SingleProcessDataLoaderIter instance when it saves its hyperparameters.
Any ideas what this may point back to? I am not entirely sure why my LightningModule is trying to pickle a DataLoader.
/home/alexm/anaconda3/envs/DeepInteract/bin/python /home/alexm/Repositories/Lab_Repositories/DeepInteract/project/lit_model_test.py --logger_name WandB --experiment_name LitGINI-b4-gl1-n128-e128-il2-i128-Alex-Local-Model-Testing --online --max_hours 1 --max_minutes 55 --ckpt_dir checkpoints --num_gpus 1 --num_compute_nodes 1 --gpu_precision 32 --num_workers 2 --dips_percent_to_use 1.00 --self_loops --pn_ratio 0.1 --use_dgl --num_epochs 50 --patience 5 --batch_size 1 --accum_grad_batches 1 --lr 3e-4 --weight_decay 1e-4 --dropout_rate 0.1 --model_name gini --gnn_layer_type geotran --num_gnn_layers 1 --num_gnn_hidden_channels 128 --num_gnn_attention_heads 8 --num_interact_layers 128 --interact_module_type resnet --num_interact_hidden_channels 128 --num_interact_attention_heads 8 --patch_size 7 --final_layer_bias_value -7.0 --ckpt_name LitGINI-epoch=00-val_bce=0.31.ckpt
Using backend: pytorch
Global seed set to 42
DB5 cache found
Loaded DB5 train-set, source: datasets/DB5/final/processed, length: 140
DB5 cache found
Loaded DB5 val-set, source: datasets/DB5/final/processed, length: 35
DB5 cache found
Loaded DB5 val-set, source: datasets/DB5/final/processed, length: 16
DB5 cache found
Loaded DB5 test-set, source: datasets/DB5/final/processed, length: 55
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
warnings.warn(*args, **kwargs)
/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AveragePrecision` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
warnings.warn(*args, **kwargs)
Traceback (most recent call last):
File "/home/alexm/Repositories/Lab_Repositories/DeepInteract/project/lit_model_test.py", line 129, in <module>
main(args)
File "/home/alexm/Repositories/Lab_Repositories/DeepInteract/project/lit_model_test.py", line 74, in main
model = LitGINI.load_from_checkpoint(ckpt_path)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 195, in _load_model_state
model = cls(**_cls_kwargs)
File "/home/alexm/Repositories/Lab_Repositories/DeepInteract/project/utils/deepinteract_modules.py", line 1883, in __init__
self.save_hyperparameters()
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/core/mixins/hparams_mixin.py", line 105, in save_hyperparameters
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py", line 252, in save_hyperparameters
obj._hparams_initial = copy.deepcopy(obj._hparams)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 296, in _reconstruct
value = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 205, in _deepcopy_list
append(deepcopy(a, memo))
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 296, in _reconstruct
value = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 296, in _reconstruct
value = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 161, in deepcopy
rv = reductor(4)
File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 547, in __getstate__
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
NotImplementedError: ('{} cannot be pickled', '_SingleProcessDataLoaderIter')
yes you need to exclude any unpickleable objects that are passed in via the constructor.
self.save_hyperparameters(ignore="some_arg_name")
maybe you are passing in dataloaders
After trying to load your checkpoint I am sure you have a hyperparameter object pickled, something specific to your project.
It must be an object you have passed in via __init__
.
@awaelchli, you are absolutely right. I discovered that the issue was caused by my passing command-line arguments from ArgParse directly into my LightningModule's init() method. That is, after converting my ArgParse arguments into a Python dictionary, I was feeding this dictionary in its entirety into the init() method using the unravel syntax (i.e., **dict_args). ArgParse must have picked up a DataLoader instance somehow while it was parsing my provided CLI arguments, so then my LightningModule would try to pickle a DataLoader (which isn't feasible currently).
The issue is fixed on my end, and now I can test all my checkpoints. Thank you very much for your help in troubleshooting this issue with me!
Great! Glad you were able to fix it
🐛 Bug
Cannot train a model using multiple GPUs and then test it using only a single GPU. That is, I have to use the same environment (e.g., number of GPUs/nodes) in which I trained the model to perform testing on a held-out dataset. I am using DDP as my distributed backend.
To Reproduce
Train a model using 16 GPUs across 4 nodes on a computer cluster like Summit and then test one of the checkpoints either on the same cluster (in a second call to the training script with max_epochs=$the_last_epoch_completed) or on a local machine in a separate script calling trainer.test().
Expected behavior
Trainer should be able to initialize the process group such that only 1 GPU is used at inference time, while multiple GPUs can be used to train the model.
Environment
Additional context
Exact error trace I'm seeing: