This repository is meant to conceptually introduce and highlight implementation considerations for the recent class of models called Neural State-Space Models (Neural SSMs). They leverage the classic state-space model with the flexibility of deep learning to approach high-dimensional generative time-series modeling and learning latent dynamics functions.
Included is an abstract PyTorch-Lightning training class with several latent dynamic functions that inherit it, as well as common metrics used in their evaluation and training examples on common datasets. Further broken down via implementation is the distinction between system identification and state estimation approaches, which are reminiscent of their classic SSM counterparts and arise from fundamental differences in the underlying choice of their probabilistic graphical model (PGM). This repository (currently) focuses primarily on considerations related to training dynamics models for system identification and forecasting rather than per-frame state estimation or filtering.
Note: This repo is not fully finished and some of the experiments/sections may be incomplete. This is released as public in order to maximize the potential benefit of this repo and hopefully inspire collaboration in improving it. Feel free to check out the "To-Do" section if you're interesting in contributing!
Fig 1. Schematic of the two PGM forms of Neural SSMs.
If you found the information helpful for your work or use portions of this repo in research development, please consider citing one of the following works:
@misc{missel2022torchneuralssm,
title={TorchNeuralSSM},
author={Missel, Ryan},
publisher={Github},
journal={Github repository},
howpublished={\url{https://github.com/qu-gg/torchssm}},
year={2022},
}
@inproceedings{jiangsequentialLVM,
title={Sequential Latent Variable Models for Few-Shot High-Dimensional Time-Series Forecasting},
author={Jiang, Xiajun and Missel, Ryan and Li, Zhiyuan and Wang, Linwei},
booktitle={The Eleventh International Conference on Learning Representations}
}
This section provides an introduction to the concept of Neural SSMs, some common considerations and limitations, and active areas of research. This section assumes some familiarity with state-space models, though little background is needed to gain a conceptual understanding if one is already coming from a latent modeling perspective or Bayesian learning. Resources are available in abundance considering the width and depth of state-space usage, however, this video and modern textbook are good starting points.
Variational Auto-encoders (VAEs): VAEs[28] provide a principled and popular framework to learn the generative model pθ(x|z) behind data x, involving latent variables z that follows a prior distribution p(z). Variational inference over the generative model is facilitated by a variational approximation of the posterior density of latent variables z, in the form of a recognition model qφ(z|x). Parameters of both the generative and recognition models are optimized with the objective to maximize the evidence lower bound (ELBO) of the marginal data likelihood:
where the first term encourages the reconstruction of the observed data, and the second term of Kullback–Leibler (KL) divergence constrains the estimated posterior density of qφ(z|x) with a pre-defined prior p(z), often assumed to be a zero-mean isotropic Gaussian density.
An extension of classic state-space models, neural state-space models - at their core - consist of a dynamic function of some latent states z_k and their emission to observations x_k, realized through the equations:
where θz represents the parameters of the latent dynamic function. The precise form of these functions can vary significantly - from deterministic or stochastic, linear or non-linear, and discrete or continuous.
Due to their explicit differentiation of transition and emission and leveraging of structured equations, they have found success in learning interpretable latent dynamic spaces[1,2,3], identifying physical systems from non-direct features[4,5,6], and uses in counterfactual forecasting[7,8,14].
Given the fast pace of progress in latent dynamics modeling over recent years, many models have been presented under a variety of terminologies and proposed frameworks - examples being variational latent recurrent models[5,9,10,11,12,22], deep state-space models[1,2,3,7,13,14], and deterministic encoding-decoding models[4,15,16]. Despite differences in appearance, they all adhere to the same conceptual framework of latent variable modeling and state-space disentanglement. As such, here we unify them under the terminology of Neural SSMs and segment them into the two base choices of probabilistic graphical models that they adhere to: system identification and state estimation. We highlight each PGM's properties and limitations with experimental evaluations on benchmark datasets.
The PGM associated with each approach is determined by the latent variable chosen for inference.
Fig 2. Schematic of latent variable PGMs in Neural SSMS.
System states as latent variables (State Estimation): The intuitive choice for the latent variable is the latent state z_k that underlies x_k, given that it is already latent in the system and is directly associated with the observations. The PGM of this form is shown under Fig. 1A where its marginal likelihood over an observed sequence x0:T is written as:
where p(xi | zi) describes the emission model and p(zi | z<i, x<i) describes the latent dynamics function. Given the common intractability of the posterior, parameter inference is performed through a variational approximation of the posterior density q(z0:T | x0:T), expressed as:
Given these two components, the standard training objective of the Evidence Lower Bound Objective (ELBO) is thus derived with the form:
where the first term represents a reconstruction likelihood term over the sequence and the second is a Kullback-Leibler divergence loss between the variational posterior approximation and some assumed prior of the latent dynamics. This prior can come in many forms, either being the standard Gaussian Normal in variational inference, flow-based priors from ODE settings[5], or physics-based priors in problem-specific situations[20]. This is the primary design choice that separates current works in this area, specifically the modeling of the dynamics prior and its learned approximation. Many works draw inspiration for modeling this interaction by filtering techniques in standard SSMs, where a divergence term is constructed between the dynamics-predicted latent state and the data-corrected observation[7,18].
With this formulation, it is easy to see how dynamics models of this type can have a strong reconstructive capacity for the high-dimensional outputs and contain strong short-term predictions. In addition, input-influenced dynamics are inherent to the prediction task, as errors in the predictions of the latent dynamics are corrected by true observations every step. However, given this data-based correction, the resulting inference of q(zi | z<i, x<i) is weakened, and without near-term observations to guide the dynamics function, its long-horizon forecasting is limited[1,3].
System parameters as latent variables (System Identification): Rather than system states, one can instead choose to select the parameters (denoted as θz in Equation 1). With this change, the resulting PGM is represented in Fig. 1B and its marginal likelihood over x0:T is represented now by:
where the resulting output observations are derived from an initial latent state z0 and the dynamics parameters θz. As before, a variational approximation is considered for inference in place of an intractable posterior but now for the density q(θz, z0) instead. Given prior density assumptions of p(θz) and p(z0) in a similar vein as above, the ELBO function for this PGM is constructed as:
where again the first term is a reconstruction likelihood and the terms following represent KL-Divergence losses over the inferred variables.
The given formulation here is the most general form for this line of models and other works can be covered under special assumptions or restrictions of how q(θz) and p(θz) are modeled. Original Neural SSM parameter works consider Linear-Gaussian SSMs as the transition function and introduce non-linearity by varying the transition parameters over time as θz0:T[1,2,3]. However, as shown in Fig. 2B1, the result of this results in convoluted temporal modeling and devolves into the same state estimation problem as now the time-varying parameters rely on near-term observations for correctness[8,20]. Rather than time-varying, the system parameters could be considered an optimized global variable, in which the underlying dynamics function becomes a Bayesian neural network in a VAE's latent space[5] and is shown in Fig. 2B2. Restricting these parameters to be deterministic results in a model of the form presented in Latent ODE[10]. The furthest restriction in forgoing stochasticity in the inference of z0 results in the suite of models as presented in [4].
Regardless of the precise assumptions, this framework builds a strong latent dynamics function that enables long-term forecasting and, in some settings, even full-scale system identification[1,4] of physical systems. This is done at the cost of a harder inference task given no access to dynamics correction during generation and for full identification tasks, often requires a large number of training samples over the potential system state space[4,5].
As the transition dynamics and the observation space are intentionally disconnected in this framework, the problem of inferring a strong initial latent state from which to forecast is an important consideration when designing a neural state-space model[30]. This is primarily a task- and data-dependent choice, in which the architecture follows the data structure. Thankfully, much work has been done in other research directions on designing good latent encoding models. As such, works in this area often draw from them. This section is split into three parts - one on the usual architecture for high-dimensional image tasks, one on lower-dimensional and/or miscellaneous encoders, and one on the different forms of inference for the initial state depending on which sequence portions are observed.
Image-based Encoders: Unsurprisingly, the common architecture used in latent image encoding is a convolutional neural network (CNN) given its inherent bias toward spatial feature extraction[1,3,4,5]. Works are mixed between either having the sequential input reshaped as frames stacked over the channel dimension or simply running the CNN over each observed frame separately and passing the concatenated embedding into an output layer. Regardless of methodology, a few frames are assumed as observations for initialization, as multiple timesteps are required to infer the initial system movement. A subset of works considers second-order latent vector spaces, in which the encoder is explicitly split into two individual position and momenta functions, taking single and multiple frames respectively[5].
Fig N. Visualization of the stacked initial state encoder, modified from [23].
Alternate Encoders: In settings with non-image-based inputs, the initial latent encoder can take on a large variety of forms, ranging anywhere from simple linear/MLP networks in physical systems[5] to graph convolution networks for latent medical image forecasting[20]. Multi-modal and dynamics conditioning inputs can be leveraged via combinations of encoders whose embeddings go through a shared linear function.
Fig N. Visualization of the stacked graph convolutional encoder, modified from [24].
Variables z0, zk, and zinit: Beyond just the inference of this latent variable, there is one more variation that can be seen throughout literature - that of which portions of the input sequence are observed and used in the initial state inference.
Generally, there are 3 forms seen:
Throughout literature, these variable names as shown here aren't used (as most works just call it z0 and describe its inference) but we differentiate it specifically to highlight the distinctions. For training purspoes, it is a subtle distinction but potentially has implications for the resulting l ikelihood optimization and learned vector space.
Fig N. Schematic of the difference between z0 and zinit formulations. Saying that, generally there is a lack of work exploring the considerations for each approach, besides ad-hoc solutions to bridge the gap between the latent encoder and dynamics function distributions[5]. This gap can stem from optimization problems caused by imbalanced reconstruction terms between dynamics and initial states or in cases where the initial state distribution is far enough away from the data distribution of downstream frames. However, a recent work "Learning Neural State-Space Models: Do we need a state estimator?" [30] is the first detailed study into the considerations of initial state inference, providing ablations across increasing difficulties of datasets and inference forms. In their work, they found that to get competitive performance of neural SSMs on some dynamical systems, more advanced architectures were required (feed-forward or LSTM networks). Notably, they only evaluate on the zk form, varying architectural choices. A variety of empirical techniques have been proposed to tackle this gap, much in the same spirit of empirical VAE stability 'tricks.' These include separated x0 and x1:T terms (where x0 has a positive weighting coefficient), VAE pre-training for x0, and KL-regularization terms between the output distributions of the encoder and the dynamics flow[1,5]. One personal intuition regarding these two variable approaches and the tricks applied is that there exists a theoretical trade-off between the two formulations and the tricks applied help to empirically alleviate the shortcomings of either approach. This, however, requires experimentation and validation before any claims can be made. ## Reconstruction vs. Extrapolation There are three important phases during the forecasting for a neural SSM, that of initial state inference, reconstruction, and extrapolation.
Fig N. Breakdown of the three forecasting phases - initial state inference, reconstruction, and extrapolation.
Initial State Inference: Inferring the initial state and how many frames are required to get a good initialization is fairly domain/problem specific, as each problem may require more or less time to highlight distinctive patterns that enable effective dynamics separation. Reconstruction: The former refers to the number of timesteps that are used in training, from which the likelihood term is calculated. So far in works, there is no generally agreed upon standard on how many steps to use in this and works can be seen using anywhere from 1 (i.e. next-step prediction) to 60 frames in this portion[4]. Some works frame this as a hyper-parameter to tune in experiments and there is a consideration of computational cost when scaling up to longer sequences. In our experiments, we've noticed a linear scaling in training time w.r.t. this sequence length. (TO-DO) In the Experiments section, we perform an ablation study on how fixed lengths of reconstruction affects the extrapolation ability of models on Hamiltonian systems. Extrapolation: This phase refers to the arbitrarily long forecasting of frames that goes beyond the length used in the likelihood term during training. It represents whether a model has captured the system dynamics sufficiently to enable long-term forecasting or model energy decay in non-conserving systems. For specific dynamical systems, this can be a difficult task as, at base, there is no training signal to inform the model to learn good extrapolation. Works often highlight metrics independently on reconstruction and extrapolation phases to highlight a model's strength of identification[4]. Training Considerations: It is important to note that the exact structure of how the likelihood loss is formulated plays a role in how this sequence length may affect extrapolation. Having your likelihood incorporate temporal information (e.g. summation over the sequence, trajectory mean, etc.) can have a detrimental effect on extrapolation as the model optimizes with respect to the fixed reconstruction length. Figure N highlights an example of using temporal information in a likelihood term, where there is near flawless reconstruction but immediate forecasting failure when going towards extrapolation.Fig N. Example of failed extrapolation given an incorrect likelihood term. Red highlights beginning of extrapoolation.
As well, it is often the case where the reconstruction training metrics (e.g. likelihood/pixel MSE) and visualizations will often show strong convergence despite still poor extrapolation. It can sometimes be the case, especially in Neural ODE latent dynamics, that more training than expected is required to enable strong extrapolation. It is an intuition that the vector field may require a longer optimization than just the reconstruction convergence to be robust against error accumulation that impacts long-horizon forecasting.Fig N. Training vs. Validation pixel MSE metrics, highlight the continued extrapolation learning past training "convergence."
Tips for training good extrapolation in these models include:For Neural SSMs, a variety of approaches have been taken thus far depending on the type of latent transition function used.
Linear Dynamics: In latent dynamics still parameterized by traditional linear gaussian transition functions, control incorporation is as easy as the addition of another transition matrix Bt that modifies a control input ut at each timestep[1,2,4,7].
Fig N. Example of control input in a linear transition function[1].
Non-Linear Dynamics: In discrete non-linear transition matrices using either multi-layer perceptrons or recurrent cells, these can be leveraged by either concatenating it to the input vector before the network forward pass or as a data transformation in the form of element-wise addition and a weighted combination[10].
Fig N. Example of control input in a non-linear transition function[1].
Continuous Dynamics: For incorporation into continuous latent dynamics functions, finding the best approaches is an ongoing topic of interest. Thus far, the reigning approaches are: 1. Directly jumping the vector field state with recurrent cells[18] 2. Influencing the vector field gradient (e.g. neural controlled differential equations)[17] 3. Introducing another dynamics mechanism, continuous or otherwise (e.g. neural ODE or attention blocks), that is combined with the latent trajectory z1:T into an auxiliary state h1:T[8,14,25].
Fig N. Visualization of the IMODE architecture, taken from [8].
# Implementation In this section, specifics on model implementation and the datasets/metrics used are detailed. Specific data generation details are available in the URLs provided for each dataset. The models and datasets used throughout this repo are solely grayscale physics datasets with underlying Hamiltonian laws, such as pendulum and mass-spring sets. Extensions to color images and non-pixel-based tasks (or even graph-based data!) are easily done in this framework, as the only architecture change needed is the structure of the encoder and decoder networks as the state propagation happens solely in a latent space. The project's folder structure is as follows: ``` torchssm/ │ ├── train.py # Training entry point that takes in user args and handles boilerplate ├── test.py # Testing script to get reconstructions and metrics on a testing set ├── tune.py # Performs a hyperparameter search for a given dataset using Ray[Tune] ├── README.md # What you're reading right now :^) ├── requirements.txt # Anaconda requirements file to enable easy setup | ├── data/ | ├──Fig N. Pendulum-Colors Examples.
For the base presented experiments of this dataset, we consider and evaluate grayscale versions of pendulum and mass-spring - which conveniently are just the sliced red channel of the original sets. Each set has50000
training and 5000
testing trajectories sampled at Δt = 1
time intervals. Energy conservation
is preserved without friction and we assume constant placement of focal points for simplicity. Note that the
modification to color outputs in this framework is as simple as modifying the number of channels in the
encoder and decoder.
Bouncing Balls: Additionally, we provide a dataloader and generation scripts for the standard latent dynamics dataset of bouncing balls[1,2,5,7,8], modified from the implementation in [1]. It consists of a ball or multiple balls moving within a bounding box while being affected by potential external effects, e.g. gravitational forces[1,2,5], pong[2], and interventions[8]. The starting position, angle, and velocity of the ball(s) are sampled uniformly between a set range. It is generated with the PyMunk and PyGame libraries. In this repository, we consider two sets - a simple set of one gravitational force and a mixed set of 4 gravitational forces in the cardinal directions with varying strengths. We similarly generate
50000
training and
5000
testing trajectories, however sample them at Δt = 0.1
intervals.
Fig N. Single Gravity Bouncing Ball Example.
Notably, this system is surprisingly difficult to successfully perform long-term generation on, especially in cases of mixed gravities or multiple objects. Most works only report on generation within 5-15 timesteps following a period of 3-5 observation timesteps[1,2,7] and farther timesteps show lost trajectories and/or incoherent reconstructions.
Meta-Learning Datasets: One of the latest research directions for neural SSMs is evaluating the potential of meta-learning to build domain-adaptable latent dynamics functions[26,27,29]. A representative dataset example for this task is the Turbulent Flow dataset that is affected by various buoyancy forces, highlighting a task with partially shared yet heterogeneous dynamics[27].
Fig N. Turbulent Flow Example, sourced from [27].
Multi-System Dynamics: So far in the literature, the majority of works only consider training Neural SSMs on one system of dynamics at a time - with the most variety lying in that of differing trajectory hyper-parameters. The ability to infer multiple dynamical systems under one model (or learn to output dynamical functions given system observations) and leverage similarities between the sets is an ongoing research pursuit - with applications of neural unit hypernetworks[27] and dynamics functions conditioned on sequences via meta-learning[26,29] being the first dives into this.
Other Sets in Literature: Outside of the previous sets, there are a plethora of other datasets that have been explored in relevant work. The popular task of human motion prediction in both the pose estimation and video generation setting has been considered via datasets Human3.6Mil, CMU Mocap, and Weizzman-Action[5,19], though proper experimentation into this area would require problem-specific architectures given the depth of the existing field. Past high-dimensionality and image-space, standard benchmark datasets in time-series forecasting have also been considered, including the M4, Electricity Transformer Temperature (ETT), and the NASA Turbofan Degradation set. Recent works have begun looking at medical applications in inverse image reconstruction and the incorporation of physics-inspired priors[20,29y ]. Regardless of the success of Neural SSMs on these tasks, the task-agnostic factor and principled structure of this framework make it a versatile and appealing option for generative time-series modeling. ## Models Here, details on how the model implementation is structured and running experiments locally are given. As well, an overview of the abstract class implementation for a general Neural SSM and its types are explained. ### Implementation Structure Provided within this repository is a PyTorch class structure in which an abstract PyTorch-Lightning Module is shared across all the given models, from which the specific VAE and dynamics functions inherit and override the relevant forward functions for training and evaluation. Swapping between dynamics functions and PGM type is as easy as passing in the model's name for arguments, e.g. `python3 train.py --model node`. As the implementation is provided in PyTorch-Lightning, an optimization and boilerplate library for PyTorch, it is recommended to be familiar at face-level.
For every model run, a new
lightning_logs/
version folder is created as well as a new experiment version
under `experiments` related to the passed in naming arguments. Hyperparameters passed in for this run are both stored in
the Tensorboard instance created as well as in the local files hparams.yaml, config.json
. Default values and available
options can be found in train.py
or by running python3 train.py -h
. During training
and validation sequences, all of the metrics below are automatically tracked and saved into a Tensorboard instance
which can be used to compare different model runs following. Every 5 epochs, reconstruction sequences against the
ground truth for a set of samples are saved to the experiments folder. Currently, only one checkpoint is saved based
on the last epoch ran rather than checkpoints based on the best validation score or at set epochs. Restarting training
from a checkpoint or loading in a model for testing is done currently by specifying the ckpt_path
to the
base experiment folder and the checkpt
filename.
The implemented dynamics functions are each separated into their respective PGM groups, however, they can still share the same general classes. Each dynamics subclass has a
model_specific_loss
function that allows it to
return additional loss values without interrupting the abstract flow. For example, this could be used in a flow-based
prior that has additional KL terms over ODE flow density without needing to override the training_step
function with a duplicate copy. As well, there is additionally model_specific_plotting
to enable custom
plots every training epoch end.
### Implemented Dynamics
System Identification Models: For the system identification models, we provide a variety of dynamics functions that resemble the general and special
cases detailed above, which are provided in Fig N. below. The most general version is that of the Bayesian Neural ODE,
in which a neural ordinary differential equation[21] is sampled from a set of optimized distributional
parameters and used as the latent dynamics function
z't = fp(θ)(zs)
[5]. A deterministic version
of a standard Neural ODE is similarly provided, e.g.
z't = fθ(zs)
[10,21]. Following that, two forms of a
Recurrent Generative Network are provided, a residual version (RGN-Res) and a full-step version (RGN), that represent
deterministic and discrete non-linear transition functions. RGN-Res is the equivalent of a Neural ODE using a fixed
step Euler integrator while RGN is just a recurrent forward step function.
Additionally, a representation of the time-varying Linear-Gaussian SSM transition dynamics[1,2] (LGSSM) is
provided. And finally, a set of autoregressive models are considered (i.e. Recurrent neural networks, Long-Short Term
Memory networks, Gated Recurrent Unit networks) as baselines. Their PyTorch Cell implementations are used and evaluated
over the entire sequence, passing in the previously predicted state and observation as its inputs.
Training for these models has one mode, that of taking in several observational frames to infer z0 and then outputting a full sequence autonomously without access to subsequent observations. A likelihood function is compared over the full reconstructed sequence and optimized over. Testing and generation in this setting can be done out to any horizon easily and we provide small sample datasets of
200
timesteps to evaluate out to long horizons.
Fig N. Model schematics for system identification's implemented dynamics functions.
State Estimation Models: For the state estimation line, we provide a reimplementation of the classic Neural SSM work Deep Kalman Filter[7] alongside state estimation versions of the above, provided in Fig. N below. The DKF model modifies the standard Kalman Filter Gaussian transition function to incorporate non-linearity and expressivity by parameterizing the distribution parameters with neural networkszt∼N(G(zt−1,∆t), S(zt−1,∆t))
[7].
The autoregressive versions for this setting can be viewed as a reimplementation of the Variational Recurrent Neural
Network (VRNN), one of the starting state estimation works in Neural SSMs[22]. For the latent correction
step, we leverage a standard Gated Recurrent Unit (GRU) cell and the corrected latent state is what is passed to the
decoder and likelihood function. Notably, there are two settings these models can be run under: reconstruction
and generation. Reconstruction is used for training and incorporates ground truth observations to correct
the latent state while generation is used to test the forecasting abilities of the model, both short- and long-term.
Fig N. Model schematics for state estimation's implemented dynamics functions.
## Metrics Mean Squared Error (MSE): A common metric used in video and image tasks where its use is in per-frame average over individual pixel error. While a multitude of papers solely uses plots of frame MSE over time as an evaluation metric, it is insufficient for comparison between models - especially in cases where the dataset contains a small object for reconstruction[4]. This is especially prominent in tasks of system identification where a model that fails to predict long-term may end up with a lower average MSE than a model that has better generation but is slightly off in its object placement.Fig N. Per-Frame MSE Equation.
Valid Prediction Time (VPT): Introduced in [4], the VPT metric is an advance on latent dynamics evaluation over pure pixel-based MSE metrics. For each prediction sequence, the per-pixel MSE is taken over the frames individually, and the minimum timestep in which the MSE surpasses a pre-defined epsilon is considered the 'valid prediction time.' The resulting mean number over the samples is often normalized over the total prediction timesteps to get a percentage of valid predictions.Fig N. Per-Sequence VPT Equation.
Object Distance (DST): Another potential metric to support evaluation (useful in image-based physics forecasting tasks) is using the Euclidean distance between the estimated center of the predicted object and its ground truth center. Otsu's Thresholding method can be applied to grayscale output images to get binary predictions of each pixel and then the average pixel location of all the "active" pixels can be calculated. This approach can help alleviate the prior MSE issues of metric imbalance as the maximum Euclidean error of a given image space can be applied to model predictions that fail to have any pixels over Otsu's threshold.Fig N. Per-Frame DST Equation.
where RN is the dimension of the output (e.g. number of image channels) and s, shat are the subsets of "active" predicted pixels.Valid Prediction Distance (VPD): Similar in spirit to how VPT leverages MSE, VPD is the minimum timestep in which the DST metric surpasses a pre-defined epsilon[29]. This is useful in tracking how long a model can generate an object in a physical system before either incorrect trajectories and/or error accumulation cause significant divergence.
Fig N. Per-Sequence VPD Equation.
R2 Score: For evaluating systems where the full underlying latent system is available and known (e.g. image translations of Hamiltonian systems), the goodness-of-fit score R2 can be used per dimension to show how well the latent system of the Neural SSM captures the dynamics in an interpretable way[1,3]. This is easiest to leverage in linear transition dynamics. Ref. [1], while containing linear transition dynamics, mentioned the possibility of non-linear regression via vanilla neural networks, though this may run into concerns of regressor capacity and data sizes. Additionally, incorporating metrics derived from latent disentanglement learning may provide stronger evaluation capabilities.
Fig N. DVBF Latent Space Visualization for R2 scores, sourced from [1,3].
# Experiments This section details some experiments that evaluate the fundamental aspects of Neural SSMs and the effects of the framework decisions one can take. Trained model checkpoints and hyperparameter files are provided for each experiment underexperiments/model
. Evaluations are done with the metrics discussed above, as well as visualizations of
animated trajectories over time and latent walk visualizations.
## Hyperparameter Tuning
As is common in deep learning and variational inference tasks, the specific choices of hyper-parameters can have a
significant impact on the resulting performance and generalization of the model. As such, first we perform a
hyper-parameter tuning task for each model on a shared validation set to get eachs' optimized hyper-parameter set.
From this, the optimal set for each is carried across the various tasks given similar task complexity.
We provide a Ray[Tune] tuning script to handle training and formatting the Pytorch-Lightning outputs for each model,
found in tune.py
. It automatically parallelizes across GPUs and has a convenient Tensorboard output
interface to compare the tuning runs. In order to run custom tuning tasks, simply create a local folder in the
repository root directory and rename the tune run "name" to redirect the output there. Please refer to RayTune's
relevant documentation for information.
## Pendulum Dynamics
Here we report the results of tuning each of the models on the Hamiltonian physics dataset Pendulum. For each model,
we highlight their best-performing hyperparameters with respect to the validation extrapolation MSE. For experiment
going forwards, these hyperparameters will be used in experiments of similar complexity.
We test two environments for the Pendulum dataset, a fixed-point one-color pendulum and a multi-point multi-color
pendulum set of increased complexity. As described in [4], each individual sequence is sampled from a uniform
distribution over physical parameters like mass, gravity, and pivot length.
We describe data generation above in the Data section.