wesselb / neuralprocesses

A framework for composing Neural Processes in Python
https://wesselb.github.io/neuralprocesses
MIT License
76 stars 12 forks source link

Integrating classification into predictions? #13

Closed nmearl closed 7 months ago

nmearl commented 1 year ago

I'm working on a pipeline to predict and classify 1D time series data (specifically astronomical data of nightly sky surveys). I've trained a model on some example time series observations and can get a prediction of the light curve given a few initial observations. Indeed, the model gets better the more I add data points in sequence (simulating nightly repeated observations of the potential event).

My question is, the training data is based on observations of different transient events, but I'm curious if it's possible to also retrieve a labeled classification of what the model thinks of the predicted light curve? Although I've implemented neural networks (specifically RNNs) to do the same, there are a plethora of weird and strange transients that occur which do not have a huge body of (real) training data, and therefore fail to be identified by the RNN. I'm hoping that NPs would be able to provide a more useful estimate of the prediction and classification in cases where the training data is sparse and the body of data is limited.

Any insight into how this might be addressed through NPs would be great!

wesselb commented 1 year ago

Hey @nmearl ! Thanks for opening this issue. :) I'm sorry for the slow reply. Been away on holiday and to a conference, and I'm only now catching up on email.

It is certainly possible to add a classification label to the model! Here's some sample code that illustrates what's possible:

import lab as B
import torch

import neuralprocesses.torch as nps

dim_x = 1
dim_y = 1

# CNN architecture:
unet = nps.UNet(
    dim=dim_x,
    in_channels=2 * dim_y,
    # Add an extra channel for classification.
    out_channels=(2 + 512) * dim_y + 1,
    channels=(8, 16, 16, 32, 32, 64),
)

# Discretisation of the functional embedding:
disc = nps.Discretisation(
    points_per_unit=64,
    multiple=2**unet.num_halving_layers,
    margin=0.1,
    dim=dim_x,
)

# Create the encoder and decoder and construct the model.
encoder = nps.FunctionalCoder(
    disc,
    nps.Chain(
        nps.PrependDensityChannel(),
        nps.SetConv(scale=1 / disc.points_per_unit),
        nps.DivideByFirstChannel(),
        nps.DeterministicLikelihood(),
    ),
)
decoder = nps.Chain(
    unet,
    nps.SetConv(scale=1 / disc.points_per_unit),
    # Besides the regression prediction, also output a tensor of log-probabilities.
    nps.Splitter((2 + 512) * dim_y, 1),
    nps.Parallel(
        nps.LowRankGaussianLikelihood(512),
        lambda x: x,
    ),
)
convgnp = nps.Model(encoder, decoder)

# Run the model on some random data.
dist, log_probs = convgnp(
    B.randn(torch.float32, 16, dim_x, 10),  # Context inputs
    B.randn(torch.float32, 16, dim_y, 10),  # Context outputs
    B.randn(torch.float32, 16, dim_x, 15),  # Target inputs
)

# This is now the prediction.
dist       
print(dist.logpdf(B.randn(torch.float32, 16, dim_x, 15)))  # Log-prob of sometarget outputs

# This can be interpreted as a vector of log-probabilities. Currently, the shape
# is `(16, 1, 15)`, so one probability per output.
log_probs  

# Should you want to produce a global probability, you will need to do some pooling.
# E.g.,
global_log_probs = B.mean(log_probs, axis=-1)  # Shape `(16, 1)`

Is something like this roughly what you're after?

If it is, it would be possible to develop primitives for classification likelihoods. Currently, the library is focussed on regression, but components for classification can certainly be added. You could then write something like:

decoder = nps.Chain(
    ...
    nps.Splitter((2 + 512) * dim_y, num_classes),
    nps.Parallel(
        nps.LowRankGaussianLikelihood(512),
        nps.ClassificationLikelihood(num_classes),
    ),
)

Also note that the neural architectures, such as the U-Net, are currently optimised for regression and not classification. In particular, mean-pooling such as in the above example is likely suboptimal. Rather, one should use neural architectures optimised for classification.

Also #11 might be relevant.