wesselb / neuralprocesses

A framework for composing Neural Processes in Python
https://wesselb.github.io/neuralprocesses
MIT License
73 stars 11 forks source link

[FR] Ability to predict on grid by passing tuple of coords for `xt` #9

Closed tom-andersson closed 1 year ago

tom-andersson commented 1 year ago

In my network of current and potential fellow users of neuralprocesses, there is increasing interest in training a model with targets on a grid. I believe it would be a great addition to the codebase if:

My guess is that this would just involve converting the tuple into a single coordinate tensor and proceeding as normal, and then reshaping at the end. This would save users having to do this. Would it be that simple?

tom-andersson commented 1 year ago

Note, I just double-checked the neuralprocesses docs and realise the shape would be (1, N, N_x, N_y) to align with the standard in the codebase (https://wesselb.github.io/neuralprocesses/basic_usage.html#shape-of-tensors).

tom-andersson commented 1 year ago

Hey @wesselb, correct me if I'm wrong, but I think this feature is already implemented?! I just tried running a model with xt as a tuple of coords and the .mean and .variance attrs of the resulting MultiOutputNormal object had shape (1, 1, 1, N_x, N_y).

Apologies if I have missed some documentation describing this functionality. Feel free to close!

N.B. I'm a bit confused about why the shapes are (1, 1, 1, N_x, N_y). There seems to be an extra 1 according to the (b, c, *n) shape convention.

wesselb commented 1 year ago

@tom-andersson You're indeed right that this is already supported. :)

Here's an example

import lab as B
import torch

import neuralprocesses.torch as nps

cnp = nps.construct_convgnp(dim_x=2, dim_y=1, likelihood="lowrank", points_per_unit=16)

dist = cnp(
    B.randn(torch.float32, 16, 2, 10),
    B.randn(torch.float32, 16, 1, 10),
    (B.randn(torch.float32, 16, 1, 15), B.randn(torch.float32, 16, 1, 20)),
)

dist.sample().shape  # (16, 1, 15, 20)

If you just call dist.sample(), the shape will be (16, 1, 15, 20), which is (b, c, n1, n2). If you call dist.sample(2), the shape will be (2, 16, 1, 15, 20), which is (b1, b2, c, n1, n2). Similarly, if you call dist.sample(1), the shape will be (1, 16, 1, 15, 20), corresponding to (b1, b2, c, n1, n2). Hence, if you want just a single sample without extra batch dimensions, the way to call the sample function is dist.sample().

wesselb commented 1 year ago

Note that this is also compatible with multiple output. Even mixtures of grid-like and list-like predictions are possible:

import lab as B
import torch

import neuralprocesses.torch as nps

cnp = nps.construct_convgnp(dim_x=2, dim_yc=1, dim_yt=2, likelihood="lowrank", points_per_unit=16)

dist = cnp(
    B.randn(torch.float32, 16, 2, 10),
    B.randn(torch.float32, 16, 1, 10),
    nps.AggregateInput(
        ((B.randn(torch.float32, 16, 1, 15), B.randn(torch.float32, 16, 1, 20)), 0),
        (B.randn(torch.float32, 16, 2, 25), 1),
    )
)

print(dist.sample()[0].shape)  # torch.Size([16, 1, 15, 20])
print(dist.sample()[1].shape)  # torch.Size([16, 1, 25])

print(dist.sample(4)[0].shape)  # torch.Size([4, 16, 1, 15, 20])
print(dist.sample(4)[1].shape)  # torch.Size([4, 16, 1, 25])
tom-andersson commented 1 year ago

Hey @wesselb, thanks for confirming all this! Very cool how multi-output architectures can be run simultaneously on-grid and off-grid.

Regarding the unexpected extra dim I was finding, this was with regards to the .mean and .variance attributes of the distribution object, not .sample calls. I am however passing num_samples=10 as a kwarg when calling the model (to work around the ConvLNP behaviour, I think). Could it be that this adds a spurious extra sample dimension to .mean etc?

wesselb commented 1 year ago

Regarding the unexpected extra dim I was finding, (...) Could it be that this adds a spurious extra sample dimension to .mean etc?

@tom-andersson models can contain two sources of stochasticity: in a latent variable in the middle of the architecture, and in the likelihood. Passing num_samples as keyword arg samples the latent variable in the middle of the architecture num_samples times, whereas calling sample() at the end samples the likelihood. (For an LNP, you want to both sample the intermediate latent variable and the likelihood, so you want to set num_samples and call sample().)

For models without a latent variable, you can still pass num_samples. What this will do is to sample the latent variable, which in this case is a Dirac delta distribution, num_samples times, which consequently adds (num_samples,) to the beginning of the shape. However, since all these samples would be identical, the model doesn't copy the samples and simply appends a (1,) dimension to let broadcasting handle the rest.

A uniform interface is provided by nps.predict, which takes a num_samples keyword argument and automatically sets num_samples in the model call only when necessary.

tom-andersson commented 1 year ago

@wesselb thank you for clarifying that the extra dim I was finding is because of the convention for handling the latent variable sample. And good to know that nps.predict handles the different cases under the hood. Will close this now!