Open LarsHill opened 2 years ago
Hi! thanks for your contribution!, great first issue!
Hey,
Are there any updates on this matter? After checking metric.py
I don't see a particular reason for this restriction but maybe I overlooked something.
Also, the reset
method could simply reset each state variable to the user defined default instead of checking the type to distinguish between List
and torch.Tensor
. Regarding reduction logic one can keep the current approach and if the state type is something custom (e.g. a Dict
, etc.) the synced state is simply a List[Dict]
and the user is in charge of providing a custome reduce fn to combine the different states across processes.
@LarsHill, I agree that this is an important question to ask. I am also encountering a use case where I need to track my custom Metric
's state with a dictionary. Any thoughts on this query, @SkafteNicki or @Borda?
Once again a short ping. Are there any plans to tackle this issue in the near future?
Hi @LarsHill, It has not really been a priority for us, but I am willing to listen. At the moment I think we could consider adding the option to also support dicts of tensors as an third option. It would not completely solve your initial example, because it requires dict of dict of tensors, but then you could do something like:
for i in range(num_document_ids):
self.add_state(f"statistics_{i}", default=defaultdict(), dist_reduce_fx=None)
and then in update:
for id_, tp, tn, fp, fn in zip(document_ids, tps, tns, fps, fns):
getattr(self, f"statistics_{id_})['tp'] += tp.float().item()
getattr(self, f"statistics_{id_})['tn'] += tn.float().item()
getattr(self, f"statistics_{id_})['fp'] += fp.float().item()
getattr(self, f"statistics_{id_})['fn'] += fn.float().item()
this would of cause mean that the num_documents_ids
needed to be known in advance.
Hi @SkafteNicki,
Thanks for your suggestion. That would work as a temporary solution, but I think it's not so elegant to iterate the entire dataset at initialization time, just to retrieve all the document ids.
Is there a particular reason, why the state types are restricted in the first place? What about the suggestion to extend it to arbitrary types and if it is something custom, the synced/reduced state is simply a List of that custom type and the user is in charge to properly reduce it, i.e. by providing a specific reduce function.
I am referring to these lines in metric.py
:
if not isinstance(default, (Tensor, list)) or (isinstance(default, list) and default):
raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)")
if dist_reduce_fx == "sum":
dist_reduce_fx = dim_zero_sum
elif dist_reduce_fx == "mean":
dist_reduce_fx = dim_zero_mean
elif dist_reduce_fx == "max":
dist_reduce_fx = dim_zero_max
elif dist_reduce_fx == "min":
dist_reduce_fx = dim_zero_min
elif dist_reduce_fx == "cat":
dist_reduce_fx = dim_zero_cat
elif dist_reduce_fx is not None and not callable(dist_reduce_fx):
raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', 'min', 'max', None]")
It seems now users already have the chance to pass a custom dist_reduce_fx
Callable. So why not change the above code to something like that?
if not isinstance(default, (Tensor, list)) or (isinstance(default, list) and default and not callable(dist_reduce_fx)):
raise ValueError("If the state variable is not a tensor or any empty list (where you can append tensors) you have to provide a custom `dist_reduce_fx` that is callable.")
if dist_reduce_fx == "sum":
dist_reduce_fx = dim_zero_sum
elif dist_reduce_fx == "mean":
dist_reduce_fx = dim_zero_mean
elif dist_reduce_fx == "max":
dist_reduce_fx = dim_zero_max
elif dist_reduce_fx == "min":
dist_reduce_fx = dim_zero_min
elif dist_reduce_fx == "cat":
dist_reduce_fx = dim_zero_cat
elif dist_reduce_fx is not None and not callable(dist_reduce_fx):
raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', 'min', 'max', None]")
For current users that work with torch.Tensor
oder List[torch.Tensor]
nothing changes, and users who have to use custom metric state types have the flexibility to do so.
Is there some issues with such a change that I am not aware of?
Btw since I could not wait for changes in torchmetrics, I had to go ahead and implement my own metric abstraction. However, I would prefer to switch to torchmetrics
, but this seems to me like a missing feature, that is limiting.
🚀 Feature
Metric states seem to be limited to
torch.Tensor
orList[torch.Tensor]
.In my usecase i want to store a dictionary as state. My dataset comprises of samples who can be assigned to different documents. In order to calculate macro metrics (calculate metrics per document and average) I want to store my metric states (e.g. true positives, false positives, etc.) as a dictionary. Here is some pseudocode:
Unfortunately the above code is not allowed. Each metric state has to be a
torch.Tensor
or aList[torch.Tensor]
. That means normal float values or numpy arrays cannot be used as metrics either. Is there a particular reason for that?