Introduction to the Neural Jump ODE framework

In the following we give a brief overview of the considered prediction problem and the Neural Jump Ordinary Differential Equations (NJ-ODEs) framework. We start with an illustrative example to motivate the problem setting and the need for the NJ-ODE framework and then continue with a formal description of the problem setting and the proposed model.
Afterwards, we give short overviews of the different papers about Neural Jump ODEs, outlining their most important contributions.
Finally, we provide a short overview how to get started with the code, if you want to apply NJ-ODEs to your own data.
The videos below are presentations of the papers, where the first one is a short general introduction to the framework and the second one gives more in-depth explanations, focusing on the third paper. The third video is my PhD presentation, which presents the first five papers in a bit more detail.

Illustrative Example

Vital parameters of patients in a hospital are measured multiple times during their stay. For each patient, this happens at different times depending on the hospital's resources and clinical needs, hence the observation dates are irregular and exhibit some randomness. Moreover, not all vital parameters are always measured at the same time (e.g., a blood test and body temperature are not necessarily measure simultaneously), leading to incomplete observations. We assume to have a dataset with $N$ patients. For example, patient $1$ has $n_1 = 4$ measurements at hours $(t^{(1)}_1, t^{(1)}_2, t^{(1)}_3, t^{(1)}_4)$ where the values $(x^{(1)}_1, x^{(1)}_2, x^{(1)}_3, x^{(1)}_4)$ are measured. Patient $2$ only has $n_2 = 2$ measurements at hours $(t^{(2)}_1,t_2^{(2)})$ where the values $(x^{(2)}_1,x^{(2)}_2)$ are measured. Similarly, the $j$-th patient has $n_j$ measurements at times $(t^{(j)}_1, \dotsc, t^{(j)}_{n_j})$ with measured values $(x^{(j)}_1, \dots, x^{(j)}_{n_j})$. Each $x^{(j)}_i$ is a $d$-dimensional vector (representing all vital parameters of interest) that can have missing values. Based on this data, we want to forecast the vital parameters of new patients coming to the hospital. In particular, for a new patient with measured values $(x_1, \dotsc, x_k)$ at times $(t_1, \dotsc, t_k)$, we want to predict what the vital parameter values will likely be at any time $t > t_k$ for any $0 \leq k \leq n$.

Mathematical Description

Problem Setting

In the simplest form, we consider a continuous-time càdlàg stochastic process $X = (X_t)_{t \in [0,T]}$ taking values in $\mathbb{R}^d$ of which we make discrete, irregular and incomplete observations. In particular, we allow for a random number of $n\in \mathbb{N}$ observations taking place at the random times $$0=t_0 < t_1 < \dotsb < t_{n} \le T. $$ For clarification, $n$ is a random variable and $t_i$ are sorted stopping times. This setup is extremely flexible and allows for a possibly unbounded number of observations in the finite time interval $[0,T]$. To allow for incomplete observations, i.e., missing values in the observations, we define an observation mask as the sequence of random variables $M = (M_k)_{k \in \mathbb{N}}$ taking values in $\{ 0,1 \}^{d}$. If $M_{k, j}=1$, then the $j$-th coordinate $X_{t_k, j}$ is observed at observation time $t_k$.

We are interested in online predictions of the process $X$ at any time $t \in [0,T]$ given the observations up to time $t$. Therefore, we define the information that is available at any time $t$ as the $\sigma$-algebra generated by the observations made until time $t$. In particular, this leads to the filtration of the currently available information $\mathbb{A} := (\mathcal{A}_t)_{t \in [0, T]}$ given by

$$ \mathcal{A}_t := \boldsymbol{\sigma}\left(X_{t_i, j}, t_i, M_{i} \mid i \leq \kappa(t),\, j \in \{1 \leq l \leq d \mid M_{i, l} = 1 \} \right).$$

The $L^2$-optimal Prediction

We are mainly interested in the $L^2$-optimal prediction of the process $X$ at any time $t \in [0,T]$ given the available information $\mathcal{A}_t$. This optimal prediction is given by the conditional expectation $$\hat{X}_t = \mathbb{E}\left[ X_t \mid \mathcal{A}_t \right].$$ While this is a well-defined concept, it is often hard to compute in practice due to the high dimensionality of the process $X$ and the complexity of the information $\mathcal{A}_t$. Additionally, we do not assume to know the underlying dynamics (i.e., distribution) of the process $X$, but only have access to a training dataset with i.i.d. samples following the description above. Hence, we propose the Neural Jump ODE as a fully data-driven online prediction model to approximate the optimal prediction $\hat{X}_t$.

Model Architecture of Neural Jump ODEs

We propose a neural network based model, consiting of two main components: a recurrent neural network (RNN) and a neural ODE. To understand the design of the model architecture, we first review what the model should be able to provide for the online prediction task. Since we want to have a continuous-time prediction, the model needs to generate an output for every time point $t \in [0,T]$. Between any two (discrete) observation points, the model does not get any additional information except for the time which evolves and should therefore produce a prediction based on the information at the last observation time and the current time. This is where the neural ODE comes into play; it is used to model the dynamics of the process (or rather of the optimal prediction $\hat{X}$) between the observation times. On the other hand, the model should be able to incorporate the information of new observations when they become available. Clearly, a purely ODE-based model would not be able to do this, as it is deterministic once the initial condition (and driving function) is fixed. This is where the RNN comes into play; it is used to update the internal model state (i.e., the model's latent variable) $H_t$ at observation times $t_i$ based on the new observation $X_{t_i}$. Since the new observation leads to a jump in the available information, this will in general also lead to a jump in the model's latent variable at the observation time, $\Delta H_{t_i} = H_{t_i} - H_{t_i-}$. This model output can be compactly written as the solution $Y$ of the following stochastic differential equation (SDE)

\begin{equation} \begin{split} H_0 &= \rho_{\theta_2}\left(X_{0}, 0 \right), \\ dH_t &= f_{\theta_1}\left(H_{t-}, X_{\tau(t)},\tau(t), t - \tau(t) \right) dt + \left( \rho_{\theta_2}\left(X_{t}, H_{t-} \right) - H_{t-} \right) du_t, \\ Y_t &= g_{\theta_3}(H_t), \end{split} \end{equation}
where $u_t:=\sum_{i=1}^{n} 1_{\left[t_i, \infty\right)}(t), 0 \le t \le T$ is the pure-jump process counting the current number of observations, $\tau(t)$ is the time of the last observation before or at $t$ and $f_{\theta_1}, \rho_{\theta_2}$ and $g_{\theta_3}$ are neural networks with parameters $\theta_1$, $\theta_2$ and $\theta_3$. Even though the recurrent structure already allows for path-dependent predictions, we can additionally use the truncated signature transform as input to the networks $f_{\theta_1}$ and $\rho_{\theta_2}$ (not displayed here). This often helps the model to better capture the path-dependence of $\hat{X}$, since the signature transform carries important path information that can be hard to capture with the RNN alone.

The Training Framework

Under suitable assumption on the underlying process $X$, this model framework can approximate $\hat{X}_t$ arbitrarily well. However, to find the optimal parameters $\theta_1, \theta_2$ and $\theta_3$ achieving this approximation, we need a suitable loss function. In our work we prove that optimizing the loss function

\begin{equation}\label{equ-C0:Psi NJODE1} \Psi(Z) := \mathbb{E}\left[ \frac{1}{n} \sum_{i=1}^n \left( \left\lvert M_i \odot \left( X_{t_i} - Z_{t_i} \right) \right\rvert_2 + \left\lvert M_i \odot \left( X_{t_i} - Z_{t_{i}-} \right) \right\rvert_2 \right)^2 \right], \end{equation}
yields the optimal parameters, i.e., the model output $Y$ converges to the true conditional expectation $\hat{X}$. The intuition for this definition of the loss is that the "jump part" (i.e., the first term) of the loss function forces the RNN $\rho_{\theta_2}$ to produce good updates based on new observations, while the "continuous part" (i.e., the second term) of the loss forces the output before the jump to be small in $L^2$-norm. Since the conditional expectation minimizes the $L^2$-norm, this forces the neural ODE $f_{\theta_1}$ to continuously transform the hidden state such that the output approximates the conditional expectation well. Moreover, both parts of the loss force the readout network $g_{\theta_3}$ to reasonably transform the hidden state $H_t$ to the output $Y_t$. If observations can happen at any time, this intuitively implies that the model has to approximate the conditional expectation well at any time, since it will otherwise be penalised. Vice versa, if observation never happen within a certain time interval, we cannot guarantee good approximations there, since then, the training data does not allow for learning the dynamics of the underlying process within this time interval.

Extensions of the Framework

We have extended the basic framework described above in several ways:

  • By predicting conditional moments or conditional characteristic functions of $X$, one can use the Neural Jump ODE model to approximate the conditional law instead of only predicting the conditional mean.
  • The noise-adapted loss function can be used if observations of the process $X$ are not noise free.
  • Long-term predictions, i.e., approximating $\mathbb{E}\left[ X_t \mid \mathcal{A}_s \right]$ for any $0 \leq s \leq t \leq T$ can be achieved by an appropriate adaption of the training framework.
  • With the Input-Output loss function the Neural Jump ODE can learn the dynamics of input-output systems, where an output process is predicted based on observations of an input process. This allows to apply the model for filtering and classification tasks.
For more details see (the overview of) the different papers below.

Results

To demonstrate the effectiveness of the proposed model, we provide a series of experiments on synthetic and real-world data. Some results are shown below. In each of the settings, the model was trained on a training dataset and evaluated on a test dataset, where also the plots are generated.

×

Overview of Papers

The Starting Point

Our work on Neural Jump ODEs started with this paper, in which we introduce this new model together with a theoretical framework for its analysis. The setting is still rather simple, with a restriction to Ito-diffusions and complete observations, but since this paper is focused on conveying the main ideas and concepts, it is the perfect starting point for anyone interested in NJ-ODEs. Pairing an architecture based on the ODE-RNN model with a novel loss function, we prove consistency of our time-series prediction model. This is the first time such theoretical guarantees are provided for a model dealing with discrete, irregular observations of a continuous-time stochastic process. Several experiments on synthetic and real-world datasets demonstrate the effectiveness of our approach.

Lifting the Theoretical Framework to a New Level of Generality

The second work introduces a largely revised theoretical framework and an upgraded model architecture called Path-Dependent NJ-ODE. This allows for the estimation of generic continuous-time stochastic processes, while still providing theoretical guarantees of optimality under these generalised settings. In particular, we extend our model to handle processes with jumps, non-Markovian (i.e. path-dependent) processes and incomplete observations, which are all common difficulties in practice. Therefore, our model is now applicable to a much broader range of problems and can be used in a variety of applications. Additionally, we explain how to use NJ-ODEs for uncertainty estimation (and more generally to approximate the conditional distributions) and provide a first attempt for stochastic filtering with NJ-ODEs. All theoretical extensions are backed by rigorous proofs and are demonstrated through experiments on synthetic datasets. Moreover, we provide applications to the real-world datasets of Physionet and Limit Order Books.

Noisy Observations & Dependent Observation Framework

In the works so far, the process itself and the coordinate-wise observation times were assumed to be independent and observations were assumed to be noiseless. In this work we discuss two extensions to lift these restrictions and provide theoretical guarantees as well as empirical examples for them. In particular, we can lift the assumption of independence by extending the theory to much more realistic settings of conditional independence without any need to change the algorithm. Moreover, we introduce a new loss function, which allows us to deal with noisy observations and explain why the previously used loss function does not lead to a consistent estimator in the case of noisy observations. However, the original loss function still has its raison d'être in the case of noiseless (or low noise) observations, due to its better inductive bias for the training. This is illustrated on a real-world dataset. While the previous papers solely focus on approximations at observation times (which is encoded in the chosen metric), we additionally derive assumptions for which this leads to approximations at any time.

Long-Term Predictions & Chaotic Systems

The previous works all focus on approximating the conditional expectation at any time given all previous observations. While this is reasonable for many applications, it is not suitable for long-term predictions. In the case where the underlying process is deterministic, the conditional expectation coincides with the process itself. Therefore, this framework can equivalently be used to learn the dynamics of ODE or PDE systems solely from realizations of the dynamical system with different initial conditions. We showcase the potential of our method by applying it to the chaotic system of a double pendulum. When training the standard NJ-ODE method, the prediction starts to diverge from the true path after about half of the evaluation time. In this work we enhance the model with two novel ideas, which independently of each other improve the performance of our modeling setup. The resulting dynamics match the true dynamics of the chaotic system very closely. The same enhancements can be used to provably enable the NJ-ODE to learn long-term predictions for general stochastic datasets, where the standard model fails. In particular, we prove that using these enhancements, the NJ-ODE can learn to approximate the conditional expectation at any time given any arbitrary subset of the previous observations. This is verified in several experiments.

Nonparametric Filtering, Estimation and Classification

Previous works on Neural Jump ODEs focused on the estimation of a process $X$ given observations of $X$. This work extends the framework to input-output systems, where observations of an input process $U$ are used to predict an output process $V$, enabling direct applications in online filtering and classification. We establish theoretical convergence guarantees for this approach, providing a robust solution to $L^2$-optimal filtering. Empirical experiments highlight the model's superior performance over classical parametric methods, particularly in scenarios with complex underlying distributions. These results emphasize the approach's potential in time-sensitive domains such as finance and health monitoring, where real-time accuracy is crucial.

Doctoral Thesis: Neural Jump Ordinary Differential Equations

The first chapter of this PhD thesis provides an introduction to the Neural Jump ODE framework. It starts by discussing the foundations of forecasting and conditional expectations, which are at the core of this work. Then, an overview of the 5 papers above is given, with a focus on conveying the main ideas and concepts of these works. This introduction is intended to create a good intuition for the methods and results presented in later chapters. As such, it is neither meant to be fully self-contained, nor to be as precise as possible; trading rigour for a better intuitive understanding. Nevertheless, all needed concepts that go beyond the standard knowledge of probability theory, stochastic calculus, statistics and (neural network based) machine learning are provided.

Applying the Neural Jump ODE

We provide a Python implementation of the Neural Jump ODE framework. The code is available on GitHub with instructions how to reproduce the results of the papers. Probably the best way to get started is to reproduce some of these results and try out new hyperparameter specifications.

To apply the model to a new dataset, the following things need to be done:

  • At first, you need to generate a dataset out of the raw data that can be used for training the model. See the existing methods in data_utils.py. If your raw data can be generated from some type of stochastic process, you can probably implement this similarly to the synthetic dataset examples provided in synthetic_datasets.py and use the standard method implemented in data_utils.py to generate the dataset. Check out the hyperparameter settings for generating datasets in the different config files and the respective sections in the README.md.
  • Then, you might need to adjust the training script train.py to your needs, in particular, for loading the dataset if this was not done with one of the already implemented methods.
  • If you want to use the parallelized training (which is suggested), you might have to adjust the train_switcher.py script based on the changes of the previous items.
  • Finally, you can specify the hyperparameters for training the model(s) in one of the existing or in a new config file and start the training as described for the provided examples in the README.md. As a starting point for selecting the hyperparameters, we suggest to have a look the configs used in config_ParamFilter.py. Importantly, make sure to use the correct loss function for your task ('easy' for the standard loss function, 'IO' for the loss function adapted for input-output systems, and 'noisy_obs' for the noise-adapted loss function, i.e., when the observations have additive observation noise).

Authors

William Andersson

Jakob Heiss

Calypso Herrera

Florian Krach

Félix Ndonfack

Marc Nübel

Thorsten Schmidt

Josef Teichmann