openclimatefix / graph_weather

PyTorch implementation of Ryan Keisler's 2022 "Forecasting Global Weather with Graph Neural Networks" paper (https://arxiv.org/abs/2202.07575)
MIT License
186 stars 45 forks source link

Error after Installation #36

Open MoHawastaken opened 2 years ago

MoHawastaken commented 2 years ago

Hey, I just installed your package and all requirements, tried to run your first example usage and got the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/Users/me/Documents/gnn_code/graph_weather.py in <cell line: 15>()
     [12](file:///Users/me/Documents/gnn_code/graph_weather.py?line=11) model = GraphWeatherForecaster(lat_lons)
     [14](file:///Users/me/Documents/gnn_code/graph_weather.py?line=13) features = torch.randn((2, len(lat_lons), 78))
---> [16](file:///Users/me/Documents/gnn_code/graph_weather.py?line=15) out = model(features)
     [17](file:///Users/me/Documents/gnn_code/graph_weather.py?line=16) criterion = NormalizedMSELoss(lat_lons=lat_lons, feature_variance=torch.randn((78,)))
     [18](file:///Users/me/Documents/gnn_code/graph_weather.py?line=17) loss = criterion(out, features)

File ~/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   [1126](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1125) # If we don't have any hooks, we want to skip the rest of the logic in
   [1127](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1126) # this function, and just call forward.
   [1128](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1127) if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   [1129](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1128)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1130](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1129)     return forward_call(*input, **kwargs)
   [1131](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1130) # Do not call functions when jit is used
   [1132](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1131) full_backward_hooks, non_full_backward_hooks = [], []

File ~/Documents/gnn_code/graph_weather/models/forecast.py:97, in GraphWeatherForecaster.forward(self, features)
     [87](file:///Users/me/Documents/gnn_code/graph_weather/models/forecast.py?line=86) def forward(self, features: torch.Tensor) -> torch.Tensor:
     [88](file:///Users/me/Documents/gnn_code/graph_weather/models/forecast.py?line=87)     """
     [89](file:///Users/me/Documents/gnn_code/graph_weather/models/forecast.py?line=88)     Compute the new state of the forecast
     [90](file:///Users/me/Documents/gnn_code/graph_weather/models/forecast.py?line=89) 
   (...)
     [95](file:///Users/me/Documents/gnn_code/graph_weather/models/forecast.py?line=94)         The next state in the forecast
     [96](file:///Users/me/Documents/gnn_code/graph_weather/models/forecast.py?line=95)     """
---> [97](file:///Users/me/Documents/gnn_code/graph_weather/models/forecast.py?line=96)     x, edge_idx, edge_attr = self.encoder(features)
     [98](file:///Users/me/Documents/gnn_code/graph_weather/models/forecast.py?line=97)     x = self.processor(x, edge_idx, edge_attr)
     [99](file:///Users/me/Documents/gnn_code/graph_weather/models/forecast.py?line=98)     x = self.decoder(x, features.shape[0])

File ~/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   [1126](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1125) # If we don't have any hooks, we want to skip the rest of the logic in
   [1127](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1126) # this function, and just call forward.
   [1128](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1127) if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   [1129](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1128)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1130](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1129)     return forward_call(*input, **kwargs)
   [1131](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1130) # Do not call functions when jit is used
   [1132](file:///Users/me/Documents/gnn_code/graphenv/lib/python3.9/site-packages/torch/nn/modules/module.py?line=1131) full_backward_hooks, non_full_backward_hooks = [], []

File ~/Documents/gnn_code/graph_weather/models/layers/encoder.py:158, in Encoder.forward(self, features)
    [156](file:///Users/me/Documents/gnn_code/graph_weather/models/layers/encoder.py?line=155) self.graph = self.graph.to(features.device)
    [157](file:///Users/me/Documents/gnn_code/graph_weather/models/layers/encoder.py?line=156) self.latent_graph = self.latent_graph.to(features.device)
--> [158](file:///Users/me/Documents/gnn_code/graph_weather/models/layers/encoder.py?line=157) features = torch.cat(
    [159](file:///Users/me/Documents/gnn_code/graph_weather/models/layers/encoder.py?line=158)     [features, einops.repeat(self.h3_nodes, "n f -> b n f", b=batch_size)], dim=1
    [160](file:///Users/me/Documents/gnn_code/graph_weather/models/layers/encoder.py?line=159) )
    [161](file:///Users/me/Documents/gnn_code/graph_weather/models/layers/encoder.py?line=160) # Cat with the h3 nodes to have correct amount of nodes, and in right order
    [162](file:///Users/me/Documents/gnn_code/graph_weather/models/layers/encoder.py?line=161) features = einops.rearrange(features, "b n f -> (b n) f")

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 78 but got size 102 for tensor number 1 in the list.

Is this a mistake on my side?

Best, Moritz

jacobbieker commented 2 years ago

Hi,

No, it should just be you have to change the input dimension size to 102, by default its 78 to match the input size in the paper, but I think the example is outdated right now. I'll update the docs soon for it

MoHawastaken commented 2 years ago

Hey, you are right of course, thanks for the quick reply!

A very unrelated question. The paper says:

"The initial edge features are the positions of the lat/lon nodes connected to each icosahedron node. These positions are provided in a local coordinate system that is defined relative to each icosahedron node."

You set cos(distance), sin(distance) as edge_attributes in encoder and decoder, should that not rather be the relative position (lon_original - lon_icosa, lat_original - lon_icosa). I agree that adding distance also makes sense, but maybe without cos/sin? My simple logic: The smaller the distance, the more related the points.

jacobbieker commented 2 years ago

Yeah, but as that distance in the sin and cos is the great circle distance from the icosahedron node to the lat/lon point, I think it is quite similar. The issue with just subtracting lat/lon points from each other is that the physical distance between two points changes if they are near the poles or the equator, and so I think the cos/sin helps to fix that issue, in a similar way that in the NormalizedMSELoss we include the cosine of the latitude to correct for each pixel becoming physically larger as we move from the poles to the equator.

An option could be added to just use the great circle distance, instead of the sin/cos part as well.

MoHawastaken commented 2 years ago

Yes, direct differences were a bad suggestion. But in order to get full 2d information in polar coordinate style, besides a distance you would also like to have a direction. cos/sin of distance does not add any intrinsic 2 dimensionality. Here is a suggestion to calculate such an angle, that is invariant to the base icosahedron grid point and is maybe a bit over the top...: 1) Find the geodesic rotation of the icosa. point to a fixed point, say (0,0,1), via Rot = T R_n(alpha) T^-1 as described in the answer of G Cab here: https://math.stackexchange.com/questions/114107/determine-the-rotation-necessary-to-transform-one-point-on-a-sphere-to-another 2) Apply the rotation to the point in the original grid. 3) So that the angle refers to North-South-West-East directions irrespective of the base icosahedron grid point, simply rotate along (0,0,1) by the angle -lat_icosa. 4) To determine the angle, project the rotated point via the normal vector (0,0,1) onto (smth,smth,1). Now you can view (0,0,1) as the origin (0,0,0) and normalize the projected point (smth,smth,0). Then (smth,smth,0) * (1,0,0) yields the cosine of the angle we are interested in.

Of course, going north from a grid point on the northern hemisphere has a different effect than going north on the southern hemisphere, so the value of such an angle is questionable for GNNs, which need invariance with respect to the base grid point, I agree...

jacobbieker commented 2 years ago

Ah yeah, that makes sense, I guess one different way could be just decompose the distance into the north/south change and the west/east change, not in lat/long but in the great circle distance that is currently being used. Might be a bit simpler than this approach, although if you want to write a PR for this, I'd be happy to review it/merge it!

peterdudfield commented 1 year ago

@all-contributors please add @MoHawastaken for bug

allcontributors[bot] commented 1 year ago

@peterdudfield

I've put up a pull request to add @MoHawastaken! :tada: