alan-turing-institute / deepsensor

A Python package for tackling diverse environmental prediction tasks with NPs.
https://alan-turing-institute.github.io/deepsensor/
MIT License
94 stars 16 forks source link

Using trainer with Batch Size leads to Conv Model shape errors in Training User Guide #93

Closed nilsleh closed 1 year ago

nilsleh commented 1 year ago

Description

I am running the new User Guide Training Notebook to better understand the details of Conv Training. I downloaded the jupyter notebook and by default it runs fine, however, I wanted to run training with a defined batch size and therefore added a batch_size argument to the trainer. But if batch_size>1, like 4 in this example then I get a shape error in a neuralprocess conv layer:

 File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/deepsensor/train/train.py", line 145, in train_epoch
    batch_loss = train_step(task)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/deepsensor/train/train.py", line 116, in train_step
    task_losses.append(model.loss_fn(task, normalise=True))
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/deepsensor/model/convnp.py", line 760, in loss_fn
    logpdfs = backend.nps.loglik(
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/plum/function.py", line 392, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/neuralprocesses/model/loglik.py", line 105, in loglik
    state, logpdfs = loglik(state, model, *args, **kw_args)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/plum/function.py", line 392, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/neuralprocesses/model/loglik.py", line 56, in loglik
    state, pred = model(
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/plum/function.py", line 459, in __call__
    return self._f(self._instance, *args, **kw_args)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/plum/function.py", line 392, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/neuralprocesses/model/model.py", line 102, in __call__
    return self(
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/plum/function.py", line 459, in __call__
    return self._f(self._instance, *args, **kw_args)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/plum/function.py", line 392, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/neuralprocesses/model/model.py", line 81, in __call__
    _, d = code(self.decoder, xz, z, xt, root=True, **kw_args)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/plum/function.py", line 392, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped
    return f(*args, **kw_args)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/neuralprocesses/chain.py", line 56, in code
    xz, z = code(link, xz, z, x, **kw_args)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/plum/function.py", line 392, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped
    return f(*args, **kw_args)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/neuralprocesses/coding.py", line 45, in code
    return xz, coder(z)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/neuralprocesses/coders/nn.py", line 364, in __call__
    hs = [self.activations[0](self.before_turn_layers[0](x))]
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/nils/.conda/envs/sensorEnv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [32, 8, 5, 5], expected input[4, 10, 352, 512] to have 8 channels, but got 10 channels instead

I find it quiet counterintuitive that changing the batch size argument would have such an effect, because I didn't find another place where I would have to change the batch size in the code, or adapt something. So batch_size=1 works, but batch_size>1 leads to error.

These are the shape outputs from before entering the decoder with batch_size = 1,

xz: torch.Size([1, 1, 352]) torch.Size([1, 1, 512])
z: torch.Size([1, 1, 8, 352, 512])
xt: torch.Size([1, 1, 241]) torch.Size([1, 1, 401])

with batch_size = 4,

xz: torch.Size([4, 1, 352]) torch.Size([4, 1, 512])
z: torch.Size([1, 4, 10, 352, 512])
xt: torch.Size([4, 1, 241]) torch.Size([4, 1, 401])

Thus, it seems like the sampled z from the Dirac pz output from the encoder is causing the difference, but I didn't get further into debugging yet about why changing the batch size would have that effect.

Reproduction steps

1. Download the Training Notebook
2. In the training loop give a batch size argument to the trainer: `batch_losses = trainer(train_tasks, batch_size=4)`
3. See error

Version

0.3.4

Screenshots

![DESCRIPTION](LINK.png)

OS

Linux

tom-andersson commented 1 year ago

Thanks for raising this @nilsleh - if you are using your local version of deepsensor for this, would you be able to print(task) before the batch_loss = train_step(task) in the train_epoch function? We can then check all the shapes conform to n_batches, n_features, *n_obs, and that the context n_features plus the number of context sets (for the density channels) adds up to 8.

nilsleh commented 1 year ago

This is for batch_size=1:

time: 2011-06-01 00:00:00
ops: []
X_c: [(2, 246), ((1, 240), (1, 400)), ((1, 72), (1, 120))]
Y_c: [(1, 246), (1, 240, 400), (3, 72, 120)]
X_t: [((1, 241), (1, 401))]
Y_t: [(1, 241, 401)]

And this for batch_size=4:

time: [Timestamp('2016-02-27 00:00:00'), Timestamp('2015-06-14 00:00:00'), Timestamp('2018-12-05 00:00:00'), Timestamp('2015-03-30 00:00:00')]
ops: ['batch_dim', 'float32', 'numpy_mask', 'nps_mask']
X_c: [(4, 2, 246), ((4, 1, 240), (4, 1, 400)), ((4, 1, 72), (4, 1, 120))]
Y_c: [<neuralprocesses.mask.Masked object at 0x7facf14ddcc0>, <neuralprocesses.mask.Masked object at 0x7facf14df6d0>, <neuralprocesses.mask.Masked object at 0x7facf14ddf60>]
X_t: [((4, 1, 241), (4, 1, 401))]
Y_t: [(4, 1, 241, 401)]

The shapes for the masked objects are: (4, 1, 246), (4, 1, 240, 400), (4, 3, 72, 120)

tom-andersson commented 1 year ago

Thanks @nilsleh! Those shapes all check out. The bug is caused by one of the context sets having >1 dimensions. Here's a MWE which produces the same error:

import deepsensor.torch
from deepsensor.data import DataProcessor, TaskLoader
from deepsensor.model import ConvNP
from deepsensor.train import Trainer

import xarray as xr
import pandas as pd
import numpy as np
from tqdm import tqdm

# Load raw data
ds_raw = xr.tutorial.open_dataset("air_temperature")

# Add extra dim 
ds_raw["air2"] = ds_raw["air"].copy()

# Normalise data 
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = data_processor(ds_raw)

# Set up task loader
task_loader = TaskLoader(context=ds, target=ds)

# Set up model
model = ConvNP(data_processor, task_loader)

# Generate training tasks with up 100 grid cells as context and all grid cells
#   as targets
train_tasks = []
for date in pd.date_range("2013-01-01", "2014-11-30")[::7]:
    N_context = np.random.randint(0, 100)
    task = task_loader(date, context_sampling=N_context, target_sampling="all")
    train_tasks.append(task)

# Train model
trainer = Trainer(model, lr=5e-5)
for epoch in tqdm(range(10)):
    batch_losses = trainer(train_tasks, batch_size=4)
...
RuntimeError: Given groups=1, weight of size [64, 3, 5, 5], expected input[4, 4, 48, 80] to have 3 channels, but got 4 channels instead

@wesselb, do you remember when we found that calling nps.merge_contexts on multi-dimensional context sets resulted in repeated density channels? I'm pretty sure that's what's going on here again, and my hacky solution was to manually override the mask of the merged nps.Masked context observation objects like: task["Y_c"][i].mask = task["Y_c"][i].mask[:, 0:1, :]. Any chance something could be going wrong under the hood in neuralprocesses, either in merge_contexts or the way the ConvNP encoder uses the .mask attr of nps.Masked objects?

N.B Our training unit test includes batching and is passing, but it only tests a 1D context set. Once we patch this bug we should add a test with an N-D context set.

wesselb commented 1 year ago

@tom-andersson Ah, I don't quite recall precisely what that problem was. :( Any chance you could post a small example of the repeated density channel issue here?

tom-andersson commented 1 year ago

Hey @wesselb, I created an MWE in pure neuralprocesses and there's no error, so it must be on the DeepSensor side.

My hypothesis is that it's something to do with applying a numpy NaN mask after merging the context sets into nps.Masked objects: https://github.com/tom-andersson/deepsensor/blob/438295797b57bfc2b206d4e3c9a5079c9ed802bb/deepsensor/data/task.py#L554-L558

I'll dig into this.

tom-andersson commented 1 year ago

Yeah, found it. Was slightly esoteric but it was an array shape bug in the way NaNs were being removed from the nps.Masked objects that come out of nps.merge_contexts. I've added a unit test for batch-wise training with multi-dimensional context sets and this is now passing.

Fixed in v0.3.5 on PyPI, thanks for catching this @nilsleh and thanks @wesselb for helping me realise the bug was on the deepsensor side, not the neuralprocesses side!

wesselb commented 1 year ago

Ah, I'm glad to hear that you managed to find the bug, @tom-andersson! :)