NVIDIA / modulus-makani

Massively parallel training of machine-learning based weather and climate models
Other
227 stars 33 forks source link

Batching in ModelWrapper #9

Open dallasfoster opened 6 months ago

dallasfoster commented 6 months ago

In makani/models/model_package.py the class ModelWrapper has the following lines:

if self.add_zenith:
            lon_grid, lat_grid = np.meshgrid(self.lons, self.lats)
            cosz = cos_zenith_angle(time, lon_grid, lat_grid)
            cosz = cosz.astype(np.float32)
            z = torch.from_numpy(cosz).to(device=x.device)
            while z.ndim != x.ndim:
                z = z[None]
            x = torch.cat([x, z], dim=1)

If x.size(0) != 1 then this will fail. Consider a line closer to this:

z = z.repeat(x.shape[0], *[1 for i in range(x.ndim - 1)])