Closed tom-andersson closed 1 year ago
Note, I just double-checked the neuralprocesses
docs and realise the shape would be (1, N, N_x, N_y)
to align with the standard in the codebase (https://wesselb.github.io/neuralprocesses/basic_usage.html#shape-of-tensors).
Hey @wesselb, correct me if I'm wrong, but I think this feature is already implemented?! I just tried running a model with xt
as a tuple of coords and the .mean
and .variance
attrs of the resulting MultiOutputNormal
object had shape (1, 1, 1, N_x, N_y)
.
Apologies if I have missed some documentation describing this functionality. Feel free to close!
N.B. I'm a bit confused about why the shapes are (1, 1, 1, N_x, N_y)
. There seems to be an extra 1
according to the (b, c, *n)
shape convention.
@tom-andersson You're indeed right that this is already supported. :)
Here's an example
import lab as B
import torch
import neuralprocesses.torch as nps
cnp = nps.construct_convgnp(dim_x=2, dim_y=1, likelihood="lowrank", points_per_unit=16)
dist = cnp(
B.randn(torch.float32, 16, 2, 10),
B.randn(torch.float32, 16, 1, 10),
(B.randn(torch.float32, 16, 1, 15), B.randn(torch.float32, 16, 1, 20)),
)
dist.sample().shape # (16, 1, 15, 20)
If you just call dist.sample()
, the shape will be (16, 1, 15, 20)
, which is (b, c, n1, n2)
. If you call dist.sample(2)
, the shape will be (2, 16, 1, 15, 20)
, which is (b1, b2, c, n1, n2)
. Similarly, if you call dist.sample(1)
, the shape will be (1, 16, 1, 15, 20)
, corresponding to (b1, b2, c, n1, n2)
. Hence, if you want just a single sample without extra batch dimensions, the way to call the sample function is dist.sample()
.
Note that this is also compatible with multiple output. Even mixtures of grid-like and list-like predictions are possible:
import lab as B
import torch
import neuralprocesses.torch as nps
cnp = nps.construct_convgnp(dim_x=2, dim_yc=1, dim_yt=2, likelihood="lowrank", points_per_unit=16)
dist = cnp(
B.randn(torch.float32, 16, 2, 10),
B.randn(torch.float32, 16, 1, 10),
nps.AggregateInput(
((B.randn(torch.float32, 16, 1, 15), B.randn(torch.float32, 16, 1, 20)), 0),
(B.randn(torch.float32, 16, 2, 25), 1),
)
)
print(dist.sample()[0].shape) # torch.Size([16, 1, 15, 20])
print(dist.sample()[1].shape) # torch.Size([16, 1, 25])
print(dist.sample(4)[0].shape) # torch.Size([4, 16, 1, 15, 20])
print(dist.sample(4)[1].shape) # torch.Size([4, 16, 1, 25])
Hey @wesselb, thanks for confirming all this! Very cool how multi-output architectures can be run simultaneously on-grid and off-grid.
Regarding the unexpected extra dim I was finding, this was with regards to the .mean
and .variance
attributes of the distribution object, not .sample
calls. I am however passing num_samples=10
as a kwarg when calling the model (to work around the ConvLNP behaviour, I think). Could it be that this adds a spurious extra sample dimension to .mean
etc?
Regarding the unexpected extra dim I was finding, (...) Could it be that this adds a spurious extra sample dimension to .mean etc?
@tom-andersson models can contain two sources of stochasticity: in a latent variable in the middle of the architecture, and in the likelihood. Passing num_samples
as keyword arg samples the latent variable in the middle of the architecture num_samples
times, whereas calling sample()
at the end samples the likelihood. (For an LNP, you want to both sample the intermediate latent variable and the likelihood, so you want to set num_samples
and call sample()
.)
For models without a latent variable, you can still pass num_samples
. What this will do is to sample the latent variable, which in this case is a Dirac delta distribution, num_samples
times, which consequently adds (num_samples,)
to the beginning of the shape. However, since all these samples would be identical, the model doesn't copy the samples and simply appends a (1,)
dimension to let broadcasting handle the rest.
A uniform interface is provided by nps.predict
, which takes a num_samples
keyword argument and automatically sets num_samples
in the model call only when necessary.
@wesselb thank you for clarifying that the extra dim I was finding is because of the convention for handling the latent variable sample. And good to know that nps.predict
handles the different cases under the hood. Will close this now!
In my network of current and potential fellow users of
neuralprocesses
, there is increasing interest in training a model with targets on a grid. I believe it would be a great addition to the codebase if:xt
, similarly to gridded context data, in the model call signature(1, N_x, N_y, N)
)My guess is that this would just involve converting the tuple into a single coordinate tensor and proceeding as normal, and then reshaping at the end. This would save users having to do this. Would it be that simple?