Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.98k stars 3.35k forks source link

Cannot trainer.test() with checkpoints saved after training a model on multiple GPUs and nodes (using DDP) #9153

Closed amorehead closed 3 years ago

amorehead commented 3 years ago

🐛 Bug

Cannot train a model using multiple GPUs and then test it using only a single GPU. That is, I have to use the same environment (e.g., number of GPUs/nodes) in which I trained the model to perform testing on a held-out dataset. I am using DDP as my distributed backend.

To Reproduce

Train a model using 16 GPUs across 4 nodes on a computer cluster like Summit and then test one of the checkpoints either on the same cluster (in a second call to the training script with max_epochs=$the_last_epoch_completed) or on a local machine in a separate script calling trainer.test().

Expected behavior

Trainer should be able to initialize the process group such that only 1 GPU is used at inference time, while multiple GPUs can be used to train the model.

Environment

* CUDA:
    - GPU:
        - Tesla V100-SXM2-16GB
    - available:         True
    - version:           10.2
* Packages:
    - numpy:             1.19.2
    - pyTorch_debug:     False
    - pyTorch_version:   1.7.1
    - pytorch-lightning: 1.4.4
    - tqdm:              4.62.0
* System:
    - OS:                Linux
    - architecture:
        - 64bit
        - ELF
    - processor:         ppc64le
    - python:            3.8.8
    - version:           #1 SMP Thu Feb 18 09:47:51 EST 2021

Additional context

Exact error trace I'm seeing:

Using backend: pytorch
Global seed set to 42
DB5 cache found
Loaded DB5 train-set, source: datasets/DB5/final/processed, length: 140
DB5 cache found
Loaded DB5 val-set, source: datasets/DB5/final/processed, length: 35
DB5 cache found
Loaded DB5 test-set, source: datasets/DB5/final/processed, length: 55
/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)
/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AveragePrecision` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:446: UserWarning: Checkpoint directory checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Traceback (most recent call last):
  File "/home/alexm/Repositories/Lab_Repositories/DeepInteract/project/lit_model_test.py", line 186, in <module>
    main(args)
  File "/home/alexm/Repositories/Lab_Repositories/DeepInteract/project/lit_model_test.py", line 128, in main
    model = model.load_from_checkpoint(ckpt_path)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 131, in load_from_checkpoint
    checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py", line 33, in load
    return torch.load(f, map_location=map_location)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torch/serialization.py", line 607, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torch/serialization.py", line 882, in _load
    result = unpickler.load()
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 576, in __setstate__
    self.process_group = _get_default_group()
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 358, in _get_default_group
    raise RuntimeError("Default process group has not been initialized, "
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.
awaelchli commented 3 years ago

Have you saved anything yourself to the checkpoint in any of the hooks, or are you just using the default checkpointing that PL provides? Are interacting with the process group in any way?

amorehead commented 3 years ago

@awaelchli, to the best of my knowledge, I have not saved anything to the checkpoints in any of my hooks. I am using the following hooks for training, validation, and testing.

    def training_step(self, batch, batch_idx):
        """Lightning calls this inside the training loop."""
        graph1, graph2, examples = batch[0], batch[1], batch[2]

        # Forward propagate with network layers
        logits = self(graph1, graph2)

        # Make predictions
        preds = torch.sigmoid(logits)
        int_labels = examples[:, 2].int()

        # Calculate loss and other metrics
        loss = loss_fn(logits, examples[:, 2].float())  # Calculate loss of a single sample
        train_acc = self.train_acc(preds, int_labels)  # Calculate Accuracy of a single sample
        train_prec = self.train_prec(preds, int_labels)  # Calculate Precision of a single sample
        train_recall = self.train_recall(preds, int_labels)  # Calculate Recall of a single sample

        # Log training step metric(s)
        self.log('train_bce', loss)

        return {
            'loss': loss,
            'train_acc': train_acc,
            'train_prec': train_prec,
            'train_recall': train_recall
        }

    def training_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None:
        """Lightning calls this at the end of every training epoch."""
        # Tuplize scores for the current device (e.g. Rank 0)
        train_accs = torch.cat([output_dict['train_acc'].unsqueeze(0) for output_dict in outputs])
        train_precs = torch.cat([output_dict['train_prec'].unsqueeze(0) for output_dict in outputs])
        train_recalls = torch.cat([output_dict['train_recall'].unsqueeze(0) for output_dict in outputs])
        # Concatenate scores over all devices (e.g. Rank 0 | ... | Rank N) - Warning: Memory Intensive
        train_accs = torch.cat([train_acc for train_acc in self.all_gather(train_accs)], dim=0)
        train_precs = torch.cat([train_prec for train_prec in self.all_gather(train_precs)], dim=0)
        train_recalls = torch.cat([train_recall for train_recall in self.all_gather(train_recalls)], dim=0)

        # Reset training TorchMetrics for all devices
        self.train_acc.reset()
        self.train_prec.reset()
        self.train_recall.reset()

        # Log metric(s) aggregated from all ranks
        self.log('med_train_acc', torch.median(train_accs))  # Log MedAccuracy of an epoch
        self.log('med_train_prec', torch.median(train_precs))  # Log MedPrecision of an epoch
        self.log('med_train_recall', torch.median(train_recalls))  # Log MedRecall of an epoch

    def validation_step(self, batch, batch_idx):
        """Lightning calls this inside the validation loop."""
        graph1, graph2, examples = batch[0], batch[1], batch[2]

        # Forward propagate with network layers
        logits = self(graph1, graph2)

        # Make predictions
        preds = torch.sigmoid(logits)
        int_labels = examples[:, 2].int()

        # Calculate loss and other metrics
        loss = self.loss_fn(sampled_logits, examples[:, 2].float())  # Calculate loss of a single sample
        val_acc = self.val_acc(preds, int_labels)  # Calculate Accuracy of a single sample
        val_prec = self.val_prec(preds, int_labels)  # Calculate Precision of a single sample
        val_recall = self.val_recall(preds, int_labels)  # Calculate Recall of a single sample
        val_f1 = self.val_f1(preds, int_labels)  # Calculate F1 score of a single sample
        val_auroc = self.val_auroc(preds, int_labels)  # Calculate AUROC of a sample
        val_auprc = self.val_auprc(preds, int_labels)  # Calculate AveragePrecision (i.e. AUPRC) of a sample

        # Log validation step metric(s)
        self.log('val_bce', loss, sync_dist=True)

        return {
            'loss': loss,
            'val_acc': val_acc,
            'val_prec': val_prec,
            'val_recall': val_recall,
            'val_f1': val_f1,
            'val_auroc': val_auroc,
            'val_auprc': val_auprc
        }

    def validation_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT) -> None:
        """Lightning calls this at the end of every validation epoch."""
        # Tuplize scores for the current device (e.g. Rank 0)
        val_accs = torch.cat([output_dict['val_acc'].unsqueeze(0) for output_dict in outputs])
        val_precs = torch.cat([output_dict['val_prec'].unsqueeze(0) for output_dict in outputs])
        val_recalls = torch.cat([output_dict['val_recall'].unsqueeze(0) for output_dict in outputs])
        val_f1s = torch.cat([output_dict['val_f1'].unsqueeze(0) for output_dict in outputs])
        val_aurocs = torch.cat([output_dict['val_auroc'].unsqueeze(0) for output_dict in outputs])
        val_auprcs = torch.cat([output_dict['val_auprc'].unsqueeze(0) for output_dict in outputs])

        # Concatenate scores over all devices (e.g. Rank 0 | ... | Rank N) - Warning: Memory Intensive
        val_accs = torch.cat([val_acc for val_acc in self.all_gather(val_accs)], dim=0)
        val_precs = torch.cat([val_prec for val_prec in self.all_gather(val_precs)], dim=0)
        val_recalls = torch.cat([val_recall for val_recall in self.all_gather(val_recalls)], dim=0)
        val_f1s = torch.cat([val_f1 for val_f1 in self.all_gather(val_f1s)], dim=0)
        val_aurocs = torch.cat([val_auroc for val_auroc in self.all_gather(val_aurocs)], dim=0)
        val_auprcs = torch.cat([val_auprc for val_auprc in self.all_gather(val_auprcs)], dim=0)

        # Reset validation TorchMetrics for all devices
        self.val_acc.reset()
        self.val_prec.reset()
        self.val_recall.reset()
        self.val_f1.reset()
        self.val_auroc.reset()
        self.val_auprc.reset()

        # Log metric(s) aggregated from all ranks
        self.log('med_val_acc', torch.median(val_accs))  # Log MedAccuracy of an epoch
        self.log('med_val_prec', torch.median(val_precs))  # Log MedPrecision of an epoch
        self.log('med_val_recall', torch.median(val_recalls))  # Log MedRecall of an epoch
        self.log('med_val_f1', torch.median(val_f1s))  # Log MedF1 of an epoch
        self.log('med_val_auroc', torch.median(val_aurocs))  # Log MedAUROC of an epoch
        self.log('med_val_auprc', torch.median(val_auprcs))  # Log epoch MedAveragePrecision

    def test_step(self, batch, batch_idx):
        """Lightning calls this inside the testing loop."""
        graph1, graph2, examples = batch[0], batch[1], batch[2]

        # Forward propagate with network layers
        logits = self(graph1, graph2)

        # Make predictions
        preds = torch.sigmoid(logits)
        int_labels = examples[:, 2].int()

        # Calculate loss and other metrics
        loss = self.loss_fn(sampled_logits, examples[:, 2].float())  # Calculate loss of a single sample
        test_acc = self.test_acc(preds, int_labels)  # Calculate Accuracy of a single sample
        test_prec = self.test_prec(preds, int_labels)  # Calculate Precision of a single sample
        test_recall = self.test_recall(preds, int_labels)  # Calculate Recall of a single sample
        test_f1 = self.test_f1(preds, int_labels)  # Calculate F1 score of a single sample
        test_auroc = self.test_auroc(preds, int_labels)  # Calculate AUROC of a sample
        test_auprc = self.test_auprc(preds, int_labels)  # Calculate AveragePrecision (i.e. AUPRC) of a sample

        # Manually evaluate test performance by collecting all predicted and ground-truth interaction tensors
        argmaxed_logits = torch.round(torch.sigmoid(logits)).cpu().detach()
        argmaxed_logits = argmaxed_logits.reshape(graph1.num_nodes(), graph2.num_nodes())

        argmaxed_examples = examples[:, 2].float().cpu().detach()
        argmaxed_examples = argmaxed_examples.reshape(graph1.num_nodes(), graph2.num_nodes())

        # Log test step metric(s)
        self.log('test_bce', loss, sync_dist=True)

        return {
            'loss': loss,
            'test_acc': test_acc,
            'test_prec': test_prec,
            'test_recall': test_recall,
            'test_f1': test_f1,
            'test_auroc': test_auroc,
            'test_auprc': test_auprc,
            'test_preds': argmaxed_logits,
            'test_labels': argmaxed_examples
        }

    def test_epoch_end(self, outputs: pl.utilities.types.EPOCH_OUTPUT):
        """Lightning calls this at the end of every test epoch."""
        # Tuplize scores for the current device (e.g. Rank 0)
        test_accs = torch.cat([output_dict['test_acc'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
        test_precs = torch.cat([output_dict['test_prec'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
        test_recalls = torch.cat([output_dict['test_recall'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
        test_f1s = torch.cat([output_dict['test_f1'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
        test_aurocs = torch.cat([output_dict['test_auroc'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)
        test_auprcs = torch.cat([output_dict['test_auprc'].unsqueeze(0) for output_dict in outputs]).unsqueeze(1)

        # Concatenate scores over all devices (e.g. Rank 0 | ... | Rank N) - Warning: Memory Intensive
        test_accs = torch.cat([test_acc for test_acc in self.all_gather(test_accs)])
        test_precs = torch.cat([test_prec for test_prec in self.all_gather(test_precs)])
        test_recalls = torch.cat([test_recall for test_recall in self.all_gather(test_recalls)])
        test_f1s = torch.cat([test_f1 for test_f1 in self.all_gather(test_f1s)])
        test_aurocs = torch.cat([test_auroc for test_auroc in self.all_gather(test_aurocs)])
        test_auprcs = torch.cat([test_auprc for test_auprc in self.all_gather(test_auprcs)])
        test_preds = [wandb.Image(output_dict['test_preds']) for output_dict in outputs]  # Convert to grayscale image
        test_labels = [wandb.Image(output_dict['test_labels']) for output_dict in outputs]  # Convert to grayscale image

        # Reset test TorchMetrics for all devices
        self.test_acc.reset()
        self.test_prec.reset()
        self.test_recall.reset()
        self.test_f1.reset()
        self.test_auroc.reset()
        self.test_auprc.reset()

        # Log metric(s) aggregated from all ranks
        self.log('med_test_acc', torch.median(test_accs))  # Log MedAccuracy of an epoch
        self.log('med_test_prec', torch.median(test_precs))  # Log MedPrecision of an epoch
        self.log('med_test_recall', torch.median(test_recalls))  # Log MedRecall of an epoch
        self.log('med_test_f1', torch.median(test_f1s))  # Log MedF1 of an epoch
        self.log('med_test_auroc', torch.median(test_aurocs))  # Log MedAUROC of an epoch
        self.log('med_test_auprc', torch.median(test_auprcs))  # Log epoch MedAveragePrecision

        # Log test predictions with their ground-truth interaction tensors to WandB for visual inspection
        if self.hparams['logger_name'].lower() == 'wandb':
            self.trainer.logger.experiment.log({'test_preds': test_preds})
            self.trainer.logger.experiment.log({'test_labels': test_labels})

I should note that I am logging WandB "Images" to WandB using my LightningModule's self.trainer.logger.experiment object (in the test_epoch_end hook). Not sure if this would be a contributing factor to my original error.

In addition, my training script makes use of PyTorch Lightning's ModelCheckpoint callback to save all of its model checkpoints. It uses a patience of 3 epochs before early-stopping training based on val_bce (i.e., validation binary cross-entropy). All of my checkpoints are created using this ModelCheckpoint callback.

    # -----------
    # Data
    # -----------
    # Load data module
    data_module = LitDataModule()
    data_module.setup()

    # ------------
    # Model
    # ------------
    # Assemble a dictionary of model arguments
    dict_args = vars(args)

    model = LitModel(**dict_args)
    args.experiment_name = f'LitModel Experiment {i}'
    template_ckpt_filename = 'LitModel-{epoch:02d}-{val_bce:.2f}'

    # ------------
    # Checkpoint
    # ------------
    ckpt_path = os.path.join(args.ckpt_dir, args.ckpt_name)
    ckpt_path_exists = os.path.exists(ckpt_path)
    ckpt_provided = args.ckpt_name != '' and ckpt_path_exists
    args.resume_from_checkpoint = ckpt_path if ckpt_provided else None

    # ------------
    # Trainer
    # ------------
    trainer = pl.Trainer.from_argparse_args(args)

    # ------------
    # Logger
    # ------------
    pl_logger = construct_pl_logger(args)  # Log everything to an external logger
    trainer.logger = pl_logger  # Assign specified logger (e.g. TensorBoardLogger) to Trainer instance
    using_wandb_logger = args.logger_name.lower() == 'wandb'  # Determine whether the user requested to use WandB

    # -----------
    # Callbacks
    # -----------
    # Create and use callbacks
    early_stop_callback = pl.callbacks.EarlyStopping(monitor=args.metric_to_track,
                                                     mode='min' if 'ce' in args.metric_to_track else 'max',
                                                     min_delta=args.min_delta, patience=args.patience)
    ckpt_callback = pl.callbacks.ModelCheckpoint(
        monitor=args.metric_to_track,
        mode='min' if 'ce' in args.metric_to_track else 'max',
        save_top_k=5, dirpath=args.ckpt_dir, verbose=True,
        filename=template_ckpt_filename  # May cause a race condition when calling trainer.test() with many GPUs
    )
    trainer.callbacks = [early_stop_callback, ckpt_callback]

    # ------------
    # Restore
    # ------------
    # If using WandB, download checkpoint file from their servers if the checkpoint is not already stored locally
    if using_wandb_logger and ckpt_provided and not os.path.exists(ckpt_path):
        # Download checkpoint from WandB
        trainer.logger.experiment.restore(ckpt_path, run_path=f'{args.entity}/{args.project_name}/{args.run_id}')

    # -------------
    # Training
    # -------------
    # Train with the provided model and DataModule
    trainer.fit(model=model, datamodule=data_module)

    # Save best checkpoint only on the main process
    trainer.save_checkpoint(ckpt_callback.best_model_path)

    # -----------
    # Testing
    # -----------
    trainer.test()

    # -----------
    # Finalizing
    # -----------
    if using_wandb_logger:
        trainer.logger.experiment.save(ckpt_callback.best_model_path)
        trainer.logger.experiment.finish()

When it comes to modifying process groups, I do not believe I am performing any manual changes there. Everything regarding process groups should be default by Lightning's standards.

tchaton commented 3 years ago

Dear @amorehead,

It seems a progress group was pickled for some reasons. Mind printing the checkpoint content while dropping all tensors ?

Best, T.C

amorehead commented 3 years ago

@tchaton, In my understanding, each checkpoint of a pl.trainer instance is saved to storage using torch.save(), so I would need to use torch.load() to load it back into memory to view its contents, right? In which case, which argument(s) can I pass to torch.load(ckpt_filepath) so that it drops all tensors upon loading? Without such an argument, I am encountering the same error as above:

RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.
amorehead commented 3 years ago

@awaelchli and @tchaton, when I view my checkpoint files in a text editor (Gnome's, for instance), this is what I'm seeing (for its head):

Screenshot from 2021-08-28 17-18-00

and its tail:

Screenshot from 2021-08-28 17-19-44

awaelchli commented 3 years ago

@amorehead is it maybe possible that this only happens with the wandb logger?

could you maybe share the checkpoint file you are trying to reload? I'll try to have a look inside. You can also try to look inside by initializing the progress group as the error message suggests.

amorehead commented 3 years ago

@awaelchli, attached below is one of the checkpoints that has been giving me this error. I will also check to see if the WandbLogger is causing this issue.

checkpoint.tar.gz

amorehead commented 3 years ago

@awaelchli,

It appears as though the WandBLogger is not the root cause of the issue, at least in my view. I tried swapping it out for the TensorBoardLogger, and I receive the same process group error when I go to call trainer.test().

I was able to load in the checkpoint when initializing the process group manually, as you suggested. I cannot identify any areas in the dictionary where process groups are mentioned, at least on my first few passes.

In addition, I tried running trainer.test() with my (now) loaded-in checkpoint to run inference, and I found that, stemming from my LightningModule's self.save_hyperparameters() call in its init() method, I am seeing a new error. It seems as though my LightningModule is trying to pickle a _SingleProcessDataLoaderIter instance when it saves its hyperparameters.

Any ideas what this may point back to? I am not entirely sure why my LightningModule is trying to pickle a DataLoader.

/home/alexm/anaconda3/envs/DeepInteract/bin/python /home/alexm/Repositories/Lab_Repositories/DeepInteract/project/lit_model_test.py --logger_name WandB --experiment_name LitGINI-b4-gl1-n128-e128-il2-i128-Alex-Local-Model-Testing --online --max_hours 1 --max_minutes 55 --ckpt_dir checkpoints --num_gpus 1 --num_compute_nodes 1 --gpu_precision 32 --num_workers 2 --dips_percent_to_use 1.00 --self_loops --pn_ratio 0.1 --use_dgl --num_epochs 50 --patience 5 --batch_size 1 --accum_grad_batches 1 --lr 3e-4 --weight_decay 1e-4 --dropout_rate 0.1 --model_name gini --gnn_layer_type geotran --num_gnn_layers 1 --num_gnn_hidden_channels 128 --num_gnn_attention_heads 8 --num_interact_layers 128 --interact_module_type resnet --num_interact_hidden_channels 128 --num_interact_attention_heads 8 --patch_size 7 --final_layer_bias_value -7.0 --ckpt_name LitGINI-epoch=00-val_bce=0.31.ckpt
Using backend: pytorch
Global seed set to 42
DB5 cache found
Loaded DB5 train-set, source: datasets/DB5/final/processed, length: 140
DB5 cache found
Loaded DB5 val-set, source: datasets/DB5/final/processed, length: 35
DB5 cache found
Loaded DB5 val-set, source: datasets/DB5/final/processed, length: 16
DB5 cache found
Loaded DB5 test-set, source: datasets/DB5/final/processed, length: 55
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AUROC` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)
/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torchmetrics/utilities/prints.py:36: UserWarning: Metric `AveragePrecision` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)
Traceback (most recent call last):
  File "/home/alexm/Repositories/Lab_Repositories/DeepInteract/project/lit_model_test.py", line 129, in <module>
    main(args)
  File "/home/alexm/Repositories/Lab_Repositories/DeepInteract/project/lit_model_test.py", line 74, in main
    model = LitGINI.load_from_checkpoint(ckpt_path)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 153, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 195, in _load_model_state
    model = cls(**_cls_kwargs)
  File "/home/alexm/Repositories/Lab_Repositories/DeepInteract/project/utils/deepinteract_modules.py", line 1883, in __init__
    self.save_hyperparameters()
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/core/mixins/hparams_mixin.py", line 105, in save_hyperparameters
    save_hyperparameters(self, *args, ignore=ignore, frame=frame)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py", line 252, in save_hyperparameters
    obj._hparams_initial = copy.deepcopy(obj._hparams)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 205, in _deepcopy_list
    append(deepcopy(a, memo))
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/copy.py", line 161, in deepcopy
    rv = reductor(4)
  File "/home/alexm/anaconda3/envs/DeepInteract/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 547, in __getstate__
    raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
NotImplementedError: ('{} cannot be pickled', '_SingleProcessDataLoaderIter')
awaelchli commented 3 years ago

yes you need to exclude any unpickleable objects that are passed in via the constructor. self.save_hyperparameters(ignore="some_arg_name") maybe you are passing in dataloaders

awaelchli commented 3 years ago

After trying to load your checkpoint I am sure you have a hyperparameter object pickled, something specific to your project. It must be an object you have passed in via __init__.

amorehead commented 3 years ago

@awaelchli, you are absolutely right. I discovered that the issue was caused by my passing command-line arguments from ArgParse directly into my LightningModule's init() method. That is, after converting my ArgParse arguments into a Python dictionary, I was feeding this dictionary in its entirety into the init() method using the unravel syntax (i.e., **dict_args). ArgParse must have picked up a DataLoader instance somehow while it was parsing my provided CLI arguments, so then my LightningModule would try to pickle a DataLoader (which isn't feasible currently).

The issue is fixed on my end, and now I can test all my checkpoints. Thank you very much for your help in troubleshooting this issue with me!

awaelchli commented 3 years ago

Great! Glad you were able to fix it