Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.19k stars 3.37k forks source link

Callback metrics not being populated during multi-gpu training #7671

Closed jacanchaplais closed 3 years ago

jacanchaplais commented 3 years ago

🐛 Bug

When plt.Trainer(gpus>1, ...), the callback_metrics dictionary appears not to be populated. I've tried to both integrate Optuna and Ray Tune, and have failed with both as a result.

To ensure this was the issue, I printed the trainer.callback_metrics attribute:

(pid=144372) callback metrics are:
(pid=144372)  {}

It does, however, work with 1 GPU. Unfortunately for me, my datasets are graphs, and they are so large that I can only fit one into memory at a time, so the number of GPUs = batch size, and tuning with a batch size of 1 might not be very indicative. Any help much appreciated!

Environment

Hardware

Software

Code

I attach the LightningModule below, which uses TorchMetrics and the self.log features, as per the docs. I did (in desperation) try setting the callback_metrics dictionary myself in validation_epoch_end(), but that didn't work. Neither did setting sync_dist=True in the self.log() calls.

import torch
import torchmetrics
import torch_geometric as pyg
import pytorch_lightning as pl

class Interaction(pyg.nn.MessagePassing):
    def __init__(self, in_edge, in_node, out_edge, out_node):
        super(Interaction, self).__init__(
            aggr='add',
            flow="source_to_target")
        self.in_edge = 2 * in_node + in_edge
        self.in_node = in_node + out_edge
        self.mlp_edge = torch.nn.Sequential(
            torch.nn.Linear(self.in_edge, out_edge, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(out_edge, out_edge, bias=True)
        )
        self.mlp_node = torch.nn.Sequential(
            torch.nn.Linear(self.in_node, out_node, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(out_node, out_node, bias=True)
        )

    def forward(self, x, edge_index, edge_attrs):
        return self.propagate(
            x=x,
            edge_index=edge_index,
            edge_attrs=edge_attrs
        )

    def message(self, x_i, x_j, edge_index, edge_attrs):
        recv_send = [x_i, x_j]
        if edge_attrs is not None:
            recv_send.append(edge_attrs)
        recv_send = torch.cat(recv_send, dim=1)
        self.edge_embed = self.mlp_edge(recv_send)
        return self.edge_embed

    def update(self, aggr_out, x):
        node_embed = self.mlp_node(torch.cat([x, aggr_out], dim=1))
        return (self.edge_embed, node_embed)

class Net(pl.LightningModule):
    def __init__(self, dim_node: int = 4, dim_edge: int = 0,
                 dim_embed_edge: int = 64, dim_embed_node: int = 32,
                 num_hidden: int = 3, final_bias: bool = False,
                 pos_weight: float = 80.0,
                 learn_rate: float = 1e-4, weight_decay: float = 5e-4,
                 infer_thresh: float = 0.5):
        super(Net, self).__init__()
        # define the architecture
        self.encode = Interaction(dim_edge, dim_node,
                                  dim_embed_edge, dim_embed_node)
        self.message = pyg.nn.Sequential('x, edge_index, edge_attrs', [
            (Interaction(dim_embed_edge, dim_embed_node,
                         dim_embed_edge, dim_embed_node),
             'x, edge_index, edge_attrs -> edge_attrs, x')
             for i in range(num_hidden)
             ])
        self.classify = torch.nn.Linear(dim_embed_edge, 1, bias=final_bias)
        # optimiser args
        self.lr = learn_rate
        self.decay = weight_decay
        # configure the loss
        self.criterion = torch.nn.BCEWithLogitsLoss(
                pos_weight=torch.tensor(pos_weight, device=self.device),
                reduction='mean')
        # add metrics
        self.train_ACC = torchmetrics.Accuracy(threshold=infer_thresh)
        self.train_F1 = torchmetrics.F1(
                num_classes=1, threshold=infer_thresh)
        self.val_ACC = torchmetrics.Accuracy(threshold=infer_thresh)
        self.val_F1 = torchmetrics.F1(
                num_classes=1, threshold=infer_thresh)
        self.val_PR = torchmetrics.BinnedPrecisionRecallCurve(
                num_classes=1, num_thresholds=5)

    def forward(self, data, sigmoid=True):
        node_attrs, edge_attrs = data.x, data.edge_attr
        edge_attrs, node_attrs = self.encode(node_attrs, data.edge_index,
                                             edge_attrs)
        edge_attrs, node_attrs = self.message(node_attrs, data.edge_index,
                                              edge_attrs)
        pred = self.classify(edge_attrs)
        if sigmoid:
            pred = torch.sigmoid(pred)
        return pred

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
                self.parameters(),
                lr=self.lr,
                weight_decay=self.decay
            )
        return optimizer

    def _train_av_loss(self, outputs):
        return torch.stack([x['loss'] for x in outputs]).mean()

    def _val_av_loss(self, losses):
        return torch.stack(losses).mean()

    def training_step(self, batch, batch_idx):
        edge_pred = self(batch, sigmoid=False)
        loss = self.criterion(edge_pred, batch.y.view(-1, 1))
        return {'loss': loss,
                'preds': torch.sigmoid(edge_pred),
                'target': batch.y.view(-1, 1).int()}

    def training_step_end(self, outputs):
        self.train_ACC(outputs['preds'], outputs['target'])
        self.train_F1(outputs['preds'], outputs['target'])
        self.log('ptl/train_loss', outputs['loss'], on_step=True)
        return outputs['loss']

    def training_epoch_end(self, outputs):
        self.log('ptl/train_loss', self._train_av_loss(outputs))
        self.log('ptl/train_accuracy', self.train_ACC.compute())
        self.log('ptl/train_f', self.train_F1.compute())

    def validation_step(self, batch, batch_idx):
        edge_pred = self(batch, sigmoid=False)
        loss = self.criterion(edge_pred, batch.y.view(-1, 1))
        return {'loss': loss,
                'preds': torch.sigmoid(edge_pred),
                'target': batch.y.view(-1, 1).int()}

    def validation_step_end(self, outputs):
        self.val_ACC(outputs['preds'], outputs['target'])
        self.val_F1(outputs['preds'], outputs['target'])
        self.val_PR(outputs['preds'], outputs['target'])
        self.log('ptl/val_loss', outputs['loss'], on_step=True)
        return outputs['loss']

    def validation_epoch_end(self, outputs):
        metrics = {
            'ptl/val_loss': self._val_av_loss(outputs),
            'ptl/val_accuracy': self.val_ACC.compute(),
            'ptl/val_f': self.val_F1.compute(),
            }
        self.log_dict(metrics, sync_dist=True)
        prec, recall, thresh = self.val_PR.compute()
        for i, t in enumerate(thresh):
            self.log(f'ptl/val_prec_thresh_{t:.3f}', prec[i])
            self.log(f'ptl/val_recall_thresh_{t:.3f}', recall[i])
        self.trainer.callback_metrics = metrics
        return metrics

Here I attach the tuning script using Ray[Tune], as per their docs https://docs.ray.io/en/master/tune/tutorials/tune-pytorch-lightning.html.

import os

import pytorch_lightning as pl
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback

from cluster_gnn.models import gnn
from cluster_gnn.data import loader

# slurm hack
os.environ["SLURM_JOB_NAME"] = "bash"

ROOT_DIR = '/home/jlc1n20/projects/cluster_gnn/'
MODEL_DIR = ROOT_DIR + '/models/tune/'

def train_gnn(config, data_module, num_epochs=10, num_gpus=4, callbacks=None,
              checkpoint_dir=None):
    logger = pl.loggers.TensorBoardLogger(
        save_dir=tune.get_trial_dir(), name="", version=".")
    if checkpoint_dir:
        ckpt = pl.utilities.cloud_io.pl_load(
            os.path.join(checkpoint_dir, 'checkpoint'),
            map_location=lambda storage, loc: storage)
        model = gnn.Net._load_model_state(
            checkpoint=ckpt,
            num_hidden=6, dim_embed_edge=64, dim_embed_node=32,
            learn_rate=config['learn_rate'],
            pos_weight=config['pos_weight'])
    else:
        model = gnn.Net(num_hidden=6, dim_embed_edge=64, dim_embed_node=32,
                        learn_rate=config['learn_rate'],
                        pos_weight=config['pos_weight'])
    trainer = pl.Trainer(gpus=num_gpus, num_nodes=1, max_epochs=num_epochs,
                         progress_bar_refresh_rate=0,
                         limit_train_batches=0.1,
                         logger=logger,
                         callbacks=callbacks)
    trainer.fit(model, data_module)
    print('callback metrics are:\n {}'.format(trainer.callback_metrics))

def tune_gnn(data_module, num_samples=10, num_epochs=10, gpus_per_trial=2,
             init_params=None, checkpoint_dir=None):
    config = {
        'learn_rate': tune.loguniform(1e-6, 1e-1),
        'pos_weight': tune.uniform(1.0, 100.0),
        }
    metrics = ['ptl/val_loss', 'ptl/val_accuracy', 'ptl/val_f']
    callbacks = [
        TuneReportCheckpointCallback(
            metrics,
            filename='checkpoint',
            on='validation_end')
        ]
    scheduler = ASHAScheduler(
        time_attr='epoch',
        max_t=num_epochs,
        )
    search_alg = HyperOptSearch(points_to_evaluate=init_params)
    reporter = CLIReporter(
        parameter_columns=[
            'learn_rate',
            'pos_weight',
            ],
        )
    trainable = tune.with_parameters(
        train_gnn,
        data_module=data_module,
        num_epochs=num_epochs,
        num_gpus=gpus_per_trial,
        callbacks=callbacks,
        checkpoint_dir=checkpoint_dir,
        )
    analysis = tune.run(
        trainable,
        resources_per_trial={
            'cpu': 1,
            'gpu': gpus_per_trial,
            },
        metric='ptl/val_f',
        mode='max',
        config=config,
        num_samples=num_samples,
        search_alg=search_alg,
        scheduler=scheduler,
        progress_reporter=reporter,
        local_dir=MODEL_DIR,
        verbose=3,
        name='tune_gnn')
    print('Best hp found: ', analysis.best_config)

if __name__ == '__main__':
    num_gpus = 2
    cur_best_params = [{
        'learn_rate': 3.75e-5,
        'pos_weight': 21.5,
        }]
    graph_data = loader.GraphDataModule(
        '/home/jlc1n20/projects/cluster_gnn/data/', num_workers=num_gpus)

    tune_gnn(data_module=graph_data, num_samples=1, num_epochs=1,
             gpus_per_trial=num_gpus, init_params=cur_best_params)
edgarriba commented 3 years ago

@jacanchaplais thanks for providing feedback. Was trying to reproduce your code but I believe that to inspect in deep we need more parts of the code that you didn't provide like the gnn.loader and at least a minimal data sample.

In the meantime I created this small repo to play around with this issue: https://github.com/edgarriba/pl_issue_7671

Please, let us know your thoughts so that we can help to solve this issue.

jacanchaplais commented 3 years ago

Thanks @edgarriba. I use the data loader class provided by PyTorch Geometric, and my full code can be seen here https://github.com/jacanchaplais/cluster_gnn.

The data loader is defined here https://github.com/jacanchaplais/cluster_gnn/blob/main/src/cluster_gnn/data/loader.py.

As I'm prototyping stuff, I wrote the data processing separately in a Jupyter notebook found here https://github.com/jacanchaplais/cluster_gnn/blob/main/notebooks/convert_data.ipynb.

Here is a small sample data set of 100 graphs. small_data.hdf5.zip

Although, have you tried reproducing this error on a simpler case, like the standard MNIST classifier? If not, might be better to see if it is a problem there before trying to reproduce my rather specific case, as it hasn't yet been polished for easy portability.

EDIT: if you do want to install my codebase, you can do conda env create -f environment.yml, followed by bash pyg-pip.sh ptg, and finally pip install -e ..

jacanchaplais commented 3 years ago

I should note that I have reproduced this myself with MNIST, see https://github.com/optuna/optuna-examples/blob/c8df375e5bd9d741538491f87d607244ae6e9746/pytorch/pytorch_lightning_simple.py.

Optuna have since changed the number of GPUs they use in this example to 1, rather than a variable number, as I informed them of the issue that I'm reporting in the current thread.

edgarriba commented 3 years ago

@jacanchaplais thanks for the insights. I'll investigate this in detail.

edgarriba commented 3 years ago

after some investigation - seems that after this call to mp.spawn https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/ddp_spawn.py#L157

the self.lightning_module.trainer is touched and the callback_metrics variable is cleared. In addition, I have noticed that self.lightning_module.trainer is a weak-reference object which makes me suspect that somehow is de-referenced, but not sure about this one. /cc @awaelchli @justusschock do you have any intuition about this ?

justusschock commented 3 years ago

~@edgarriba From my feeling you might be right, when passing a weakref proxy the original object might not exist anymore in the new process. maybe we can use the passed trainer instead of the one in mp_kwargs? This one shouldn't be a proxy IIRC~

EDIT: nvm my reply, @awaelchli is right, didn't think of this.

awaelchli commented 3 years ago

Hey guys! Don't get mislead by this. From what I can see the OP is trying to access the callback_metrics outside in the main process:

    trainer.fit(model, data_module)
    print('callback metrics are:\n {}'.format(trainer.callback_metrics))

Please note that in DDP spawn the main process never trains, therefore there are no callback metrics! The only thing it does is wait for the worker processes to enter join() when finished. This is just how it is in ddp spawn. Lightning will make sure to add the weights to the queue so they get back to the main process but that's pretty much all.

The recommendation is always to avoid ddp spawn whenever possible. So my recommendation is accelerator="ddp".

edgarriba commented 3 years ago

@awaelchli you might be right, but the process gets blocked when I try with ddp

justusschock commented 3 years ago

@edgarriba ddp doesn't work with notebooks, that's the only reason we still have ddp_spawn around. So the callback_metrics in spawn are populated just not in the main process, since that one just waits and does nothing.

edgarriba commented 3 years ago

@justusschock I'm running in an aws instance

s-rog commented 3 years ago

I thought no form of ddp works in notebooks? also can confirm in ddp_spawn that callback_metrics work internally

justusschock commented 3 years ago

Yes, no form of plain DDP (since we usually call the script multiple times which is not possible in jupyter), but with spawn we spawn processes that are not tied to the script but to one specific function we pass (and thus they work)

tchaton commented 3 years ago

Dear @jacanchaplais,

Any progress on this issue ?

Responses resume:

Best, T.C

jacanchaplais commented 3 years ago

I will test specifying accelerator="ddp" soon and get back to you, thanks for the updates.