samholt / NeuralLaplace

Neural Laplace: Differentiable Laplace Reconstructions for modelling any time observation with O(1) complexity.
https://samholt.github.io/NeuralLaplace/
MIT License
71 stars 11 forks source link

Cant reproduce test RMSE for Lokta volterra DDE #3

Closed thibmonsel closed 1 year ago

thibmonsel commented 1 year ago

I am trying to reproduce some of the result from your paper on Lotka Volterra DDE experiment. I cant seem to hit a test RMSE in the range of .0475 ± .0061 from Table 5 in the paper.

I basically cloned the repository and launch :

python exp_all_baselines.py -d lotka_volterra_system_with_delay

This value I get :

17:08:54,689 root INFO [epoch=999] epoch_duration=0.52 | train_loss=0.9681546149729995  | val_loss=0.8175302209565037   | val_mse=0.8175302209565037
17:08:54,690 root INFO epoch_duration_mean=0.50497
17:08:54,715 root INFO Result: Neural Laplace - TEST RMSE: 0.9646045838282256

I was wondering if some hyperparameters were changed during later commits or if i'm doing something wrong/not properly using the README.md instructions.

Thanks !

samholt commented 1 year ago

Hi @thibmonsel,

Thank you for opening an issue. That is interesting, I have indeed changed the algorithm since the initial commit for the paper (hopefully improving a few corner cases). The initial code and results should be replicable from this commit https://github.com/samholt/NeuralLaplace/tree/3582b411986a50f11a9aee1c3a3cfc34527afcc1 .

I noticed you have closed this issue, are you still experiencing this issue? Please do let me know, and I will investigate this further to make the correct state again that originally replicated those results.

Thank you. Apologies for the delayed response, I have been on holiday for the past week, and I am back now.

All the best, Sam

thibmonsel commented 1 year ago

Thanks for the response ! I seem to have an issue with saving and reloading the trained model for further investigation. This is most likely an other issue not related to torchlaplace but Im not sure, I have never seen anything of the sort, I'm a bit clueless.

Here is the code snipped in order to reproduce the issue (normally you just need to give YOUR_PATH) I check the model weights are the same, models are in eval mode. This is maybe because of the GRU units however in ReverseGRUEncoder your first hidden state is 0 so this doesn't seem to be the problem.

I tried on my local machine and on colab too.

I basically took your Tutorial 1 code and saved the model


# Commented out IPython magic to ensure Python compatibility.
from torch import nn
import torch
import matplotlib.pyplot as plt
# %matplotlib inline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model (encoder and Laplace representation func)
class ReverseGRUEncoder(nn.Module):
    # Encodes observed trajectory into latent vector
    def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True):
        super(ReverseGRUEncoder, self).__init__()
        self.encode_obs_time = encode_obs_time
        if self.encode_obs_time:
            dimension_in += 1
        self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True)
        self.linear_out = nn.Linear(hidden_units, latent_dim)
        nn.init.xavier_uniform_(self.linear_out.weight)

    def forward(self, observed_data, observed_tp):
        trajs_to_encode = observed_data  # (batch_size, t_observed_dim, observed_dim)
        if self.encode_obs_time:
            trajs_to_encode = torch.cat(
                (
                    observed_data,
                    observed_tp.view(1, -1, 1).repeat(observed_data.shape[0], 1, 1),
                ),
                dim=2,
            )
        reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))
        out, _ = self.gru(reversed_trajs_to_encode)
        return self.linear_out(out[:, -1, :])

class LaplaceRepresentationFunc(nn.Module):
    # SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : b dim s reconstruction terms, k is latent encoding dimension, d is output dimension
    def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64):
        super(LaplaceRepresentationFunc, self).__init__()
        self.s_dim = s_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(s_dim * 2 + latent_dim, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, (s_dim) * 2 * output_dim),
        )

        for m in self.linear_tanh_stack.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

        phi_max = torch.pi / 2.0
        self.phi_scale = phi_max - -torch.pi / 2.0

    def forward(self, i):
        out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view(
            -1, 2 * self.output_dim, self.s_dim
        )
        theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi  # From - pi to + pi
        phi = (
            nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0
            - torch.pi / 2.0
            + self.phi_scale / 2.0
        )  # Form -pi / 2 to + pi / 2
        return theta, phi

from torch.utils.data import DataLoader
from torchlaplace.data_utils import basic_collate_fn

normalize_dataset = True
batch_size = 128
extrapolate = True

def sawtooth(trajectories_to_sample=100, t_nsamples=200):
    # Toy sawtooth waveform. Simple to generate, for Differential Equation Datasets see datasets.py (Note more complex DE take time to sample from, in some cases minutes).
    t_end = 20.0
    t_begin = t_end / t_nsamples
    ti = torch.linspace(t_begin, t_end, t_nsamples).to(device)

    def sampler(t, x0=0):
        return (t + x0) / (2 * torch.pi) - torch.floor((t + x0) / (2 * torch.pi))

    x0s = torch.linspace(0, 2 * torch.pi, trajectories_to_sample)
    trajs = []
    for x0 in x0s:
        trajs.append(sampler(ti, x0))
    y = torch.stack(trajs)
    trajectories = y.view(trajectories_to_sample, -1, 1)
    return trajectories, ti

trajectories, t = sawtooth(
    trajectories_to_sample=1000,
    t_nsamples=200,
)
if normalize_dataset:
    samples = trajectories.shape[0]
    dim = trajectories.shape[2]
    traj = (
        torch.reshape(trajectories, (-1, dim))
        - torch.reshape(trajectories, (-1, dim)).mean(0)
    ) / torch.reshape(trajectories, (-1, dim)).std(0)
    trajectories = torch.reshape(traj, (samples, -1, dim))
train_split = int(0.8 * trajectories.shape[0])
test_split = int(0.9 * trajectories.shape[0])
traj_index = torch.randperm(trajectories.shape[0])
train_trajectories = trajectories[traj_index[:train_split], :, :]
val_trajectories = trajectories[traj_index[train_split:test_split], :, :]
test_trajectories = trajectories[traj_index[test_split:], :, :]

input_dim = train_trajectories.shape[2]
output_dim = input_dim

dltrain = DataLoader(
    train_trajectories,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
dlval = DataLoader(
    val_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
dltest = DataLoader(
    test_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)

"""### Instantiate model and sample prediction"""

latent_dim = 2
hidden_units = 64
encode_obs_time = True
s_recon_terms = 33

encoder = ReverseGRUEncoder(
    input_dim,
    latent_dim,
    hidden_units // 2,
    encode_obs_time=encode_obs_time,
).to(device)
laplace_rep_func = LaplaceRepresentationFunc(
    s_recon_terms, output_dim, latent_dim
).to(device)

"""We can generate predictions using `laplace_reconstruct` as follows"""

from torchlaplace import laplace_reconstruct

laplace_rep_func.eval(), encoder.eval()
for batch in dlval:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(
        laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
    )

"""### Train end to end """

from time import time
from copy import deepcopy

learning_rate = 1e-3
epochs = 500
patience = None

if not patience:
    patience = epochs

params = list(laplace_rep_func.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)
loss_fn = torch.nn.MSELoss()

best_loss = float("inf")
waiting = 0

for epoch in range(epochs):
    iteration = 0
    epoch_train_loss_it_cum = 0
    start_time = time()
    laplace_rep_func.train(), encoder.train()
    for batch in dltrain:
        optimizer.zero_grad()
        trajs_to_encode = batch[
            "observed_data"
        ]  # (batch_size, t_observed_dim, observed_dim)
        observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
        p = encoder(
            trajs_to_encode, observed_tp
        )  # p is the latent tensor encoding the initial states
        tp_to_predict = batch["tp_to_predict"]
        predictions = laplace_reconstruct(
            laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
        )
        loss = loss_fn(
            torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 1)
        optimizer.step()
        epoch_train_loss_it_cum += loss.item()
        iteration += 1
    epoch_train_loss = epoch_train_loss_it_cum / iteration
    epoch_duration = time() - start_time

    # Validation step
    laplace_rep_func.eval(), encoder.eval()
    cum_val_loss = 0
    cum_val_batches = 0
    for batch in dlval:
        trajs_to_encode = batch[
            "observed_data"
        ]  # (batch_size, t_observed_dim, observed_dim)
        observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
        p = encoder(
            trajs_to_encode, observed_tp
        )  # p is the latent tensor encoding the initial states
        tp_to_predict = batch["tp_to_predict"]
        predictions = laplace_reconstruct(
            laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
        )
        cum_val_loss += loss_fn(
            torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
        ).item()            
        cum_val_batches += 1
    val_mse = cum_val_loss / cum_val_batches
    if epoch % 100 == 0:
        print(
            "[epoch={}] epoch_duration={:.2f} | train_loss={}\t| val_mse={}\t|".format(
                epoch, epoch_duration, epoch_train_loss, val_mse
            )
        )
        plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
        plt.plot(
                tp_to_predict.detach().cpu()[0],
                batch["data_to_predict"].detach().cpu()[0],
                "--",
            )
        plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0], c="r")
        plt.show()

    # Early stopping procedure
    if val_mse < best_loss:
        best_loss = val_mse
        best_laplace_rep_func = deepcopy(laplace_rep_func.state_dict())
        best_encoder = deepcopy(encoder.state_dict())
        waiting = 0
    elif waiting > patience:
        break
    else:
        waiting += 1

# Load best model
laplace_rep_func.load_state_dict(best_laplace_rep_func)
encoder.load_state_dict(best_encoder)

torch.save(best_encoder, YOUR_PATH)
torch.save(best_laplace_rep_func, YOUR_PATH)

# Test step
laplace_rep_func.eval(), encoder.eval()
cum_test_loss = 0
cum_test_batches = 0
for batch in dltest:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(laplace_rep_func, p, tp_to_predict,  recon_dim=output_dim)
    print('predictions',predictions.shape)
    cum_test_loss += loss_fn(
        torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
    ).item()
    plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
    plt.plot(
        tp_to_predict.detach().cpu()[0],
        batch["data_to_predict"].detach().cpu()[0],
        "--",
    )
    plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0])
    plt.show()
    cum_test_batches += 1
test_mse = cum_test_loss / cum_test_batches
print(f"test_mse= {test_mse}")

plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
plt.plot(tp_to_predict.detach().cpu()[0],
batch["data_to_predict"].detach().cpu()[0],
"--",
)
plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0])
plt.show()

In a separate file I load the model and do inference ie


from torch import nn
import torch
import matplotlib.pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model (encoder and Laplace representation func)
class ReverseGRUEncoder(nn.Module):
    # Encodes observed trajectory into latent vector
    def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True):
        super(ReverseGRUEncoder, self).__init__()
        self.encode_obs_time = encode_obs_time
        if self.encode_obs_time:
            dimension_in += 1
        self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True)
        self.linear_out = nn.Linear(hidden_units, latent_dim)
        nn.init.xavier_uniform_(self.linear_out.weight)

    def forward(self, observed_data, observed_tp):
        trajs_to_encode = observed_data  # (batch_size, t_observed_dim, observed_dim)
        if self.encode_obs_time:
            trajs_to_encode = torch.cat(
                (
                    observed_data,
                    observed_tp.view(1, -1, 1).repeat(observed_data.shape[0], 1, 1),
                ),
                dim=2,
            )
        reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))
        out, _ = self.gru(reversed_trajs_to_encode)
        return self.linear_out(out[:, -1, :])

class LaplaceRepresentationFunc(nn.Module):
    # SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : b dim s reconstruction terms, k is latent encoding dimension, d is output dimension
    def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64):
        super(LaplaceRepresentationFunc, self).__init__()
        self.s_dim = s_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(s_dim * 2 + latent_dim, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, (s_dim) * 2 * output_dim),
        )

        for m in self.linear_tanh_stack.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

        phi_max = torch.pi / 2.0
        self.phi_scale = phi_max - -torch.pi / 2.0

    def forward(self, i):
        out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view(
            -1, 2 * self.output_dim, self.s_dim
        )
        theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi  # From - pi to + pi
        phi = (
            nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0
            - torch.pi / 2.0
            + self.phi_scale / 2.0
        )  # Form -pi / 2 to + pi / 2
        return theta, phi

from torch.utils.data import DataLoader
from torchlaplace.data_utils import basic_collate_fn

normalize_dataset = True
batch_size = 128
extrapolate = True

def sawtooth(trajectories_to_sample=100, t_nsamples=200):
    # Toy sawtooth waveform. Simple to generate, for Differential Equation Datasets see datasets.py (Note more complex DE take time to sample from, in some cases minutes).
    t_end = 20.0
    t_begin = t_end / t_nsamples
    ti = torch.linspace(t_begin, t_end, t_nsamples).to(device)

    def sampler(t, x0=0):
        return (t + x0) / (2 * torch.pi) - torch.floor((t + x0) / (2 * torch.pi))

    x0s = torch.linspace(0, 2 * torch.pi, trajectories_to_sample)
    trajs = []
    for x0 in x0s:
        trajs.append(sampler(ti, x0))
    y = torch.stack(trajs)
    trajectories = y.view(trajectories_to_sample, -1, 1)
    return trajectories, ti

trajectories, t = sawtooth(
    trajectories_to_sample=1000,
    t_nsamples=200,
)
if normalize_dataset:
    samples = trajectories.shape[0]
    dim = trajectories.shape[2]
    traj = (
        torch.reshape(trajectories, (-1, dim))
        - torch.reshape(trajectories, (-1, dim)).mean(0)
    ) / torch.reshape(trajectories, (-1, dim)).std(0)
    trajectories = torch.reshape(traj, (samples, -1, dim))
train_split = int(0.8 * trajectories.shape[0])
test_split = int(0.9 * trajectories.shape[0])
traj_index = torch.randperm(trajectories.shape[0])
train_trajectories = trajectories[traj_index[:train_split], :, :]
val_trajectories = trajectories[traj_index[train_split:test_split], :, :]
test_trajectories = trajectories[traj_index[test_split:], :, :]

input_dim = train_trajectories.shape[2]
output_dim = input_dim

dltest = DataLoader(
    test_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
input_dim = train_trajectories.shape[2]
output_dim = input_dim

latent_dim = 2
hidden_units = 64
encode_obs_time = True
s_recon_terms = 33

encoder = ReverseGRUEncoder(
    input_dim,
    latent_dim,
    hidden_units // 2,
    encode_obs_time=encode_obs_time,
).to(device)
laplace_rep_func = LaplaceRepresentationFunc(
    s_recon_terms, output_dim, latent_dim
).to(device)

encoder.load_state_dict(torch.load(YOUR_PATH))
laplace_rep_func.load_state_dict(torch.load(YOUR_PATH))

# Test step
from torchlaplace import laplace_reconstruct

laplace_rep_func.eval(), encoder.eval()
cum_test_loss = 0
cum_test_batches = 0
for batch in dltest:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(laplace_rep_func, p, tp_to_predict,  recon_dim=output_dim)
    print('predictions',predictions.shape)
    # cum_test_loss += loss_fn(
    #     torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
    # ).item()
    for i in range(predictions.shape[0]):
        plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[i])
        plt.plot(
            tp_to_predict.detach().cpu()[0],
            batch["data_to_predict"].detach().cpu()[0],
            "--",
        )
        plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[i])
        plt.show()
    cum_test_batches += 1
test_mse = cum_test_loss / cum_test_batches
print(f"test_mse= {test_mse}")

The testing results are not identical. Here is an issue that looks similar.

samholt commented 1 year ago

This is interesting. My first thoughts, are that it is good to always use double floating point precision for the dataset values and numbers, as although PyTorch uses single floating precision by default, double floating precision has been shown to increase the accuracy of numerical inverse Laplace transform algorithms, such as the ones used here.

Thank you for posting the code, I will investigate this further and check the hyper parameters from the paper as well. It may take me a while.

This has most likely happened, as the original code used for the paper was refactored heavily into this library. Thank you for pointing this out.

Just checking, a few questions:

thibmonsel commented 1 year ago

As far as I have tested, the baseline of lokta-volterra

python exp_all_baselines.py -d lotka_volterra_system_with_delay

yields a test rmse not in the interval of confidence (https://github.com/samholt/NeuralLaplace/issues/3#issue-1500496796). However, the python dde_demo.py --viz in the folder examples does have a test mse in the paper's range. But when the best model is saved and loaded into another file, we get some weird behavior.

samholt commented 1 year ago

Hi @thibmonsel

I was able to replicate the original results from the paper for the original requested comment of the test RMSE for Lokta volterra DDE.

For instance running python exp_all_baselines.py -d lotka_volterra_system_with_delay the result of one random seed run is 14:52:16,534 root INFO Result: Neural Laplace - TEST RMSE: 0.046900781612333464, which is within the test RMSE of .0475 ± .0061 as quoted in the paper and by yourself in your first comment.

I just refactored the code a little, cleaning up some new functionality in core, added some more unit tests and pushed and updated the latest PyPi package, bumping the repo's version to version 0.0.3.

I was able to run the google colabs, and their results for the colab scripted datasets correlate to the original papers results.

Could you uninstall your version of torchlaplace and install the latest version? That of version 0.0.3.

Please do let me know what the results are when you have updated the package, and re-run the original command you intended. If you still receive differing results, could you try using a virtual environment and installing all the packages in a fresh virtual environment?

Keen to resolve this. Thank you!

samholt commented 1 year ago

Thanks for the response ! I seem to have an issue with saving and reloading the trained model for further investigation. This is most likely an other issue not related to torchlaplace but Im not sure, I have never seen anything of the sort, I'm a bit clueless.

Here is the code snipped in order to reproduce the issue (normally you just need to give YOUR_PATH) I check the model weights are the same, models are in eval mode. This is maybe because of the GRU units however in ReverseGRUEncoder your first hidden state is 0 so this doesn't seem to be the problem.

I tried on my local machine and on colab too.

I basically took your Tutorial 1 code and saved the model

# Commented out IPython magic to ensure Python compatibility.
from torch import nn
import torch
import matplotlib.pyplot as plt
# %matplotlib inline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model (encoder and Laplace representation func)
class ReverseGRUEncoder(nn.Module):
    # Encodes observed trajectory into latent vector
    def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True):
        super(ReverseGRUEncoder, self).__init__()
        self.encode_obs_time = encode_obs_time
        if self.encode_obs_time:
            dimension_in += 1
        self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True)
        self.linear_out = nn.Linear(hidden_units, latent_dim)
        nn.init.xavier_uniform_(self.linear_out.weight)

    def forward(self, observed_data, observed_tp):
        trajs_to_encode = observed_data  # (batch_size, t_observed_dim, observed_dim)
        if self.encode_obs_time:
            trajs_to_encode = torch.cat(
                (
                    observed_data,
                    observed_tp.view(1, -1, 1).repeat(observed_data.shape[0], 1, 1),
                ),
                dim=2,
            )
        reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))
        out, _ = self.gru(reversed_trajs_to_encode)
        return self.linear_out(out[:, -1, :])

class LaplaceRepresentationFunc(nn.Module):
    # SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : b dim s reconstruction terms, k is latent encoding dimension, d is output dimension
    def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64):
        super(LaplaceRepresentationFunc, self).__init__()
        self.s_dim = s_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(s_dim * 2 + latent_dim, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, (s_dim) * 2 * output_dim),
        )

        for m in self.linear_tanh_stack.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

        phi_max = torch.pi / 2.0
        self.phi_scale = phi_max - -torch.pi / 2.0

    def forward(self, i):
        out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view(
            -1, 2 * self.output_dim, self.s_dim
        )
        theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi  # From - pi to + pi
        phi = (
            nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0
            - torch.pi / 2.0
            + self.phi_scale / 2.0
        )  # Form -pi / 2 to + pi / 2
        return theta, phi

from torch.utils.data import DataLoader
from torchlaplace.data_utils import basic_collate_fn

normalize_dataset = True
batch_size = 128
extrapolate = True

def sawtooth(trajectories_to_sample=100, t_nsamples=200):
    # Toy sawtooth waveform. Simple to generate, for Differential Equation Datasets see datasets.py (Note more complex DE take time to sample from, in some cases minutes).
    t_end = 20.0
    t_begin = t_end / t_nsamples
    ti = torch.linspace(t_begin, t_end, t_nsamples).to(device)

    def sampler(t, x0=0):
        return (t + x0) / (2 * torch.pi) - torch.floor((t + x0) / (2 * torch.pi))

    x0s = torch.linspace(0, 2 * torch.pi, trajectories_to_sample)
    trajs = []
    for x0 in x0s:
        trajs.append(sampler(ti, x0))
    y = torch.stack(trajs)
    trajectories = y.view(trajectories_to_sample, -1, 1)
    return trajectories, ti

trajectories, t = sawtooth(
    trajectories_to_sample=1000,
    t_nsamples=200,
)
if normalize_dataset:
    samples = trajectories.shape[0]
    dim = trajectories.shape[2]
    traj = (
        torch.reshape(trajectories, (-1, dim))
        - torch.reshape(trajectories, (-1, dim)).mean(0)
    ) / torch.reshape(trajectories, (-1, dim)).std(0)
    trajectories = torch.reshape(traj, (samples, -1, dim))
train_split = int(0.8 * trajectories.shape[0])
test_split = int(0.9 * trajectories.shape[0])
traj_index = torch.randperm(trajectories.shape[0])
train_trajectories = trajectories[traj_index[:train_split], :, :]
val_trajectories = trajectories[traj_index[train_split:test_split], :, :]
test_trajectories = trajectories[traj_index[test_split:], :, :]

input_dim = train_trajectories.shape[2]
output_dim = input_dim

dltrain = DataLoader(
    train_trajectories,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
dlval = DataLoader(
    val_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
dltest = DataLoader(
    test_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)

"""### Instantiate model and sample prediction"""

latent_dim = 2
hidden_units = 64
encode_obs_time = True
s_recon_terms = 33

encoder = ReverseGRUEncoder(
    input_dim,
    latent_dim,
    hidden_units // 2,
    encode_obs_time=encode_obs_time,
).to(device)
laplace_rep_func = LaplaceRepresentationFunc(
    s_recon_terms, output_dim, latent_dim
).to(device)

"""We can generate predictions using `laplace_reconstruct` as follows"""

from torchlaplace import laplace_reconstruct

laplace_rep_func.eval(), encoder.eval()
for batch in dlval:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(
        laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
    )

"""### Train end to end """

from time import time
from copy import deepcopy

learning_rate = 1e-3
epochs = 500
patience = None

if not patience:
    patience = epochs

params = list(laplace_rep_func.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)
loss_fn = torch.nn.MSELoss()

best_loss = float("inf")
waiting = 0

for epoch in range(epochs):
    iteration = 0
    epoch_train_loss_it_cum = 0
    start_time = time()
    laplace_rep_func.train(), encoder.train()
    for batch in dltrain:
        optimizer.zero_grad()
        trajs_to_encode = batch[
            "observed_data"
        ]  # (batch_size, t_observed_dim, observed_dim)
        observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
        p = encoder(
            trajs_to_encode, observed_tp
        )  # p is the latent tensor encoding the initial states
        tp_to_predict = batch["tp_to_predict"]
        predictions = laplace_reconstruct(
            laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
        )
        loss = loss_fn(
            torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 1)
        optimizer.step()
        epoch_train_loss_it_cum += loss.item()
        iteration += 1
    epoch_train_loss = epoch_train_loss_it_cum / iteration
    epoch_duration = time() - start_time

    # Validation step
    laplace_rep_func.eval(), encoder.eval()
    cum_val_loss = 0
    cum_val_batches = 0
    for batch in dlval:
        trajs_to_encode = batch[
            "observed_data"
        ]  # (batch_size, t_observed_dim, observed_dim)
        observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
        p = encoder(
            trajs_to_encode, observed_tp
        )  # p is the latent tensor encoding the initial states
        tp_to_predict = batch["tp_to_predict"]
        predictions = laplace_reconstruct(
            laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
        )
        cum_val_loss += loss_fn(
            torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
        ).item()            
        cum_val_batches += 1
    val_mse = cum_val_loss / cum_val_batches
    if epoch % 100 == 0:
        print(
            "[epoch={}] epoch_duration={:.2f} | train_loss={}\t| val_mse={}\t|".format(
                epoch, epoch_duration, epoch_train_loss, val_mse
            )
        )
        plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
        plt.plot(
                tp_to_predict.detach().cpu()[0],
                batch["data_to_predict"].detach().cpu()[0],
                "--",
            )
        plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0], c="r")
        plt.show()

    # Early stopping procedure
    if val_mse < best_loss:
        best_loss = val_mse
        best_laplace_rep_func = deepcopy(laplace_rep_func.state_dict())
        best_encoder = deepcopy(encoder.state_dict())
        waiting = 0
    elif waiting > patience:
        break
    else:
        waiting += 1

# Load best model
laplace_rep_func.load_state_dict(best_laplace_rep_func)
encoder.load_state_dict(best_encoder)

torch.save(best_encoder, YOUR_PATH)
torch.save(best_laplace_rep_func, YOUR_PATH)

# Test step
laplace_rep_func.eval(), encoder.eval()
cum_test_loss = 0
cum_test_batches = 0
for batch in dltest:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(laplace_rep_func, p, tp_to_predict,  recon_dim=output_dim)
    print('predictions',predictions.shape)
    cum_test_loss += loss_fn(
        torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
    ).item()
    plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
    plt.plot(
        tp_to_predict.detach().cpu()[0],
        batch["data_to_predict"].detach().cpu()[0],
        "--",
    )
    plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0])
    plt.show()
    cum_test_batches += 1
test_mse = cum_test_loss / cum_test_batches
print(f"test_mse= {test_mse}")

plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
plt.plot(tp_to_predict.detach().cpu()[0],
batch["data_to_predict"].detach().cpu()[0],
"--",
)
plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0])
plt.show()

In a separate file I load the model and do inference ie

from torch import nn
import torch
import matplotlib.pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Model (encoder and Laplace representation func)
class ReverseGRUEncoder(nn.Module):
    # Encodes observed trajectory into latent vector
    def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True):
        super(ReverseGRUEncoder, self).__init__()
        self.encode_obs_time = encode_obs_time
        if self.encode_obs_time:
            dimension_in += 1
        self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True)
        self.linear_out = nn.Linear(hidden_units, latent_dim)
        nn.init.xavier_uniform_(self.linear_out.weight)

    def forward(self, observed_data, observed_tp):
        trajs_to_encode = observed_data  # (batch_size, t_observed_dim, observed_dim)
        if self.encode_obs_time:
            trajs_to_encode = torch.cat(
                (
                    observed_data,
                    observed_tp.view(1, -1, 1).repeat(observed_data.shape[0], 1, 1),
                ),
                dim=2,
            )
        reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))
        out, _ = self.gru(reversed_trajs_to_encode)
        return self.linear_out(out[:, -1, :])

class LaplaceRepresentationFunc(nn.Module):
    # SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : b dim s reconstruction terms, k is latent encoding dimension, d is output dimension
    def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64):
        super(LaplaceRepresentationFunc, self).__init__()
        self.s_dim = s_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(s_dim * 2 + latent_dim, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, (s_dim) * 2 * output_dim),
        )

        for m in self.linear_tanh_stack.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

        phi_max = torch.pi / 2.0
        self.phi_scale = phi_max - -torch.pi / 2.0

    def forward(self, i):
        out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view(
            -1, 2 * self.output_dim, self.s_dim
        )
        theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi  # From - pi to + pi
        phi = (
            nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0
            - torch.pi / 2.0
            + self.phi_scale / 2.0
        )  # Form -pi / 2 to + pi / 2
        return theta, phi

from torch.utils.data import DataLoader
from torchlaplace.data_utils import basic_collate_fn

normalize_dataset = True
batch_size = 128
extrapolate = True

def sawtooth(trajectories_to_sample=100, t_nsamples=200):
    # Toy sawtooth waveform. Simple to generate, for Differential Equation Datasets see datasets.py (Note more complex DE take time to sample from, in some cases minutes).
    t_end = 20.0
    t_begin = t_end / t_nsamples
    ti = torch.linspace(t_begin, t_end, t_nsamples).to(device)

    def sampler(t, x0=0):
        return (t + x0) / (2 * torch.pi) - torch.floor((t + x0) / (2 * torch.pi))

    x0s = torch.linspace(0, 2 * torch.pi, trajectories_to_sample)
    trajs = []
    for x0 in x0s:
        trajs.append(sampler(ti, x0))
    y = torch.stack(trajs)
    trajectories = y.view(trajectories_to_sample, -1, 1)
    return trajectories, ti

trajectories, t = sawtooth(
    trajectories_to_sample=1000,
    t_nsamples=200,
)
if normalize_dataset:
    samples = trajectories.shape[0]
    dim = trajectories.shape[2]
    traj = (
        torch.reshape(trajectories, (-1, dim))
        - torch.reshape(trajectories, (-1, dim)).mean(0)
    ) / torch.reshape(trajectories, (-1, dim)).std(0)
    trajectories = torch.reshape(traj, (samples, -1, dim))
train_split = int(0.8 * trajectories.shape[0])
test_split = int(0.9 * trajectories.shape[0])
traj_index = torch.randperm(trajectories.shape[0])
train_trajectories = trajectories[traj_index[:train_split], :, :]
val_trajectories = trajectories[traj_index[train_split:test_split], :, :]
test_trajectories = trajectories[traj_index[test_split:], :, :]

input_dim = train_trajectories.shape[2]
output_dim = input_dim

dltest = DataLoader(
    test_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
input_dim = train_trajectories.shape[2]
output_dim = input_dim

latent_dim = 2
hidden_units = 64
encode_obs_time = True
s_recon_terms = 33

encoder = ReverseGRUEncoder(
    input_dim,
    latent_dim,
    hidden_units // 2,
    encode_obs_time=encode_obs_time,
).to(device)
laplace_rep_func = LaplaceRepresentationFunc(
    s_recon_terms, output_dim, latent_dim
).to(device)

encoder.load_state_dict(torch.load(YOUR_PATH))
laplace_rep_func.load_state_dict(torch.load(YOUR_PATH))

# Test step
from torchlaplace import laplace_reconstruct

laplace_rep_func.eval(), encoder.eval()
cum_test_loss = 0
cum_test_batches = 0
for batch in dltest:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(laplace_rep_func, p, tp_to_predict,  recon_dim=output_dim)
    print('predictions',predictions.shape)
    # cum_test_loss += loss_fn(
    #     torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
    # ).item()
    for i in range(predictions.shape[0]):
        plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[i])
        plt.plot(
            tp_to_predict.detach().cpu()[0],
            batch["data_to_predict"].detach().cpu()[0],
            "--",
        )
        plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[i])
        plt.show()
    cum_test_batches += 1
test_mse = cum_test_loss / cum_test_batches
print(f"test_mse= {test_mse}")

The testing results are not identical. Here is an issue that looks similar.

Hi @thibmonsel

I was able to fix your posted code example, and have posted it below.

A few points though. I would not recommend using the collab tutorials, as these were designed to show quickly how to use NeuralLaplace, and get something working quickly. For a full and concrete example I recommend building on the examples in experiments/exp_all_baselines.py as these were used in the paper, and I still use them today with future projects. These include a key part, of using double precision for the model and the dataset, which improves accuracy. There are a few examples showing this as well in examples/simple_demo_double_precision.py. Most importantly experiments/exp_all_baselines.py includes random seed setting, which makes all the results consistent and reproducible, the collab scripts lack this.

For your specific example highlighted a few things are incorrect. The "YOURPATH" doesn't work like that, you need to specify a separate path for each model, i.e. one for the encoder and one for the Laplace representation function separately.

On the loading script I think there was a mistake in the data generation process, therefore I copied it exactly from the first script, and have included setting the torch random seed for both scripts. Running the following two scripts trains the model and then loads it and both scripts produce identical test MSE's.


# Commented out IPython magic to ensure Python compatibility.
from torch import nn
import torch
import matplotlib.pyplot as plt
# %matplotlib inline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.random.manual_seed(0)

# Model (encoder and Laplace representation func)
class ReverseGRUEncoder(nn.Module):
    # Encodes observed trajectory into latent vector
    def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True):
        super(ReverseGRUEncoder, self).__init__()
        self.encode_obs_time = encode_obs_time
        if self.encode_obs_time:
            dimension_in += 1
        self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True)
        self.linear_out = nn.Linear(hidden_units, latent_dim)
        nn.init.xavier_uniform_(self.linear_out.weight)

    def forward(self, observed_data, observed_tp):
        trajs_to_encode = observed_data  # (batch_size, t_observed_dim, observed_dim)
        if self.encode_obs_time:
            trajs_to_encode = torch.cat(
                (
                    observed_data,
                    observed_tp.view(1, -1, 1).repeat(observed_data.shape[0], 1, 1),
                ),
                dim=2,
            )
        reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))
        out, _ = self.gru(reversed_trajs_to_encode)
        return self.linear_out(out[:, -1, :])

class LaplaceRepresentationFunc(nn.Module):
    # SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : b dim s reconstruction terms, k is latent encoding dimension, d is output dimension
    def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64):
        super(LaplaceRepresentationFunc, self).__init__()
        self.s_dim = s_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(s_dim * 2 + latent_dim, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, (s_dim) * 2 * output_dim),
        )

        for m in self.linear_tanh_stack.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

        phi_max = torch.pi / 2.0
        self.phi_scale = phi_max - -torch.pi / 2.0

    def forward(self, i):
        out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view(
            -1, 2 * self.output_dim, self.s_dim
        )
        theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi  # From - pi to + pi
        phi = (
            nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0
            - torch.pi / 2.0
            + self.phi_scale / 2.0
        )  # Form -pi / 2 to + pi / 2
        return theta, phi

from torch.utils.data import DataLoader
from torchlaplace.data_utils import basic_collate_fn

normalize_dataset = True
batch_size = 128
extrapolate = True

def sawtooth(trajectories_to_sample=100, t_nsamples=200):
    # Toy sawtooth waveform. Simple to generate, for Differential Equation Datasets see datasets.py (Note more complex DE take time to sample from, in some cases minutes).
    t_end = 20.0
    t_begin = t_end / t_nsamples
    ti = torch.linspace(t_begin, t_end, t_nsamples).to(device)

    def sampler(t, x0=0):
        return (t + x0) / (2 * torch.pi) - torch.floor((t + x0) / (2 * torch.pi))

    x0s = torch.linspace(0, 2 * torch.pi, trajectories_to_sample)
    trajs = []
    for x0 in x0s:
        trajs.append(sampler(ti, x0))
    y = torch.stack(trajs)
    trajectories = y.view(trajectories_to_sample, -1, 1)
    return trajectories, ti

trajectories, t = sawtooth(
    trajectories_to_sample=1000,
    t_nsamples=200,
)
if normalize_dataset:
    samples = trajectories.shape[0]
    dim = trajectories.shape[2]
    traj = (
        torch.reshape(trajectories, (-1, dim))
        - torch.reshape(trajectories, (-1, dim)).mean(0)
    ) / torch.reshape(trajectories, (-1, dim)).std(0)
    trajectories = torch.reshape(traj, (samples, -1, dim))
train_split = int(0.8 * trajectories.shape[0])
test_split = int(0.9 * trajectories.shape[0])
traj_index = torch.randperm(trajectories.shape[0])
train_trajectories = trajectories[traj_index[:train_split], :, :]
val_trajectories = trajectories[traj_index[train_split:test_split], :, :]
test_trajectories = trajectories[traj_index[test_split:], :, :]

input_dim = train_trajectories.shape[2]
output_dim = input_dim

dltrain = DataLoader(
    train_trajectories,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
dlval = DataLoader(
    val_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
dltest = DataLoader(
    test_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)

"""### Instantiate model and sample prediction"""

latent_dim = 2
hidden_units = 64
encode_obs_time = True
s_recon_terms = 33

encoder = ReverseGRUEncoder(
    input_dim,
    latent_dim,
    hidden_units // 2,
    encode_obs_time=encode_obs_time,
).to(device)
laplace_rep_func = LaplaceRepresentationFunc(
    s_recon_terms, output_dim, latent_dim
).to(device)

"""We can generate predictions using `laplace_reconstruct` as follows"""

from torchlaplace import laplace_reconstruct

laplace_rep_func.eval(), encoder.eval()
for batch in dlval:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(
        laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
    )

"""### Train end to end """

from time import time
from copy import deepcopy

learning_rate = 1e-3
epochs = 500
patience = None

if not patience:
    patience = epochs

params = list(laplace_rep_func.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)
loss_fn = torch.nn.MSELoss()

best_loss = float("inf")
waiting = 0

for epoch in range(epochs):
    iteration = 0
    epoch_train_loss_it_cum = 0
    start_time = time()
    laplace_rep_func.train(), encoder.train()
    for batch in dltrain:
        optimizer.zero_grad()
        trajs_to_encode = batch[
            "observed_data"
        ]  # (batch_size, t_observed_dim, observed_dim)
        observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
        p = encoder(
            trajs_to_encode, observed_tp
        )  # p is the latent tensor encoding the initial states
        tp_to_predict = batch["tp_to_predict"]
        predictions = laplace_reconstruct(
            laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
        )
        loss = loss_fn(
            torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 1)
        optimizer.step()
        epoch_train_loss_it_cum += loss.item()
        iteration += 1
    epoch_train_loss = epoch_train_loss_it_cum / iteration
    epoch_duration = time() - start_time

    # Validation step
    laplace_rep_func.eval(), encoder.eval()
    cum_val_loss = 0
    cum_val_batches = 0
    for batch in dlval:
        trajs_to_encode = batch[
            "observed_data"
        ]  # (batch_size, t_observed_dim, observed_dim)
        observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
        p = encoder(
            trajs_to_encode, observed_tp
        )  # p is the latent tensor encoding the initial states
        tp_to_predict = batch["tp_to_predict"]
        predictions = laplace_reconstruct(
            laplace_rep_func, p, tp_to_predict, recon_dim=output_dim
        )
        cum_val_loss += loss_fn(
            torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
        ).item()            
        cum_val_batches += 1
    val_mse = cum_val_loss / cum_val_batches
    if epoch % 100 == 0:
        print(
            "[epoch={}] epoch_duration={:.2f} | train_loss={}\t| val_mse={}\t|".format(
                epoch, epoch_duration, epoch_train_loss, val_mse
            )
        )
        plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
        plt.plot(
                tp_to_predict.detach().cpu()[0],
                batch["data_to_predict"].detach().cpu()[0],
                "--",
            )
        plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0], c="r")
        plt.show()

    # Early stopping procedure
    if val_mse < best_loss:
        best_loss = val_mse
        best_laplace_rep_func = deepcopy(laplace_rep_func.state_dict())
        best_encoder = deepcopy(encoder.state_dict())
        waiting = 0
    elif waiting > patience:
        break
    else:
        waiting += 1

# Load best model
laplace_rep_func.load_state_dict(best_laplace_rep_func)
encoder.load_state_dict(best_encoder)

torch.save(best_encoder, "./encoder.pt")
torch.save(best_laplace_rep_func, "./laplace_rep.pt")

# Test step
laplace_rep_func.eval(), encoder.eval()
cum_test_loss = 0
cum_test_batches = 0
for batch in dltest:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(laplace_rep_func, p, tp_to_predict,  recon_dim=output_dim)
    print('predictions',predictions.shape)
    cum_test_loss += loss_fn(
        torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
    ).item()
    plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
    plt.plot(
        tp_to_predict.detach().cpu()[0],
        batch["data_to_predict"].detach().cpu()[0],
        "--",
    )
    plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0])
    plt.show()
    cum_test_batches += 1
test_mse = cum_test_loss / cum_test_batches
print(f"test_mse= {test_mse}")

plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
plt.plot(tp_to_predict.detach().cpu()[0],
batch["data_to_predict"].detach().cpu()[0],
"--",
)
plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0])
plt.show()

Output

[epoch=0] epoch_duration=0.09 | train_loss=1.1995722906930106 | val_mse=1.0132559537887573 | [epoch=100] epoch_duration=0.02 | train_loss=0.1175929041845458 | val_mse=0.11761683225631714 | [epoch=200] epoch_duration=0.02 | train_loss=0.07985144640718188 | val_mse=0.07714516669511795 | [epoch=300] epoch_duration=0.02 | train_loss=0.09400784117834908 | val_mse=0.077095627784729 | [epoch=400] epoch_duration=0.02 | train_loss=0.05828039348125458 | val_mse=0.06377138197422028 | predictions torch.Size([100, 100, 1]) test_mse= 0.04920018091797829

Then run this separately

# Commented out IPython magic to ensure Python compatibility.
from torch import nn
import torch
import matplotlib.pyplot as plt
# %matplotlib inline
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.random.manual_seed(0)

# Model (encoder and Laplace representation func)
class ReverseGRUEncoder(nn.Module):
    # Encodes observed trajectory into latent vector
    def __init__(self, dimension_in, latent_dim, hidden_units, encode_obs_time=True):
        super(ReverseGRUEncoder, self).__init__()
        self.encode_obs_time = encode_obs_time
        if self.encode_obs_time:
            dimension_in += 1
        self.gru = nn.GRU(dimension_in, hidden_units, 2, batch_first=True)
        self.linear_out = nn.Linear(hidden_units, latent_dim)
        nn.init.xavier_uniform_(self.linear_out.weight)

    def forward(self, observed_data, observed_tp):
        trajs_to_encode = observed_data  # (batch_size, t_observed_dim, observed_dim)
        if self.encode_obs_time:
            trajs_to_encode = torch.cat(
                (
                    observed_data,
                    observed_tp.view(1, -1, 1).repeat(observed_data.shape[0], 1, 1),
                ),
                dim=2,
            )
        reversed_trajs_to_encode = torch.flip(trajs_to_encode, (1,))
        out, _ = self.gru(reversed_trajs_to_encode)
        return self.linear_out(out[:, -1, :])

class LaplaceRepresentationFunc(nn.Module):
    # SphereSurfaceModel : C^{b+k} -> C^{bxd} - In Riemann Sphere Co ords : b dim s reconstruction terms, k is latent encoding dimension, d is output dimension
    def __init__(self, s_dim, output_dim, latent_dim, hidden_units=64):
        super(LaplaceRepresentationFunc, self).__init__()
        self.s_dim = s_dim
        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(s_dim * 2 + latent_dim, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, hidden_units),
            nn.Tanh(),
            nn.Linear(hidden_units, (s_dim) * 2 * output_dim),
        )

        for m in self.linear_tanh_stack.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

        phi_max = torch.pi / 2.0
        self.phi_scale = phi_max - -torch.pi / 2.0

    def forward(self, i):
        out = self.linear_tanh_stack(i.view(-1, self.s_dim * 2 + self.latent_dim)).view(
            -1, 2 * self.output_dim, self.s_dim
        )
        theta = nn.Tanh()(out[:, : self.output_dim, :]) * torch.pi  # From - pi to + pi
        phi = (
            nn.Tanh()(out[:, self.output_dim :, :]) * self.phi_scale / 2.0
            - torch.pi / 2.0
            + self.phi_scale / 2.0
        )  # Form -pi / 2 to + pi / 2
        return theta, phi

from torch.utils.data import DataLoader
from torchlaplace.data_utils import basic_collate_fn

normalize_dataset = True
batch_size = 128
extrapolate = True
loss_fn = torch.nn.MSELoss()

def sawtooth(trajectories_to_sample=100, t_nsamples=200):
    # Toy sawtooth waveform. Simple to generate, for Differential Equation Datasets see datasets.py (Note more complex DE take time to sample from, in some cases minutes).
    t_end = 20.0
    t_begin = t_end / t_nsamples
    ti = torch.linspace(t_begin, t_end, t_nsamples).to(device)

    def sampler(t, x0=0):
        return (t + x0) / (2 * torch.pi) - torch.floor((t + x0) / (2 * torch.pi))

    x0s = torch.linspace(0, 2 * torch.pi, trajectories_to_sample)
    trajs = []
    for x0 in x0s:
        trajs.append(sampler(ti, x0))
    y = torch.stack(trajs)
    trajectories = y.view(trajectories_to_sample, -1, 1)
    return trajectories, ti

trajectories, t = sawtooth(
    trajectories_to_sample=1000,
    t_nsamples=200,
)
if normalize_dataset:
    samples = trajectories.shape[0]
    dim = trajectories.shape[2]
    traj = (
        torch.reshape(trajectories, (-1, dim))
        - torch.reshape(trajectories, (-1, dim)).mean(0)
    ) / torch.reshape(trajectories, (-1, dim)).std(0)
    trajectories = torch.reshape(traj, (samples, -1, dim))
train_split = int(0.8 * trajectories.shape[0])
test_split = int(0.9 * trajectories.shape[0])
traj_index = torch.randperm(trajectories.shape[0])
train_trajectories = trajectories[traj_index[:train_split], :, :]
val_trajectories = trajectories[traj_index[train_split:test_split], :, :]
test_trajectories = trajectories[traj_index[test_split:], :, :]

input_dim = train_trajectories.shape[2]
output_dim = input_dim

dltrain = DataLoader(
    train_trajectories,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
dlval = DataLoader(
    val_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)
dltest = DataLoader(
    test_trajectories,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: basic_collate_fn(
        batch,
        t,
        data_type="train",
        extrap=extrapolate,
    ),
)

input_dim = train_trajectories.shape[2]
output_dim = input_dim

latent_dim = 2
hidden_units = 64
encode_obs_time = True
s_recon_terms = 33

encoder = ReverseGRUEncoder(
    input_dim,
    latent_dim,
    hidden_units // 2,
    encode_obs_time=encode_obs_time,
).to(device)
laplace_rep_func = LaplaceRepresentationFunc(
    s_recon_terms, output_dim, latent_dim
).to(device)

encoder.load_state_dict(torch.load("./encoder.pt"))
laplace_rep_func.load_state_dict(torch.load("./laplace_rep.pt"))

# Test step
from torchlaplace import laplace_reconstruct

# Test step
laplace_rep_func.eval(), encoder.eval()
cum_test_loss = 0
cum_test_batches = 0
for batch in dltest:
    trajs_to_encode = batch[
        "observed_data"
    ]  # (batch_size, t_observed_dim, observed_dim)
    observed_tp = batch["observed_tp"]  # (1, t_observed_dim)
    p = encoder(
        trajs_to_encode, observed_tp
    )  # p is the latent tensor encoding the initial states
    tp_to_predict = batch["tp_to_predict"]
    predictions = laplace_reconstruct(laplace_rep_func, p, tp_to_predict,  recon_dim=output_dim)
    print('predictions',predictions.shape)
    cum_test_loss += loss_fn(
        torch.flatten(predictions), torch.flatten(batch["data_to_predict"])
    ).item()
    plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
    plt.plot(
        tp_to_predict.detach().cpu()[0],
        batch["data_to_predict"].detach().cpu()[0],
        "--",
    )
    plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0])
    plt.show()
    cum_test_batches += 1
test_mse = cum_test_loss / cum_test_batches
print(f"test_mse= {test_mse}")

plt.plot(observed_tp.detach().cpu()[0], trajs_to_encode.detach().cpu()[0])
plt.plot(tp_to_predict.detach().cpu()[0],
batch["data_to_predict"].detach().cpu()[0],
"--",
)
plt.plot(tp_to_predict.detach().cpu()[0], predictions.detach().cpu()[0])
plt.show()

Output

predictions torch.Size([100, 100, 1]) test_mse= 0.04920018091797829

Which is identical. I am closing this, if you do not get the same results please do re-open and let me know. Please do let me know if this fixes your issue.

Thank you!

thibmonsel commented 1 year ago

@samholt , everything looks good, I can reproduce the results, a clean env did the trick ! I appreciate your help, im going to close this issue. Cheers