Open srprca opened 3 months ago
Yes that's right, the warning shouldn't occur when logging TorchMetrics. Does it occur only with MetricCollection or a regular Metric too?
Thank you for your reply!
I will be able to check this tomorrow, and will report back.
Meanwhile, my second suspicion is that since I log metric.item()
objects, is it possible that somehow self.log
doesn't recognize these objects as originating from TorchMetrics, and sees them as "generic" numbers or tensors...?
I will try to check this hypothesis, too.
If you pass in scalar tensors then not of course. Then the warning is normal and expected. For logging TorchMetrics you would just pass in the metric directly into self.log
. You can find a guide here: https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html
I guess this is exactly my case. I don't exactly remember why I have .item()
for clearly-scalar metrics, and I guess I can try removing .item()
in their case, but my use-case in general is that I want to log both aggregate metrics (mean recall over classes) and per-class metrics (recall of every individual class), and in the latter case I use multiclass recall with average="none"
, and then extract individual elements to be saved in separate columns in my metrics csv. Similarly, for confusion matrix, I want to log every individual entry as a separate column in the metrics csv file.
So I guess it's not a bug then, thank you for clarifying this!
Now I have just a couple more questions:
self.log
doesn't recognize that this is the case and shows me the warning? That is, is it safe and correct to ignore it?Thank you!
Actually, this still happens when I log all the metric properly, without using .item()
: now, with the following metrics definition:
self.val_metric_funs = tm.MetricCollection(
{
"recall_average_macro": tmc.MulticlassRecall(
len(self.task.class_names),
average="macro",
ignore_index=self.metric_ignore_index,
),
"precision_average_macro": tmc.MulticlassPrecision(
len(self.task.class_names),
average="macro",
ignore_index=self.metric_ignore_index,
),
"f1_average_macro": tmc.MulticlassF1Score(
len(self.task.class_names),
average="macro",
ignore_index=self.metric_ignore_index,
),
}
)
and the following logging code:
metrics = self.val_metric_funs(logits, annotator_masks_mixed)
if not sanity_check:
for metric_name, metric in metrics.items():
metric_fun = self.val_metric_funs[metric_name]
metric_name_ = metric_name.split("_")[0]
if isinstance(
metric_fun,
(
tmc.MulticlassRecall,
tmc.MulticlassPrecision,
tmc.MulticlassF1Score,
),
):
if metric_fun.average == "macro":
self.log(
f"val_mean_{metric_name_}",
metric,
on_step=False,
on_epoch=True,
logger=True,
)
else:
raise NotImplementedError(
f"Code for logging metric {metric_name} is not implemented"
)
else:
raise NotImplementedError(
f"Code for logging metric {metric_name} is not implemented"
)
I get
.../miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_mean_f1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
.../miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_mean_precision', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
.../miniforge3/envs/main/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:439: It is recommended to use `self.log('val_mean_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
So here, metrics
is the return value of calling MetricCollection
with my logits and true masks, then I iterate over it like this: for metric_name, metric in metrics.items()
, and finally properly log metric
, not metric.item()
.
I will test whether this still happens without MetricCollection
a bit later.
This still happens when logged "properly" (without .item()
) without MetricCollection
wrapper.
Can you show it with a runnable example based on https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/pytorch/bug_report/bug_report_model.py?
If you pass in scalar tensors then not of course. Then the warning is normal and expected. For logging TorchMetrics you would just pass in the metric directly into
self.log
. You can find a guide here: https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html
I get the same thing when logging using the manual method.
Going to try directly logging with the metric, but does that support ClasswiseWrapper?
Hello, i can confirm the confusion. I am just training on 2 GPUs and cannot find any documentation on how to use MetricCollection in distributed environments. Im not using sync_dist, so getting the same warning and i am not sure if my metrics are computed / logged properly
Bug description
Hello!
When I train and validate a model in a multi-GPU setting (HPC, sbatch job that requests multiple GPUs on a single node), I use
self.log(..., sync_dist=True)
when logging PyTorch losses, and don't specify any value forsync_dist
when logging metrics from TorchMetrics library. However, I still get warnings likeThese specific messages correspond to logging
tmc.MulticlassRecall(len(self.task.class_names), average="macro", ignore_index=self.metric_ignore_index)
and individual components oftmc.MulticlassRecall(len(self.task.class_names), average="none", ignore_index=self.metric_ignore_index)
.Full code listing for metric object definitions and logging is provided in the "reproducing the bug" section.
As I understand from a note here, and from discussion here, one doesn't typically need to explicitly use
sync_dist
when using TorchMetrics.I wonder if I still need to enable
sync_dist=True
as advised in the warnings due to some special case that I am not aware about, or should I follow the docs and keep it as is? In any case, this is probably a bug, either in documentation, or in warning code.Thank you!
What version are you seeing the problem on?
2.3.0
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` #- PyTorch Lightning Version: 2.3.0 #- PyTorch Version: 2.3.1 #- Python version: 3.11.9 #- OS: Linux #- CUDA/cuDNN version: 11.8 #- How you installed Lightning: conda-forge ```More info
No response
cc @carmocca