Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.42k stars 3.39k forks source link

Confusing recommendation to use sync_dist=True even with TorchMetrics #20153

Open srprca opened 3 months ago

srprca commented 3 months ago

Bug description

Hello!

When I train and validate a model in a multi-GPU setting (HPC, sbatch job that requests multiple GPUs on a single node), I use self.log(..., sync_dist=True) when logging PyTorch losses, and don't specify any value for sync_dist when logging metrics from TorchMetrics library. However, I still get warnings like

...
It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
...
It is recommended to use `self.log('val_bg_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

These specific messages correspond to logging tmc.MulticlassRecall(len(self.task.class_names), average="macro", ignore_index=self.metric_ignore_index) and individual components of tmc.MulticlassRecall(len(self.task.class_names), average="none", ignore_index=self.metric_ignore_index).

Full code listing for metric object definitions and logging is provided in the "reproducing the bug" section.

As I understand from a note here, and from discussion here, one doesn't typically need to explicitly use sync_dist when using TorchMetrics.

I wonder if I still need to enable sync_dist=True as advised in the warnings due to some special case that I am not aware about, or should I follow the docs and keep it as is? In any case, this is probably a bug, either in documentation, or in warning code.

Thank you!

What version are you seeing the problem on?

2.3.0

How to reproduce the bug

self.val_metric_funs = tm.MetricCollection(
                {
                    "cm_normalize_all": tmc.MulticlassConfusionMatrix(
                        len(self.task.class_names),
                        ignore_index=self.metric_ignore_index,
                        normalize="all",
                    ),
                    "recall_average_macro": tmc.MulticlassRecall(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "recall_average_none": tmc.MulticlassRecall(
                        len(self.task.class_names),
                        average="none",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "precision_average_macro": tmc.MulticlassPrecision(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "precision_average_none": tmc.MulticlassPrecision(
                        len(self.task.class_names),
                        average="none",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "f1_average_macro": tmc.MulticlassF1Score(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "f1_average_none": tmc.MulticlassF1Score(
                        len(self.task.class_names),
                        average="none",
                        ignore_index=self.metric_ignore_index,
                    ),
                }
            )
if not sanity_check:
            for metric_name, metric in metrics.items():
                metric_fun = self.val_metric_funs[metric_name]
                metric_name_ = metric_name.split("_")[0]
                if isinstance(metric_fun, tmc.MulticlassConfusionMatrix):
                    for true_class_num in range(metric.shape[0]):
                        true_class_name = self.task.class_names[true_class_num]
                        for pred_class_num in range(metric.shape[1]):
                            pred_class_name = self.task.class_names[pred_class_num]
                            self.log(
                                f"val_true_{true_class_name}_pred_{pred_class_name}_cm",
                                metric[true_class_num, pred_class_num].item(),
                                on_step=False,
                                on_epoch=True,
                                logger=True,
                            )
                elif isinstance(
                    metric_fun,
                    (
                        tmc.MulticlassRecall,
                        tmc.MulticlassPrecision,
                        tmc.MulticlassF1Score,
                    ),
                ):
                    if metric_fun.average == "macro":
                        self.log(
                            f"val_mean_{metric_name_}",
                            metric.item(),
                            on_step=False,
                            on_epoch=True,
                            logger=True,
                        )
                    elif metric_fun.average == "none":
                        for class_num, metric_ in enumerate(metric):
                            class_name = self.task.class_names[class_num]
                            self.log(
                                f"val_{class_name}_{metric_name_}",
                                metric_.item(),
                                on_step=False,
                                on_epoch=True,
                                logger=True,
                            )
                    else:
                        raise NotImplementedError(
                            f"Code for logging metric {metric_name} is not implemented"
                        )
                else:
                    raise NotImplementedError(
                        f"Code for logging metric {metric_name} is not implemented"
                    )

Error messages and logs

...
It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
...
It is recommended to use `self.log('val_bg_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

Environment

Current environment ``` #- PyTorch Lightning Version: 2.3.0 #- PyTorch Version: 2.3.1 #- Python version: 3.11.9 #- OS: Linux #- CUDA/cuDNN version: 11.8 #- How you installed Lightning: conda-forge ```

More info

No response

cc @carmocca

awaelchli commented 3 months ago

Yes that's right, the warning shouldn't occur when logging TorchMetrics. Does it occur only with MetricCollection or a regular Metric too?

srprca commented 3 months ago

Thank you for your reply!

I will be able to check this tomorrow, and will report back.

Meanwhile, my second suspicion is that since I log metric.item() objects, is it possible that somehow self.log doesn't recognize these objects as originating from TorchMetrics, and sees them as "generic" numbers or tensors...?

I will try to check this hypothesis, too.

awaelchli commented 3 months ago

If you pass in scalar tensors then not of course. Then the warning is normal and expected. For logging TorchMetrics you would just pass in the metric directly into self.log. You can find a guide here: https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html

srprca commented 3 months ago

I guess this is exactly my case. I don't exactly remember why I have .item() for clearly-scalar metrics, and I guess I can try removing .item() in their case, but my use-case in general is that I want to log both aggregate metrics (mean recall over classes) and per-class metrics (recall of every individual class), and in the latter case I use multiclass recall with average="none", and then extract individual elements to be saved in separate columns in my metrics csv. Similarly, for confusion matrix, I want to log every individual entry as a separate column in the metrics csv file.

So I guess it's not a bug then, thank you for clarifying this!

Now I have just a couple more questions:

  1. Does this mean that the metrics are properly reduced across the devices behind the scenes, it is just that self.log doesn't recognize that this is the case and shows me the warning? That is, is it safe and correct to ignore it?
  2. Is there a more idiomatic way to do what I want to do here, given the use-case described above?

Thank you!

srprca commented 3 months ago

Actually, this still happens when I log all the metric properly, without using .item(): now, with the following metrics definition:

self.val_metric_funs = tm.MetricCollection(
                {
                    "recall_average_macro": tmc.MulticlassRecall(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "precision_average_macro": tmc.MulticlassPrecision(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                    "f1_average_macro": tmc.MulticlassF1Score(
                        len(self.task.class_names),
                        average="macro",
                        ignore_index=self.metric_ignore_index,
                    ),
                }
            )

and the following logging code:

metrics = self.val_metric_funs(logits, annotator_masks_mixed)
        if not sanity_check:
            for metric_name, metric in metrics.items():
                metric_fun = self.val_metric_funs[metric_name]
                metric_name_ = metric_name.split("_")[0]
                if isinstance(
                    metric_fun,
                    (
                        tmc.MulticlassRecall,
                        tmc.MulticlassPrecision,
                        tmc.MulticlassF1Score,
                    ),
                ):
                    if metric_fun.average == "macro":
                        self.log(
                            f"val_mean_{metric_name_}",
                            metric,
                            on_step=False,
                            on_epoch=True,
                            logger=True,
                        )
                    else:
                        raise NotImplementedError(
                            f"Code for logging metric {metric_name} is not implemented"
                        )
                else:
                    raise NotImplementedError(
                        f"Code for logging metric {metric_name} is not implemented"
                    )

I get

.../miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_mean_f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
.../miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_mean_precision', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
.../miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.

So here, metrics is the return value of calling MetricCollection with my logits and true masks, then I iterate over it like this: for metric_name, metric in metrics.items(), and finally properly log metric, not metric.item().

I will test whether this still happens without MetricCollection a bit later.

srprca commented 3 months ago

This still happens when logged "properly" (without .item()) without MetricCollection wrapper.

awaelchli commented 3 months ago

Can you show it with a runnable example based on https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/pytorch/bug_report/bug_report_model.py?

geometrikal commented 2 months ago

If you pass in scalar tensors then not of course. Then the warning is normal and expected. For logging TorchMetrics you would just pass in the metric directly into self.log. You can find a guide here: https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html

I get the same thing when logging using the manual method.

Going to try directly logging with the metric, but does that support ClasswiseWrapper?

david-rohrschneider commented 6 days ago

Hello, i can confirm the confusion. I am just training on 2 GPUs and cannot find any documentation on how to use MetricCollection in distributed environments. Im not using sync_dist, so getting the same warning and i am not sure if my metrics are computed / logged properly