Closed neighthan closed 3 years ago
Hi @neighthan! The best way to achieve your goal at the moment is indeed performing the augmentation first, then calling .trajectory
later. You can definitely use losses distributed over various solution points at different times, however note that at the moment to achieve this within torchdyn
you require integral losses (check out the notebook for more information).
I will be introducing soon (within the week) a small addition to our adjoint to handle losses distributed on specific solution points, rather than being distributed across the entire time domain. This is nothing new, and will mimick the standard approach of e.g torchdiffeq
. When that is done I'll update you here.
In the example notebooks (at least, the ones I've been through; I haven't looked at all of them fully) it seems like the loss during training only uses the final point of each trajectory. For things like classification problems, this makes sense. However, if I'm training a neural ODE to approximate some other ODE that I have sampled trajectories from, using only the final point of the trajectory doesn't make sense. In particular, suppose my data is like x{t0}, x{t1}, ..., x_{tn} for a single trajectory from time t0 to tn. I see a couple of options here:
y_hat = self.model(x)
(whereself
is theLearner
) like in the examples. If that's not the case, it seems likeself.model.trajectory
is required; more on this in 2.self.model.trajectory
and pass in a tensor of all time points that we need. This seems to work for me with a basic model, but fails when I try, e.g., to do input-layer augmentation like in the demo notebook (because thenself.model
is annn.Sequential
which doesn't have.trajectory
). We could work around this by havingLearner.forward
call the augmentation layer first, then get a trajectory with theNeuralDE
, then run the layer that converts each point back into the space of the original variables.There isn't anything that I can't get to work here, but since I didn't see any examples using
model.trajectory
for training, I just wanted to check if this seems right or if there's a better way to do this.