Open LucaButera opened 1 year ago
Hi Luca, sorry for the late reply! Your suggestion seems useful and nice, I would opt for the first solution you proposed. Also, I think it is better to set the dim
attribute just once at instantiation, at the moment I don't see any advantage in having it as a method parameter. Feel free to propose a PR!
https://github.com/TorchSpatiotemporal/tsl/blob/154cccf6a07c5cc0558569f4faa18928f3755603/tsl/metrics/torch/metric_base.py#L112-L122
In the referenced snippet, MaskedMetric's update function assumes the time dimension is the second. However, this leads to adding an unnecessary dummy batch dimension when we represent a batch of graphs as a single big graph.
I suggest two possible solutions to avoid this: 1) Add a
t_dim
parameter and usex.select(t_dim, self.at)
instead ofx[:, self.at]
. 2) Add a pattern string and use it to identify the time dimension.In both cases this can either be a class attribute or a method parameter, it depends on preferring to have it set once or allowing it to be changed each time the metric is updated.
The first solution is the easiest to implement, however the second one may make further dimension semantics dependent aggregations easier to implement down the road.
If this is deemed useful I can implement this behavior with an agreed solution.