I'm looking to taking in a sequence of values and output a sequence as well.
So instead of doing flatten and computing the final output from neural ode, I would like to make use of each intermediate value that neural ode outputs and produce a loss with a target label.
Is there an example of this somewhere using TorchDyn?
This is the default behavior now with NeuralODE instances. The solution will be a Tensor of output num_time_points, batch, dims, which you can calculate losses over.
I'm looking to taking in a sequence of values and output a sequence as well. So instead of doing
flatten
and computing the final output from neural ode, I would like to make use of each intermediate value that neural ode outputs and produce a loss with a target label.Is there an example of this somewhere using TorchDyn?
Thanks