patrick-kidger / torchcde

Differentiable controlled differential equation solvers for PyTorch with GPU support and memory-efficient adjoint backpropagation.
Apache License 2.0
419 stars 43 forks source link

Masking Coefficients? #51

Open AshCher51 opened 1 year ago

AshCher51 commented 1 year ago

Hi Patrick,

Thanks for all your work with Diffrax and torchcde.

I noticed that for each time step in a two channel time series dataset, there are 4 coefficients associated with that particular time step (so 10 time steps would have 40 coefficients per example in batch).

To this end, I am attempting to incorporate a Neural CDE with a Transformer Decoder and wanted to apply masking on the coefficients to avoid any lookahead bias with the CDE model. My question then is if this is something that can be done?

My immediate thought would be reshaping the coefficients into a (batch_size, 4, 10) matrix and trying to find some way to use a (10, 10) tril mask and unsqueeze to pass into a CubicSpline interpolation, but I'm not sure how exactly this could be done.

Any help with this would be greatly appreciated!

Thanks,

Aashish

patrick-kidger commented 1 year ago

Take a look at https://arxiv.org/abs/2106.11028 and https://docs.kidger.site/diffrax/api/interpolation/. These discuss handling causality in the input.

(It's quite complicated, unfortunately!)

AshCher51 commented 1 year ago

Thanks Patrick! Appreciate those two resources. Since I'm doing discretely online prediction, I've chose to stick to the hermite cubic splines and make use of the z_T output of the model at each time step to ensure that there isn't any look ahead bias in making predictions.

I have a question with the z_T matrix though: the documentation mentions that the shape should be a tensor of shape (..., len(t), hidden_channels) whereas I'm observing with my data that even with an X.interval of [0, 8], I'm getting a z_T with a second to last dimension of 2, not a second to last dimension of 9.

Is this an indication that I've done something wrong? I'm not sure why I'm not getting an output for each time step.

(edit: in case this may be of some help, the shape of my coefficients are (16, 8, 12)).

(edit 2: just realized that this was because I had my integration times as X.interval as opposed to X.grid_points)

Appreciate all the help, and feel free to close the issue unless you feel as though I've done something incorrectly!