TorchSpatiotemporal / tsl

tsl: a PyTorch library for processing spatiotemporal data.
https://torch-spatiotemporal.readthedocs.io/
MIT License
236 stars 22 forks source link

Add parameter to specify which is the time dimension in MaskedMetric #23

Open LucaButera opened 1 year ago

LucaButera commented 1 year ago

https://github.com/TorchSpatiotemporal/tsl/blob/154cccf6a07c5cc0558569f4faa18928f3755603/tsl/metrics/torch/metric_base.py#L112-L122

In the referenced snippet, MaskedMetric's update function assumes the time dimension is the second. However, this leads to adding an unnecessary dummy batch dimension when we represent a batch of graphs as a single big graph.

I suggest two possible solutions to avoid this: 1) Add a t_dim parameter and use x.select(t_dim, self.at) instead of x[:, self.at]. 2) Add a pattern string and use it to identify the time dimension.

In both cases this can either be a class attribute or a method parameter, it depends on preferring to have it set once or allowing it to be changed each time the metric is updated.

The first solution is the easiest to implement, however the second one may make further dimension semantics dependent aggregations easier to implement down the road.

If this is deemed useful I can implement this behavior with an agreed solution.

marshka commented 1 year ago

Hi Luca, sorry for the late reply! Your suggestion seems useful and nice, I would opt for the first solution you proposed. Also, I think it is better to set the dim attribute just once at instantiation, at the moment I don't see any advantage in having it as a method parameter. Feel free to propose a PR!