alan-turing-institute / deepsensor

A Python package for tackling diverse environmental prediction tasks with NPs.
https://alan-turing-institute.github.io/deepsensor/
MIT License
72 stars 15 forks source link

Batching with missing data #103

Closed davidwilby closed 4 months ago

davidwilby commented 6 months ago

@tom-andersson et al. I wonder if you can help clear up some difficulty that @MartinSJRogers and I are having with batched training for gridded data with missing values using deepsensor.

As yet I'm unable to work out whether we're doing something incorrectly or whether there are bugs in deepsensor's implementation.

We're working with gridded data with missing values represented as NaNs as specified in the Data Requirements section of the docs.

When setting a batch_size during training, concat_tasks is called, in the below snippet, the remove_target_nans() method is called:

https://github.com/alan-turing-institute/deepsensor/blob/aeccc097699c5963fa81b4601cfa11fad5daa41b/deepsensor/data/task.py#L477-L489

This results in a ValueError raised later in concat_tasks since there are different numbers of targets in each batch:

ValueError: All tasks must have the same number of targets to concatenate: got [9460279, 10432117, 8255541, 10345501]. To train with Task batches containing differing numbers of targets, run the model individually over each task and average the losses.

and as a result we don't get to the calls to mask_nans_numpy and mask_nans_nps towards the end of concat_tasks.

I'm confused by this, since from the message above "Cannot concatenate tasks that have had NaNs masked. " "Masking will be applied automatically after concatenation." and the later call to mask_nans_{numpy,nps} it seems like this should be handled by those methods.

When I remove the call to remove_target_nans here for testing, of course the batch sizes are the same and the ValueError above isn't raised, the rest of concat_tasks runs and mask_nans_{numpy,nps} are called successfully.

This, however, results in an error further down the line in which the Masked object from neuralprocesses is found not to have the dtype attribute:

Full stack trace ``` --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[12], line 15 13 trainer = Trainer(model, lr=5e-5) 14 for epoch in tqdm(range(1)): ---> 15 batch_losses = trainer(train_tasks, tqdm_notebook=True, batch_size=None) # error here due to filesize. I have attempted using batch_size = n, 16 # but get seperate error asserting that the number of targets in each batch must be the same. 17 # Todo- work out how to calculate number of targets in each batch, and ensure the batch size allows me to honour this assertion. 18 losses.append(np.mean(batch_losses)) File _/deepsensor/train/train.py:177, in Trainer.__call__(self, tasks, batch_size, progress_bar, tqdm_notebook) 170 def __call__( 171 self, 172 tasks: List[Task], (...) 175 tqdm_notebook=False, 176 ) -> List[float]: --> 177 return train_epoch( 178 model=self.model, 179 tasks=tasks, 180 batch_size=batch_size, 181 opt=self.opt, 182 progress_bar=progress_bar, 183 tqdm_notebook=tqdm_notebook, 184 ) File _/deepsensor/train/train.py:145, in train_epoch(model, tasks, lr, batch_size, opt, progress_bar, tqdm_notebook) 143 else: 144 task = tasks[batch_i] --> 145 batch_loss = train_step(task) 146 batch_losses.append(batch_loss) 148 return batch_losses File _/deepsensor/train/train.py:116, in train_epoch..train_step(tasks) 114 task_losses = [] 115 for task in tasks: --> 116 task_losses.append(model.loss_fn(task, normalise=True)) 117 mean_batch_loss = B.mean(B.stack(*task_losses)) 118 mean_batch_loss.backward() File _/deepsensor/model/convnp.py:869, in ConvNP.loss_fn(self, task, fix_noise, num_lv_samples, normalise) 865 task = ConvNP.modify_task(task) 867 context_data, xt, yt, model_kwargs = convert_task_to_nps_args(task) --> 869 logpdfs = backend.nps.loglik( 870 self.model, 871 context_data, 872 xt, 873 yt, 874 **model_kwargs, 875 fix_noise=fix_noise, 876 num_samples=num_lv_samples, 877 normalise=normalise, 878 ) 880 loss = -B.mean(logpdfs) 882 return loss File _conda-envs/deepsensor/lib/python3.11/site-packages/plum/function.py:399, in Function.__call__(self, *args, **kw_args) 397 def __call__(self, *args, **kw_args): 398 method, return_type = self._resolve_method_with_cache(args=args) --> 399 return _convert(method(*args, **kw_args), return_type) File _conda-envs/deepsensor/lib/python3.11/site-packages/neuralprocesses/model/loglik.py:113, in loglik(model, *args, **kw_args) 110 @_dispatch 111 def loglik(model: Model, *args, **kw_args): 112 state = B.global_random_state(B.dtype(args[-2])) --> 113 state, logpdfs = loglik(state, model, *args, **kw_args) 114 B.set_global_random_state(state) 115 return logpdfs File _conda-envs/deepsensor/lib/python3.11/site-packages/plum/function.py:399, in Function.__call__(self, *args, **kw_args) 397 def __call__(self, *args, **kw_args): 398 method, return_type = self._resolve_method_with_cache(args=args) --> 399 return _convert(method(*args, **kw_args), return_type) File _conda-envs/deepsensor/lib/python3.11/site-packages/neuralprocesses/model/loglik.py:48, in loglik(state, model, contexts, xt, yt, num_samples, batch_size, normalise, fix_noise, dtype_lik, **kw_args) 12 @_dispatch 13 def loglik( 14 state: B.RandomState, (...) 25 **kw_args, 26 ): 27 \"\"\"Log-likelihood objective. 28 29 Args: (...) 46 tensor: Log-likelihoods. 47 \"\"\" ---> 48 float = B.dtype_float(yt) 49 float64 = B.promote_dtypes(float, np.float64) 51 # For the likelihood computation, default to using a 64-bit version of the data 52 # type of `yt`. File _conda-envs/deepsensor/lib/python3.11/site-packages/plum/function.py:399, in Function.__call__(self, *args, **kw_args) 397 def __call__(self, *args, **kw_args): 398 method, return_type = self._resolve_method_with_cache(args=args) --> 399 return _convert(method(*args, **kw_args), return_type) File _conda-envs/deepsensor/lib/python3.11/site-packages/lab/types.py:342, in dtype_float(x) 340 @dispatch 341 def dtype_float(x): --> 342 return dtype_float(dtype(x)) File _conda-envs/deepsensor/lib/python3.11/site-packages/plum/function.py:399, in Function.__call__(self, *args, **kw_args) 397 def __call__(self, *args, **kw_args): 398 method, return_type = self._resolve_method_with_cache(args=args) --> 399 return _convert(method(*args, **kw_args), return_type) File _conda-envs/deepsensor/lib/python3.11/site-packages/lab/types.py:236, in dtype(a) 226 @dispatch 227 def dtype(a): 228 \"\"\"Determine the data type of an object. 229 230 Args: (...) 234 dtype: Data type of `a`. 235 \"\"\" --> 236 return a.dtype AttributeError: 'Masked' object has no attribute 'dtype'" ```

Are we doing something incorrectly here? Or are there bugs in the implementation. Happy to add more docs when we've worked out what we're doing or contribute bug fixes if required!

Lastly, for non-batched training, are mask_nans_{numpy,nps} run somewhere else? I notice that they're called in modify_task but I'm not yet sure when this is called.

nilsleh commented 6 months ago

@davidwilby Not sure if this helps, I ran into something similar a while ago where the target sets have different number of targets across tasks while batching, and I adapted the concat_tasks function to randomly subsample the targets to a common batch size:

for target_set_i in range(n_target_sets):
    # Raise error if target sets have different numbers of targets across tasks
    n_target_obs = [task["Y_t"][target_set_i].size for task in tasks]
    if not all([n == n_target_obs[0] for n in n_target_obs]):
        # for this target set adapt the number of observations across tasks to min_n_target_obs
        shapes = [task["Y_t"][target_set_i].shape[-1] for task in tasks]
        min_n = min(shapes)

        for task in tasks:
            rand_indices = np.random.choice(
                np.arange(task["Y_t"][target_set_i].shape[-1]),
                size=min_n,
                replace=False,
            )
            task["Y_t"][target_set_i] = task["Y_t"][target_set_i][..., rand_indices]
            task["X_t"][target_set_i] = task["X_t"][target_set_i][..., rand_indices]

Not sure if this helps, but I would agree that it would be nice to include some functionality that handles this for batched training as this can happen quiet frequently.

tom-andersson commented 6 months ago

Hi @davidwilby + @MartinSJRogers, thank you for raising this :) this boils down to a few things:

  1. We can have missing data/NaNs on the context side because ConvNPs represent missing data as zeros in the density channels of the context set encodings. This is handled by neuralprocesses.Masked objects which are constructed by the Task.mask_nans_{numpy,nps} methods you mentioned.
  2. However, we can't have NaNs in target values because backpropagation would fail.
  3. We can remove NaNs from target sets, and NP models can happily train on varying-length target arrays when this happens.
  4. However, we can't concatenate varying-length arrays into a single array (and there is no support in deepsensor/neuralprocesses for padding the target arrays and then masking the padded values from the loss).
  5. Therefore we can't run in batch mode if there are missing values in the targets.

The workaround, as suggested in the error message, is to manually run the model multiple times in a for loop over your 'batch', and then average the losses within your model update. This gives you the smoother loss surface of batch training, but unfortunately it doesn't give you the computational efficiency of running on multiple examples in parallel on a GPU.

@nilsleh's workaround of subsampling to the smallest number of targets is a nice idea, although the model will see fewer target points per batch than it would otherwise, so this is a trade-off between computational efficiency and learning efficiency. If the number of non-missing target points are similar between all Tasks, which looks to be the case from [9460279, 10432117, 8255541, 10345501], then it might not be a bad shout.

When I remove the call to remove_target_nans here for testing, of course the batch sizes are the same and the ValueError above isn't raised, the rest of concat_tasks runs and masknans{numpy,nps} are called successfully. This, however, results in an error further down the line in which the Masked object from neuralprocesses is found not to have the dtype attribute:

neuralprocesses stack traces can be confusing and the dtype error isn't clear, but you can't have neuralprocesses.Masked objects in the targets. The targets need to be vanilla tensors. Only context data can have neuralprocesses.Masked objects, and the missing data will be dealt with under the hood, as mentioned above.

Hope this clears things up, and please close if so :)

nilsleh commented 6 months ago

I have a related question to NaNs in target sets, which is the case for the data that I am working with. If I don't modify anything and use the provided Trainer code as such:

trainer = Trainer(model, lr=5e-5)
batch_losses = trainer(train_tasks, batch_size=None)

A single task looks like this in the loss_fn computation after the modify_task call here

time: Timestamp/2013-06-15 12:00:00
ops: ['str/batch_dim', 'str/float32', 'str/numpy_mask', 'str/nps_mask', 'str/tensor']
X_c: ['Tensor/torch.float32/torch.Size([1, 2, 256000])', 'Tensor/torch.float32/torch.Size([1, 2, 256000])', 'Tensor/torch.float32/torch.Size([1, 2, 256000])']
Y_c: ['Masked/(y=torch.float32/torch.Size([1, 1, 256000]))/(mask=torch.float32/torch.Size([1, 1, 256000]))', 'Masked/(y=torch.float32/torch.Size([1, 1, 256000]))/(mask=torch.float32/torch.Size([1, 1, 256000]))', 'Tensor/torch.float32/torch.Size([1, 1, 256000])']
X_t: ['Tensor/torch.float32/torch.Size([1, 2, 224000])']
Y_t: ['Masked/(y=torch.float32/torch.Size([1, 1, 224000]))/(mask=torch.float32/torch.Size([1, 1, 224000]))']

And I get the error: AttributeError: 'Masked' object has no attribute 'dtype'

If I change the loss function to remove the targets before modifying the task by adding task.remove_target_nans(), while keeping batch_size=None or just changing the trainer batch_size>2 because then remove_target_nans() is called in concat_tasks a single task looks like this:

time: Timestamp/2013-06-15 12:00:00
ops: ['str/target_nans_removed', 'str/batch_dim', 'str/float32', 'str/numpy_mask', 'str/nps_mask', 'str/tensor']
X_c: ['Tensor/torch.float32/torch.Size([1, 2, 256000])', 'Tensor/torch.float32/torch.Size([1, 2, 256000])', 'Tensor/torch.float32/torch.Size([1, 2, 256000])']
Y_c: ['Masked/(y=torch.float32/torch.Size([1, 1, 256000]))/(mask=torch.float32/torch.Size([1, 1, 256000]))', 'Masked/(y=torch.float32/torch.Size([1, 1, 256000]))/(mask=torch.float32/torch.Size([1, 1, 256000]))', 'Tensor/torch.float32/torch.Size([1, 1, 256000])']
X_t: ['Tensor/torch.float32/torch.Size([1, 2, 215843])']
Y_t: ['Tensor/torch.float32/torch.Size([1, 1, 215843])']

but then I get the neuralprocess library error: AssertionError: Expected not a parallel of elements, but got inputs and outputs in parallel.

Full Stacktrace ```python File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/deepsensor/train/train.py", line 116, in train_step task_losses.append(model.loss_fn(task, normalise=True)) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/deepsensor/model/convnp.py", line 870, in loss_fn logpdfs = backend.nps.loglik( File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__ return _convert(method(*args, **kw_args), return_type) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/loglik.py", line 113, in loglik state, logpdfs = loglik(state, model, *args, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__ return _convert(method(*args, **kw_args), return_type) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/loglik.py", line 64, in loglik state, pred = model( File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 489, in __call__ return self._f(self._instance, *args, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__ return _convert(method(*args, **kw_args), return_type) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/model.py", line 101, in __call__ return self( File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 489, in __call__ return self._f(self._instance, *args, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__ return _convert(method(*args, **kw_args), return_type) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/model/model.py", line 72, in __call__ xz, pz = code(self.encoder, xc, yc, xt, root=True, **enc_kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__ return _convert(method(*args, **kw_args), return_type) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped return f(*args, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/coders/functional.py", line 39, in code return code(coder.coder, xz, z, x, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__ return _convert(method(*args, **kw_args), return_type) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped return f(*args, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/chain.py", line 56, in code xz, z = code(link, xz, z, x, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__ return _convert(method(*args, **kw_args), return_type) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped return f(*args, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/chain.py", line 56, in code xz, z = code(link, xz, z, x, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/plum/function.py", line 399, in __call__ return _convert(method(*args, **kw_args), return_type) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/__init__.py", line 21, in f_wrapped return f(*args, **kw_args) File "/opt/anaconda3/envs/oceanEnv/lib/python3.10/site-packages/neuralprocesses/coders/shaping.py", line 179, in code raise AssertionError( AssertionError: Expected not a parallel of elements, but got inputs and outputs in parallel. ``` ```

And I am not sure what I have done wrong.

tom-andersson commented 5 months ago

Hi @nilsleh, the AttributeError: 'Masked' object has no attribute 'dtype' is exactly what is described above - essentially you can't train a model with NaNs in targets. It is unfortunate that when target NaNs are present the error message is confusing. As you say, using batch_size > 1 means target NaNs are automatically removed within the concat_tasks method.

Regarding the AssertionError: Expected not a parallel of elements, but got inputs and outputs in parallel, I have never seen that neuralprocesses error before. The shape of the Task looks fine, but I am missing context for what exact code you call prior to this. Would you be able to produce an MWE in a Colab by generating random data?

nilsleh commented 5 months ago

Hi @tom-andersson thanks for the reply, I have created a gist with the accompanying data I am using. The data tar file also contains the normalization parameters for the data processor.

EDIT: I was able to resolve it thanks to Wessel, it was a misconfiguration of the data processor and model.

tom-andersson commented 4 months ago

Glad you could solve this @nilsleh - to copy over your solution from the neuralprocesses GitHub for future reference:

I had forgotten to pass in the task_loader as an argument to the ConvNP model as I was using multiple context sets. And that just initializes a model with default parameters than then result in a mismatch, when you try to pass in your "actual" data.

tom-andersson commented 4 months ago

There were no complaints about closing this issue, so closing now.