wesselb / neuralprocesses

A framework for composing Neural Processes in Python
https://wesselb.github.io/neuralprocesses
MIT License
76 stars 12 forks source link

Not getting expected results #20

Open nmearl opened 7 months ago

nmearl commented 7 months ago

Hey there. I've been revisiting neural processes for use in a project dealing with simple time series data (sets of x, y values for which we'd like to make future predictions).

I've constructed a very simple example using sin data, but doing the most basic implementation doesn't seem to yield the results we'd expect. I'm curious if you might be able to help with our understanding of an NP implementation vis-a-vis a typical RNN implementation; or, rather, how we can make such a simple example work for us.

Here's the simplest implementation we could construct:

import lab as B
import torch
import matplotlib.pyplot as plt
import glob
import numpy as np
from PIL import Image
import neuralprocesses.torch as nps

def generate_sin_data(amplitude_range=(0., 2), shift_range=(0., 1.), batch_size=16, num_points=10):
    a_min, a_max = amplitude_range
    b_min, b_max = shift_range

    cxt_x = np.zeros(shape=(batch_size, 1, num_points))
    cxt_y = np.zeros(shape=(batch_size, 1, num_points))

    for i in range(batch_size):
        a = (a_max - a_min) * np.random.rand() + a_min
        b = (b_max - b_min) * np.random.rand() + b_min
        x = np.linspace(0, 2 * np.pi, num_points)
        cxt_x[i] = x
        cxt_y[i] = a * (np.sin(x - b) + 1)

    return torch.from_numpy(cxt_x).float(), torch.from_numpy(cxt_y).float()

dataset = generate_sin_data(amplitude_range=(0., 2.), 
                            shift_range=(0., 1.),
                            batch_size=16)

# Construct a ConvCNP.
convcnp = nps.construct_convgnp(dim_x=1, dim_y=1, likelihood="het", num_basis_functions=512)

# Construct optimiser.
opt = torch.optim.Adam(convcnp.parameters(), 1e-3)

# Training: optimise the model for 32 batches.
for _ in range(256):
    # Sample a batch of new context and target sets. Replace this with your data. The
    # shapes are `(batch_size, dimensionality, num_data)`.

    xt, yt = generate_sin_data(amplitude_range=(0., 2.), 
                               shift_range=(0., 1.),
                               batch_size=16,
                               num_points=15)

    idx = np.sort(np.random.choice(15, 10, replace=False))
    xc, yc = xt[:, :, idx], yt[:, :, idx]

    # Compute the loss and update the model parameters.
    loss = -torch.mean(nps.loglik(convcnp, xc, yc, xt, yt, normalise=True))
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

# Make predictions on some new data
pxt, pyt = generate_sin_data(amplitude_range=(0., 2), 
                             shift_range=(0., 1.),
                             batch_size=1,
                             num_points=15)

# idx = np.sort(np.random.choice(15, 10, replace=False))
idx = np.random.randint(2, 15)
pxc, pyc = pxt[:, :, :idx], pyt[:, :, :idx]

# Testing: make some predictions.
mean, var, noiseless_samples, noisy_samples = nps.predict(
    convcnp,
    pxc,
    pyc,
    pxt
)

f, ax = plt.subplots()

for i in range(1):
    ctx_x, ctx_y = pxc.numpy()[i, 0, :], pyc.numpy()[i, 0, :]
    tar_x, tar_y = pxt.numpy()[i, 0, :], pyt.numpy()[i, 0, :]
    ax.scatter(ctx_x, ctx_y)
    ax.scatter(tar_x, tar_y, color='none', edgecolor='k')

for i in range(1):
    pre_x, pre_y = pxt.numpy()[i, 0, :], mean.detach().numpy()[i, 0, :]
    pre_yerr = var.detach().numpy()[i, 0, :]
    ax.plot(pre_x, pre_y)
    ax.fill_between(pre_x, pre_y - pre_yerr, pre_y + pre_yerr, alpha=0.25, color='C0')

And we get results similar to:

Figure 3 Figure 2 Figure 53

Are we approaching this wrong? Should we be using many more batches? Is prediction in this way not a good use of NPs?

Thanks for any insight!

Cheers, Nick

wesselb commented 7 months ago

Hey Nick! A very brief reply to let you know that I’ve read this. :) I’m currently away, but I will get back to you in a few days.

wesselb commented 7 months ago

Hey @nmearl,

Are we approaching this wrong? Should we be using many more batches? Is prediction in this way not a good use of NPs?

I think you are approaching this in the right way! :) It's just a matter of setting up the model and the data in the right way so that the model can learn what it needs to learn to generalise.

I've made a few tweaks to your script:

See here:

import glob

import lab as B
import matplotlib.pyplot as plt
import neuralprocesses.torch as nps
import numpy as np
import torch
from PIL import Image

B.set_random_seed(0)

def generate_sin_data(
    amplitude_range=(0.0, 2), shift_range=(0.0, 1.0), batch_size=16, num_points=10
):
    a_min, a_max = amplitude_range
    b_min, b_max = shift_range

    cxt_x = np.zeros(shape=(batch_size, 1, num_points))
    cxt_y = np.zeros(shape=(batch_size, 1, num_points))

    for i in range(batch_size):
        a = (a_max - a_min) * np.random.rand() + a_min
        b = (b_max - b_min) * np.random.rand() + b_min
        x = np.linspace(0, 6, num_points)
        cxt_x[i] = x
        cxt_y[i] = a * (np.sin(2 * np.pi / 2 * (x - b)) + 1)

    return torch.from_numpy(cxt_x).float(), torch.from_numpy(cxt_y).float()

dataset = generate_sin_data(
    amplitude_range=(1.0, 2.0), shift_range=(0.0, 1.0), batch_size=16
)

# Construct a ConvCNP.
convcnp = nps.construct_convgnp(
    dim_x=1,
    dim_y=1,
    likelihood="het",
    points_per_unit=16,
)

# Construct optimiser.
opt = torch.optim.Adam(convcnp.parameters(), 1e-3)

for i in range(1024):
    print(i)
    # Sample a batch of new context and target sets. Replace this with your data. The
    # shapes are `(batch_size, dimensionality, num_data)`.

    x, y = generate_sin_data(
        amplitude_range=(1.5, 2.0),
        shift_range=(0.0, 1.0),
        batch_size=16,
        num_points=50,
    )

    inds = np.random.permutation(x.shape[2])
    xc, yc = x[:, :, inds[:25]], y[:, :, inds[:25]]
    xt, yt = x[:, :, inds[25:]], y[:, :, inds[25:]]

    # Compute the loss and update the model parameters.
    loss = -torch.mean(nps.loglik(convcnp, xc, yc, xt, yt, normalise=True))
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

# Make predictions on some new data
pxt, pyt = generate_sin_data(
    amplitude_range=(1.5, 2),
    shift_range=(0.0, 1.0),
    batch_size=1,
    num_points=30,
)

pxc, pyc = pxt[:, :, :15], pyt[:, :, :15]

# Testing: make some predictions.
mean, var, noiseless_samples, noisy_samples = nps.predict(convcnp, pxc, pyc, pxt)

f, ax = plt.subplots()

for i in range(1):
    ctx_x, ctx_y = pxc.numpy()[i, 0, :], pyc.numpy()[i, 0, :]
    tar_x, tar_y = pxt.numpy()[i, 0, :], pyt.numpy()[i, 0, :]
    ax.scatter(ctx_x, ctx_y)
    ax.scatter(tar_x, tar_y, color="none", edgecolor="k")

for i in range(1):
    pre_x, pre_y = pxt.numpy()[i, 0, :], mean.detach().numpy()[i, 0, :]
    pre_yerr = 2 * np.sqrt(var.detach().numpy()[i, 0, :])
    ax.plot(pre_x, pre_y)
    ax.fill_between(pre_x, pre_y - pre_yerr, pre_y + pre_yerr, alpha=0.25, color="C0")
plt.show()

With these tweaks, I get the following prediction:

image

It's not perfect, but already much better. You can like get even better predictions using nps.ar_predict.

In the end, it is all about setting up the model and the data in the right way and training for long enough.

By the way, I would recommend to run this on a GPU. You can do these simple examples on a CPU, but that won't scale to bigger data sets and bigger models.

nmearl commented 7 months ago

Thanks, @wesselb! This is excellent. I do seem to be having issues with replacing the sinusoidal data with data more akin to my particular use case, but I've reached out via email so as not to clutter the issue.