SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
847 stars 151 forks source link

Example/Implementation of Neural Controlled Differential Equations #408

Open johnp-4dvanalytics opened 3 years ago

johnp-4dvanalytics commented 3 years ago

The paper provides a good method for encoding data with an ODE. This would be useful for being the encoder of an encoder/decoder architecture as an alternative to using an RNN encoder.

https://arxiv.org/abs/2005.08926 https://github.com/patrick-kidger/NeuralCDE

I have played around with the example in the code repo, but it is very slow and could probably be significantly faster if written with DiffEqFlux.

ChrisRackauckas commented 3 years ago

Yeah, that could be a nice model to implement. Let me know if you need any help optimizing it.

johnp-4dvanalytics commented 3 years ago

@ChrisRackauckas I was able to get a version of this working although it is very slow. If there are any optimizations that stand out please let me know. Thanks!

Here is the code for it:

using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random

T = Float32

bs = 512
X = [rand(T, 10, 50) for _ in 1:bs*10]

function create_spline(i)
    x = X[i]
    t = x[end, :]
    t = (t .- minimum(t)) ./ (maximum(t) - minimum(t))

    spline = QuadraticInterpolation(x, t)
end

splines = [create_spline(i) for i in tqdm(1:length(X))]

rand_inds = randperm(length(X))

i_sz = size(X[1], 1)
h_sz = 16

use_gpu = true
batches = [[splines[rand_inds[(i-1)*bs+1:i*bs]]] for i in tqdm(1:length(X)÷bs)]

data_ = Iterators.cycle(batches)

function call_and_cat(splines, t)
    vals = Zygote.ignore() do
        vals = hcat([spline(t) for spline in splines]...)
    end
    vals |> (use_gpu ? gpu : cpu)
end

function derivative(A::QuadraticInterpolation, t::Number)
    idx = findfirst(x -> x >= t, A.t) - 1
    idx == 0 ? idx += 1 : nothing
    if idx == length(A.t) - 1
        i₀ = idx - 1; i₁ = idx; i₂ = i₁ + 1;
    else
        i₀ = idx; i₁ = i₀ + 1; i₂ = i₁ + 1;
    end
    dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
    dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
    dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
    A.u[:, i₀] * dl₀ + A.u[:, i₁] * dl₁ + A.u[:, i₂] * dl₂
end

function derivative_call_and_cat(splines, t)
    vals = Zygote.ignore() do
        vals = hcat([derivative(spline, t) for spline in splines]...)
    end
    vals |> (use_gpu ? gpu : cpu)
end

cde = Chain(
    Dense(h_sz, h_sz, relu),
    Dense(h_sz, h_sz*i_sz, tanh),
) |> (use_gpu ? gpu : cpu)

h_to_out = Dense(h_sz, 2) |> (use_gpu ? gpu : cpu)

initial = Dense(i_sz, h_sz) |> (use_gpu ? gpu : cpu)

cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out)

basic_tgrad(u,p,t) = zero(u)

function predict_func(p, BX)
    By = call_and_cat(BX, 1)

    x0 = call_and_cat(BX, 0)
    i = 1
    j = (i-1)+length(initial_p)

    h0 = initial_re(p[i:j])(x0)

    function dhdt(h,p,t)
        x = derivative_call_and_cat(BX, t)
        bs = size(h, 2)
        a = reshape(cde_re(p)(h), (i_sz, h_sz, bs))
        b = reshape(x, (1, i_sz, bs))

        dh = batched_mul(b,a)[1,:,:]
    end

    i = j+1
    j = (i-1)+length(cde_p)

    tspan = (0.0f0, 0.8f0)

    ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
    prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    solver = Tsit5()

    h = solve(prob,solver,u0=h0,saveat=tspan[end], save_start=false).u[end]

    i = j+1
    j = (i-1)+length(h_to_out_p)

    y_hat = h_to_out_re(p[i:j])(h)

    y_hat, By[1:2, :]
end

function loss_func(p, BX)
    y_hat, y = predict_func(p, BX)

    mean(sum(sqrt.((y .- y_hat).^2), dims=1))
end

p = vcat(initial_p, cde_p, h_to_out_p)

callback = function (p, l)
  display(l)
  return false
end

using DiffEqFlux

result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.05),
    data_,
    cb = callback,
    maxiters = 10)
ChrisRackauckas commented 3 years ago

Hey, here's an updated version with comments on what was done and timings. That little training step improved by about 6.6x:

using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random

T = Float32

bs = 512
X = [rand(T, 10, 50) for _ in 1:bs*10]

function create_spline(i)
    x = X[i]
    t = x[end, :]
    t = (t .- minimum(t)) ./ (maximum(t) - minimum(t))

    spline = QuadraticInterpolation(x, t)
end

splines = [create_spline(i) for i in tqdm(1:length(X))]

rand_inds = randperm(length(X))

i_sz = size(X[1], 1)
h_sz = 16

use_gpu = true
batches = [[splines[rand_inds[(i-1)*bs+1:i*bs]]] for i in tqdm(1:length(X)÷bs)]

data_ = Iterators.cycle(batches)

function call_and_cat(splines, t)
    vals = Zygote.ignore() do
        vals = reduce(hcat,[spline(t) for spline in splines])
    end
    vals |> (use_gpu ? gpu : cpu)
end

function derivative(A::QuadraticInterpolation, t::Number)
    idx = findfirst(x -> x >= t, A.t) - 1
    idx == 0 ? idx += 1 : nothing
    if idx == length(A.t) - 1
        i₀ = idx - 1; i₁ = idx; i₂ = i₁ + 1;
    else
        i₀ = idx; i₁ = i₀ + 1; i₂ = i₁ + 1;
    end
    dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
    dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
    dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
    @views @. A.u[:, i₀] * dl₀ + A.u[:, i₁] * dl₁ + A.u[:, i₂] * dl₂
end

function derivative_call_and_cat(splines, t)
    vals = Zygote.ignore() do
        reduce(hcat,[derivative(spline, t) for spline in splines]) |> (use_gpu ? gpu : cpu)
    end
end

cde = Chain(
    Dense(h_sz, h_sz, relu),
    Dense(h_sz, h_sz*i_sz, tanh),
) |> (use_gpu ? gpu : cpu)

h_to_out = Dense(h_sz, 2) |> (use_gpu ? gpu : cpu)

initial = Dense(i_sz, h_sz) |> (use_gpu ? gpu : cpu)

cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out)

basic_tgrad(u,p,t) = zero(u)

function predict_func(p, BX)
    By = call_and_cat(BX, 1)

    x0 = call_and_cat(BX, 0)
    i = 1
    j = (i-1)+length(initial_p)

    h0 = initial_re(p[i:j])(x0)

    function dhdt(h,p,t)
        x = derivative_call_and_cat(BX, t)
        bs = size(h, 2)
        a = reshape(cde_re(p)(h), (i_sz, h_sz, bs))
        b = reshape(x, (1, i_sz, bs))

        dh = batched_mul(b,a)[1,:,:]
    end

    i = j+1
    j = (i-1)+length(cde_p)

    tspan = (0.0f0, 0.8f0)

    ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
    prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    solver = Tsit5()

    sol = solve(prob,solver,u0=h0,saveat=tspan[end], save_start=false, sensealg=sense)
    #@show sol.destats
    i = j+1
    j = (i-1)+length(h_to_out_p)

    y_hat = h_to_out_re(p[i:j])(sol[end])

    y_hat, By[1:2, :]
end

function loss_func(p, BX)
    y_hat, y = predict_func(p, BX)

    mean(sum(sqrt.((y .- y_hat).^2), dims=1))
end

p = vcat(initial_p, cde_p, h_to_out_p)

callback = function (p, l)
  display(l)
  return false
end

using DiffEqFlux

Zygote.gradient((p)->loss_func(p, first(data_)...),p)
@time Zygote.gradient((p)->loss_func(p, first(data_)...),p)

@time result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.05),
    data_,
    cb = callback,
    maxiters = 10)

# Start
# 13.178288 seconds (40.04 M allocations: 3.362 GiB, 5.01% gc time)

# Reduce(hcat)
# 6.273153 seconds (21.84 M allocations: 2.443 GiB, 7.86% gc time)

# @views @.
# 3.315527 seconds (5.05 M allocations: 495.695 MiB, 3.81% gc time)

# gpu in do
# 2.652512 seconds (5.11 M allocations: 466.430 MiB, 2.64% gc time)

# Training time before:
# 199.442675 seconds (218.19 M allocations: 23.603 GiB, 66.46% gc time)

# Training time after:
# 30.587359 seconds (58.69 M allocations: 5.210 GiB, 3.74% gc time)

The rate limiting step here is that the spline data is on the CPU while your computations are on the GPU, so the most costly portion now is simply moving the spline output to the GPU. That's like 90% of the cost or something ridiculous now, so you'd have to tackle that problem and I was only giving myself 30 minutes to play with this. One way you could do this would be to make your quadratic spline asynchronously pre-cache some of the next time points onto the GPU while other computations are taking place. Or, even cooler, train a neural network to mimic the spline but be all on the GPU, and then use that in place of the spline. But shipping that much data every step is going to dominate the computation so it's gotta be dealt with somehow.

What's the baseline you want to beat here? Do you have that code around to time?

johnp-4dvanalytics commented 3 years ago

Awesome, thanks for reviewing the code and giving those optimizations! That's quite a speed up.

I will try to figure out how to avoid moving the data from CPU to GPU every step.

Btw the author of the paper / repo above recently released this new repo for this type of model: https://github.com/patrick-kidger/torchcde

So probably we would want to show it outperforming some examples in that repo.

The baseline I was comparing against was based off of the code in https://github.com/patrick-kidger/NeuralCDE/blob/master/example/example.py

I'll do a speed comparison against that and the optimized code you posted.

johnp-4dvanalytics commented 3 years ago

Here is the code I am using as the efficiency baseline:

On my system the baseline code take about ~6 seconds to run whereas the optimized version you posted takes about ~18 seconds. I think once we fix moving the data to GPU at each step it should be a lot faster though.

import controldiffeq
import math
import torch

class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)

    def forward(self, z):
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        z = z.tanh()
        z = z.view(*z.shape[:-1], self.hidden_channels, self.input_channels)
        return z

class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()
        self.hidden_channels = hidden_channels

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, times, coeffs):
        spline = controldiffeq.NaturalCubicSpline(times, coeffs)

        z0 = self.initial(spline.evaluate(times[0]))

        z_T = controldiffeq.cdeint(dX_dt=spline.derivative,
                                   z0=z0,
                                   func=self.func,
                                   t=times[[0, int(len(times)*.8)]],
                                   atol=1e-2,
                                   rtol=1e-2)
        z_T = z_T[1]
        pred_y = self.readout(z_T)
        return pred_y

def get_data():
    X = torch.rand(512, 50, 10)
    t = torch.linspace(0., 1, X.shape[1])
    y = X[:, -1, :2]

    return t, X, y

from ipdb import set_trace
def main():
    train_t, train_X, train_y = get_data()

    model = NeuralCDE(input_channels=train_X.shape[-1], hidden_channels=16, output_channels=2)
    optimizer = torch.optim.Adam(model.parameters())

    train_coeffs = controldiffeq.natural_cubic_spline_coeffs(train_t, train_X)

    import time
    from tqdm import tqdm

    start = time.time()
    for epoch in tqdm(range(10)):
        pred_y = model(train_t, train_coeffs).squeeze(-1)
        loss = (pred_y - train_y).norm(dim=-1).mean()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))
    end = time.time()
    print(end - start)

if __name__ == '__main__':
    main()
ChrisRackauckas commented 3 years ago

I don't see any shuttling to GPUs there: are the spline coefficients on the GPU in that implementation? If they are, that would make a massive difference.

Also, doing this as 1 spline instead of 5000 splines probably makes a decent difference.

johnp-4dvanalytics commented 3 years ago

This it the file where the spline code is defined: https://github.com/patrick-kidger/NeuralCDE/blob/master/controldiffeq/interpolate.py#L229

It does seem like the spline operations are being done as a batch on the GPU

def derivative(self, t):
        """Evaluates the derivative of the natural cubic spline at a point t, which should be a scalar tensor."""
        fractional_part, index = self._interpret_t(t)
        inner = self._two_c[..., index, :] + self._three_d[..., index, :] * fractional_part
        deriv = self._b[..., index, :] + inner * fractional_part
        return deriv
johnp-4dvanalytics commented 3 years ago

I wrapped the PyTorch spline functions using PyCall and CUDA.jl and I was able to get a speed up of ~5.5x over the torch version, it was 92.7 seconds for the 10 batches of torch version and ~17 seconds for the 10 batches of the DiffEqFlux version.

Btw thanks for the awesome library!

Here's what the calls looked like: spl is the python spline object

x0, tspan = Zygote.ignore() do
        tspan = spl.interval.cpu().numpy()
        x0 = Base.unsafe_wrap(CuArray, CuPtr{Float32}(spl.evaluate(0).permute(1,0).contiguous().data_ptr()), (10, 512))
        x0, tspan
    end

    function dhdt(h,p,t)
        x = Zygote.ignore() do
            x = Base.unsafe_wrap(CuArray, CuPtr{Float32}(spl.derivative(t).permute(1,0).contiguous().data_ptr()), (10, 512))
        end
        ....
ChrisRackauckas commented 3 years ago

Awesome. So yeah, it would really be nice to get that directly implemented in Julia as a library function for people who want to use this method. It's the rate-limiting step.

patrick-kidger commented 3 years ago

Just been pointed at this. Three quick comments:

Anyway, I'm not Julia-proficient but let me know if I can help out over here.

ChrisRackauckas commented 3 years ago

I think his latest version (the version that was timed) is using your spline code IIUC

johnp-4dvanalytics commented 3 years ago

@patrick-kidger

In the Julia code I used for the benchmark I was doing a call out to your spline code so that I could do a consistent comparison.

I did have some errors in the python code I posted above that I fixed in my script that I used for the benchmarking. Here is the updated script:

class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        super(CDEFunc, self).__init__()
        self.linear = torch.nn.Linear(hidden_channels,
                                      hidden_channels * input_channels)
        self.f1 = torch.nn.Linear(hidden_channels, hidden_channels)
        store_attr()

    def forward(self, t, z):
#         return (self.linear(z).view(len(z), self.hidden_channels, self.input_channels))
        z = F.relu(self.f1(z))

        return torch.tanh(self.linear(z).view(len(z), self.hidden_channels, self.input_channels))

class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, coeffs):
        spline = torchcde.NaturalCubicSpline(coeffs)

        z0 = self.initial(spline.evaluate(0.))

        zt = torchcde.cdeint(X=spline,
                              z0=z0,
                              func=self.func,
#                              t=t,
                              t=spline.interval,
                              method='dopri5',
#                               atol=1e-2,
#                               rtol=1e-2,
                              options=dict(grid_points=spline.grid_points, 
                                           eps=1e-5,

#                                            , dtype=torch.float32
                                          ))
        z_T = zt[:, -1]

        pred_y = self.readout(z_T)
        return pred_y

bs = 512
input_channels = X.size(2)
hidden_channels = 16  # hyperparameter, we can pick whatever we want for this

device="cuda"
torch_core.defaults.device = device

dls = DataLoaders(TfmdDL(train_ds, shuffle=True, bs=bs), 
                  TfmdDL(val_ds, shuffle=True, bs=bs), device=device)

model = NeuralCDE(input_channels=input_channels, hidden_channels=hidden_channels, output_channels=2).type(dtype).to(device)

optimizer = torch.optim.Adam(model.parameters())

import time
from tqdm import tqdm

start = time.time()
total = 0
for i, (BX, By) in enumerate(tqdm(dls[0])):
    pred_y = model(BX).squeeze(-1)
    loss = (pred_y - By).norm(dim=-1).mean()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print('Batch: {}   Training loss: {}'.format(i+1, loss.item()))
    total += 1
    if total == 10:
        break
end = time.time()
end-start
patrick-kidger commented 3 years ago

TfmdDL (amongst others) isn't defined so I can't run that.

I wouldn't set the default device in the way that you're doing. In particular this doesn't perform the usual CPU-to-GPU copy you'll usually see when training a model.

You're using grid_points wrong: this should be passed when using linear interpolation, but not with cubic interpolation.

I suggest using https://github.com/patrick-kidger/torchcde/blob/master/example/example.py as a reference point.

johnp-4dvanalytics commented 3 years ago

Thanks for the clarification about the grid_points, I'll fix that. I am using fastai for creating the dataloaders, it automatically puts the batches on the gpu. I'll post a minimal runnable script shortly that takes out the fastai code.

johnp-4dvanalytics commented 3 years ago

Here are updated Python and Julia scripts based on the example you linked to:

EDIT: the loss for the Julia version isn't decreasing like the Python version, so I may have a bug with the model. I will try to fix that.

Python: Time to run the training part at the bottom: ~58 seconds

import math
import torch
import torchcde

######################
# A CDE model looks like
#
# z_t = z_0 + \int_0^t f_\theta(z_s) dX_s
#
# Where X is your data and f_\theta is a neural network. So the first thing we need to do is define such an f_\theta.
# That's what this CDEFunc class does.
# Here we've built a small single-hidden-layer neural network, whose hidden layer is of width 128.
######################
class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        ######################
        # input_channels is the number of input channels in the data X. (Determined by the data.)
        # hidden_channels is the number of channels for z_t. (Determined by you!)
        ######################
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)

    ######################
    # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
    # different times, which would be unusual. But it's there if you need it!
    ######################
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z

######################
# Next, we need to package CDEFunc up into a model that computes the integral.
######################
class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, coeffs):
        X = torchcde.NaturalCubicSpline(coeffs)

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        z0 = self.initial(X.evaluate(0.))

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.func,
                              t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y

######################
# Now we need some data.
# Here we have a simple example which generates some spirals, some going clockwise, some going anticlockwise.
######################
def get_data():
    t = torch.linspace(0., 4 * math.pi, 100)

    start = torch.rand(128) * 2 * math.pi
    x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos[:64] *= -1
    y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos += 0.01 * torch.randn_like(x_pos)
    y_pos += 0.01 * torch.randn_like(y_pos)
    ######################
    # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
    # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
    ######################
    X = torch.stack([t.unsqueeze(0).repeat(128, 1), x_pos, y_pos], dim=2)
    y = torch.zeros(128)
    y[:64] = 1

    perm = torch.randperm(128)
    X = X[perm]
    y = y[perm]

    ######################
    # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
    # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise respectively.
    ######################
    return X, y

def main(num_epochs=30):
    device = "cuda"
    dtype = torch.float32
    train_X, train_y = get_data()
    train_X, train_y = train_X.type(dtype), train_y.type(dtype)

    ######################
    # input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.
    # hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.
    # output_channels=1 because we're doing binary classification.
    ######################
    model = NeuralCDE(input_channels=3, hidden_channels=8, output_channels=1).type(dtype).to(device)
    optimizer = torch.optim.Adam(model.parameters())

    ######################
    # Now we turn our dataset into a continuous path. We do this here via natural cubic spline interpolation.
    # The resulting `train_coeffs` is a tensor describing the path.
    # For most problems, it's probably easiest to save this tensor and treat it as the dataset.
    ######################
    train_coeffs = torchcde.natural_cubic_spline_coeffs(train_X)

    train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
    for epoch in range(num_epochs):
        for batch in train_dataloader:
            batch_coeffs, batch_y = batch
            batch_coeffs, batch_y = batch_coeffs.to(device), batch_y.to(device)
            pred_y = model(batch_coeffs).squeeze(-1)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))

def return_data():
    device = "cuda"
    dtype = torch.float32
    train_X, train_y = get_data()
    train_X, train_y = train_X.type(dtype), train_y.type(dtype)
    train_coeffs = torchcde.natural_cubic_spline_coeffs(train_X)

    train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
    data = [(torchcde.NaturalCubicSpline(batch[0]), batch[1]) for batch in train_dataloader]

    return data

import pickle as p

#get data for spline to use in Julia code
data = return_data()
p.dump(data, open("cde_data.p", "wb"))

import time
start = time.time()
main(10)
end = time.time()
end-start

Julia code using the spline from Python time to run training: ~27 seconds

using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random
using CUDA
using PyCall

py"""
    import torch
    import pickle as p

    torch.cuda.set_device(0)

    spl_targ = p.load(open("cde_data.p", "rb"))
    spl = [x[0].cuda() for x in spl_targ]
    targ = [x[1].cuda() for x in spl_targ]
    bs = targ[0].shape[0]
    """

bs = py"bs"

T = Float32

i_sz = 3
h_sz = 8
o_sz = 1

use_gpu = true

batches = [(py"spl"[i], py"targ"[i]) for i in 1:py"len(spl)"]

data_ = Iterators.cycle(batches)

cde = Chain(
    Dense(h_sz, h_sz, relu),
    Dense(h_sz, h_sz*i_sz, tanh),
) |> (use_gpu ? gpu : cpu)

h_to_out = Dense(h_sz, o_sz) |> (use_gpu ? gpu : cpu)

initial = Dense(i_sz, h_sz) |> (use_gpu ? gpu : cpu)

cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out)

basic_tgrad(u,p,t) = zero(u)

function predict_func(p, spl)
    x0, tspan = Zygote.ignore() do
        tspan = spl.interval.cpu().numpy()
        x0 = Base.unsafe_wrap(CuArray, CuPtr{Float32}(spl.evaluate(0).permute(1,0).contiguous().data_ptr()), (i_sz, bs))
        x0, tspan
    end
    i = 1
    j = (i-1)+length(initial_p)

    local batch_size = bs

    h0 = initial_re(p[i:j])(x0)

    function dhdt(h,p,t)
        x = Zygote.ignore() do
            x = Base.unsafe_wrap(CuArray, CuPtr{Float32}(spl.derivative(t).permute(1,0).contiguous().data_ptr()), (i_sz, batch_size))
        end
        bs = size(h, 2)
        a = reshape(cde_re(p)(h), (i_sz, h_sz, bs))
        b = reshape(x, (1, i_sz, bs))

        dh = batched_mul(b,a)[1,:,:]
    end

    i = j+1
    j = (i-1)+length(cde_p)

    ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
    prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
    solver = Tsit5()

    sol = solve(prob,solver,u0=h0,saveat=tspan[end], save_start=false, sensealg=sense)
    i = j+1
    j = (i-1)+length(h_to_out_p)

    y_hat = h_to_out_re(p[i:j])(sol[end])

    y_hat
end

function loss_func(p, spl, targ)
    y_hat = predict_func(p, spl)

    y = Zygote.ignore() do
        y = Base.unsafe_wrap(CuArray, CuPtr{Float32}(targ.cuda().unsqueeze(-1).permute(1,0).contiguous().data_ptr()), (o_sz, bs))
    end

    return Flux.Losses.logitbinarycrossentropy(y_hat, y, agg=mean)
end

p = vcat(initial_p, cde_p, h_to_out_p)

callback = function (p, l)
  display(l)
  return false
end

predict_func(p, first(data_)[1])

using DiffEqFlux

Zygote.gradient((p)->loss_func(p, first(data_)...),p)
@time Zygote.gradient((p)->loss_func(p, first(data_)...),p)

@time result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.05),
    data_,
    cb = callback,
    maxiters = 10*length(batches))

I suspect the gap between the Julia and Python code may widen in Julia's favor with larger batch sizes, hidden sizes, etc. but I still think doing something with DiffEqGPU will probably be the way to get a major speed up. I'll be trying some of the ideas Chris mentioned in the other thread to see if I can get a speed up with that.

ChrisRackauckas commented 3 years ago

I wouldn't get your hopes up there.

think doing something with DiffEqGPU will probably be the way to get a major speed up

That's for a different use case, i.e. ensembles of small ODEs.

I suspect the gap between the Julia and Python code may widen in Julia's favor with larger batch sizes, hidden sizes, etc.

On the contrary, it would probably shrink as the rate limiting step is sooner or later going to be the cost of the GPU kernels, which if both are calling into CuBLAS then it'll be the same. So if they are taking the same number of steps (which they likely aren't due to some stabilizing tricks, but those are like 50% gains), then you'd expect the cost to be the same. It's when the kernels aren't fully saturated that the extra codegen and fusion matters.

johnp-4dvanalytics commented 3 years ago

Okay, thanks for keeping me from going down that route. I was mainly thinking that if each trajectory has it's own solver they should be required to do many less function calls, whereas (if I'm not mistaken) the batch version has to do a stop whenever any of the trajectories needs to stop, so there would be many more function calls done for the current batch version. Is there a good way to reduce the number of function calls?

ChrisRackauckas commented 3 years ago

Oh I read you wrong. If you goal is to make use of like 4 GPUs by running 4 trajectories at a time on different GPUs, yeah DiffEqGPU isn't the right tool but EnsembleDistributed where each Julia process has a different GPU will do this. You could also try to pack multiple onto the same GPU, but sooner or later you'll get memory limited.

Is there a good way to reduce the number of function calls?

That's a great research question. In the general context, that's just developing a "better" differential equation solver which can be hard work given how much they've been optimized, but there's still some tricks no one has done and we will have some new methods coming out soonish. But in the context of training a CDE, there's some other tricks one can employ. For example, you don't necessarily need to fit the ODE solves themselves: you can regularize to find solutions that are fast to solve, you drop accuracy and only increase accuracy after a decending a bit, etc.

patrick-kidger commented 3 years ago

Some discrepancies:

@ChrisRackauckas: You mention that the number of solver steps can be reduced via some stabilising tricks - I'm curious what you're referring to specifically?

johnp-4dvanalytics commented 3 years ago

@patrick-kidger @ChrisRackauckas

I created a non-gpu, non-batch version that performs fairly well. I wasn't able to figure out the issue with the GPU version, I believe that there is some non-trivial issue with the gradient calculation.

Below is the script for training on the data from your example. It probably could easily incorporate multiprocessing to speed it up. It has a use_linear variable for choosing whether to use linear interpolation or natural cubic interpolation. The linear interpolation is quite slow since it has the additional tstops, but the cubic interpolation version is quite fast. It takes ~83 seconds to run on 10 epochs of the data of your example. The ODE solves get faster as the loss goes down, so the time taken is dependent on the starting parameters and the order of the examples seen, so there is some variance in the time.

I benchmarked against your example with the changes that you mentioned and found that on average the code took ~60 seconds to run. The script for that is included at the bottom.

EDIT: I was accidentally using h_sz=16 instead of 8 for the Julia version, and I forgot to include the compilation warmup for the sciml_train, updated the script with those values.

using Statistics
using DataInterpolations
using Flux
using OrdinaryDiffEq
using DiffEqSensitivity
using Zygote
using ProgressBars
using Random
using CUDA
using PyCall
using DiffEqFlux
using BenchmarkTools

py"""
    from scipy.interpolate import CubicSpline
    import numpy as np
    import math
    import torch
    from tqdm import tqdm

    def get_data(N):
        t = torch.linspace(0., 4 * math.pi, 100)

        start = torch.rand(N) * 2 * math.pi
        x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
        x_pos[:64] *= -1
        y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
        x_pos += 0.01 * torch.randn_like(x_pos)
        y_pos += 0.01 * torch.randn_like(y_pos)
        ######################
        # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
        # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
        ######################
        X = torch.stack([t.unsqueeze(0).repeat(N, 1), x_pos, y_pos], dim=2)
        y = torch.zeros(N)
        y[:64] = 1

        perm = torch.randperm(N)
        X = X[perm]
        y = y[perm]

        ######################
        # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
        # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise respectively.
        ######################
        return X, y
    X, y = get_data(N=256)

    y = y.unsqueeze(-1)

    t = np.linspace(0, 1, X.shape[1], dtype=np.float32)

    y = y.permute(1, 0).numpy()

    Xy = [(CubicSpline(x=t,y=X[i].numpy(), axis=0, bc_type="natural"), t, y[:, i]) for i in tqdm(range(len(X)))]
    """

T = Float32
use_linear = false

Xy = py"Xy"
if use_linear
    Xy = [(DataInterpolations.LinearInterpolation(T[permutedims(Xy[i][1](Xy[i][2]), (2,1));], Xy[i][2]), Xy[i][2], Xy[i][3]) for i in tqdm(1:length(Xy))]
end

i_sz = length(Xy[1][1](0))
h_sz = 8
o_sz = length(Xy[1][3])

use_gpu = false
device = (use_gpu ? gpu : cpu)

cde = Chain(
    Dense(h_sz, h_sz, relu),
    Dense(h_sz, h_sz*i_sz, tanh),
) |> device
h_to_out = Dense(h_sz, o_sz) |> device
initial = Dense(i_sz, h_sz) |> device

cde_p, cde_re = Flux.destructure(cde)
initial_p, initial_re = Flux.destructure(initial)
h_to_out_p, h_to_out_re = Flux.destructure(h_to_out)

basic_tgrad(u,p,t) = zero(u)

function derivative(A::DataInterpolations.LinearInterpolation{<:AbstractArray{<:Number}}, t::Number)
    idx = findfirst(x -> x >= t, A.t) - 1
    idx == 0 ? idx += 1 : nothing
    θ = 1 / (A.t[idx+1] - A.t[idx])
    (A.u[:, idx+1] - A.u[:, idx]) / (A.t[idx+1] - A.t[idx])
end

function predict_func(p, spl, t)
    x0, tspan = Zygote.ignore() do
        tspan = (t[1], t[end])

        x0 = T[spl(0);]
        x0, tspan
    end
    i = 1
    j = (i-1)+length(initial_p)

    h0 = initial_re(p[i:j])(x0)

    function dhdt(h,p,t)
        dx = Zygote.ignore() do
            # dx = SMatrix{i_sz, 1}(reshapeT[x_int_i.derivative(1)(t);], (i_sz, 1)))
            if use_linear
                dx = derivative(spl, t)
            else
                dx = T[spl.derivative(1)(t);]
            end
            dx
        end
        dh = reshape(cde_re(p)(h), (h_sz, i_sz))*dx
    end

    i = j+1
    j = (i-1)+length(cde_p)

    ff = ODEFunction{false}(dhdt,tgrad=basic_tgrad)
    prob = ODEProblem{false}(ff,h0,tspan,p[i:j])
    sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())

    if use_linear
        tstops = t
        d_discontinuities = t
    else
        tstops = []
        d_discontinuities = []
    end

    solver = Tsit5()

    sol = solve(prob,solver,u0=h0,saveat=tspan[end],
        save_start=false,
        sensealg=sense,
        tstops=tstops,
        d_discontinuities=tstops,
        atol=1e-6, rtol=1e-4
    )

    out = sol[end]

    i = j+1
    j = (i-1)+length(h_to_out_p)

    y_hat = h_to_out_re(p[i:j])(sol[end])

    y_hat
end

N = length(Xy)

inds = randperm(N)
train_inds = inds[1:trunc(Int, length(inds)*.5)]
val_inds = inds[trunc(Int, length(inds)*.5)+1:end]
@assert length(train_inds) + length(val_inds) == length(inds)

train_dl = Iterators.cycle([Xy[i] for i in train_inds])
val_dl = [Xy[i] for i in val_inds]

function loss_func(p, spl, t, y; train=true)
    y_hat = predict_func(p, spl, t)

    loss = Flux.Losses.logitbinarycrossentropy(y_hat, y, agg=mean)

    if train
        for _ in 1:num_additional
            spl_i, t_i, y_i = Zygote.ignore() do
                train_i = train_inds[rand(1:length(train_inds))]
                spl_i, t_i, y_i = Xy[train_i]
            end

            y_hat_i = predict_func(p, spl_i, t_i)

            loss += Flux.Losses.logitbinarycrossentropy(y_hat, y, agg=mean)
        end
        loss = loss/(num_additional+1)
    end
    return loss
end

p = vcat(initial_p, cde_p, h_to_out_p)

# Zygote.gradient((p)->loss_func(p, first(train_dl)...),p)

callback = function (p, _)
    global display_i
    if display_i % display_every == 0
        l = Zygote.ignore() do
            l = 0
            subset_inds = randperm(length(val_inds))[1:num_val_to_test]
            val_inds_subset = val_inds[subset_inds]

            for i in val_inds_subset
                l += loss_func(p, Xy[i]..., train=false)
            end

            l/length(subset_inds)
        end

        display(l)
    end
    display_i += 1
    return false
end

display_i = 1
num_val_to_test = 16
display_every = Inf
# num_additional = 4 - 1
num_additional = 0

result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.01),
    cb = callback,
    train_dl,
    maxiters = 1)

@time result_neuralode = DiffEqFlux.sciml_train(loss_func, p, ADAM(0.01),
    cb = callback,
    train_dl,
    maxiters = length(train_inds)*10) #time = 82.803470

function final_loss(p)
    l = 0

    for i in val_inds
        l += loss_func(p, Xy[i]..., train=false)
    end

    l/length(val_inds)
end

final_loss(result_neuralode.minimizer) #0.00112
import math
import torch
import torchcde

######################
# A CDE model looks like
#
# z_t = z_0 + \int_0^t f_\theta(z_s) dX_s
#
# Where X is your data and f_\theta is a neural network. So the first thing we need to do is define such an f_\theta.
# That's what this CDEFunc class does.
# Here we've built a small single-hidden-layer neural network, whose hidden layer is of width 128.
######################
class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        ######################
        # input_channels is the number of input channels in the data X. (Determined by the data.)
        # hidden_channels is the number of channels for z_t. (Determined by you!)
        ######################
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 128)
        self.linear2 = torch.nn.Linear(128, input_channels * hidden_channels)

    ######################
    # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
    # different times, which would be unusual. But it's there if you need it!
    ######################
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z

######################
# Next, we need to package CDEFunc up into a model that computes the integral.
######################
class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()

        self.func = CDEFunc(input_channels, hidden_channels)
        self.initial = torch.nn.Linear(input_channels, hidden_channels)
        self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, coeffs):
        X = torchcde.NaturalCubicSpline(coeffs)

        ######################
        # Easy to forget gotcha: Initial hidden state should be a function of the first observation.
        ######################
        z0 = self.initial(X.evaluate(0.))

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.func,
                              t=X.interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        pred_y = self.readout(z_T)
        return pred_y

######################
# Now we need some data.
# Here we have a simple example which generates some spirals, some going clockwise, some going anticlockwise.
######################
def get_data():
    t = torch.linspace(0., 4 * math.pi, 100)

    start = torch.rand(128) * 2 * math.pi
    x_pos = torch.cos(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos[:64] *= -1
    y_pos = torch.sin(start.unsqueeze(1) + t.unsqueeze(0)) / (1 + 0.5 * t)
    x_pos += 0.01 * torch.randn_like(x_pos)
    y_pos += 0.01 * torch.randn_like(y_pos)
    ######################
    # Easy to forget gotcha: time should be included as a channel; Neural CDEs need to be explicitly told the
    # rate at which time passes. Here, we have a regularly sampled dataset, so appending time is pretty simple.
    ######################
    X = torch.stack([t.unsqueeze(0).repeat(128, 1), x_pos, y_pos], dim=2)
    y = torch.zeros(128)
    y[:64] = 1

    perm = torch.randperm(128)
    X = X[perm]
    y = y[perm]

    ######################
    # X is a tensor of observations, of shape (batch=128, sequence=100, channels=3)
    # y is a tensor of labels, of shape (batch=128,), either 0 or 1 corresponding to anticlockwise or clockwise respectively.
    ######################
    return X, y

def main(num_epochs=30):
    device = "cuda"
    dtype = torch.float32
    train_X, train_y = get_data()
    train_X, train_y = train_X.type(dtype), train_y.type(dtype)

    ######################
    # input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.
    # hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.
    # output_channels=1 because we're doing binary classification.
    ######################
    model = NeuralCDE(input_channels=3, hidden_channels=8, output_channels=1).type(dtype).to(device)
    optimizer = torch.optim.Adam(model.parameters())

    ######################
    # Now we turn our dataset into a continuous path. We do this here via natural cubic spline interpolation.
    # The resulting `train_coeffs` is a tensor describing the path.
    # For most problems, it's probably easiest to save this tensor and treat it as the dataset.
    ######################
    train_coeffs = torchcde.natural_cubic_spline_coeffs(train_X)

    train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)

    #warm up
    for batch in train_dataloader:
        batch_coeffs, batch_y = batch
        batch_coeffs, batch_y = batch_coeffs.to(device), batch_y.to(device)
        pred_y = model(batch_coeffs).squeeze(-1)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        break

    start = time.time()
    for epoch in range(num_epochs):
        for batch in train_dataloader:
            batch_coeffs, batch_y = batch
            batch_coeffs, batch_y = batch_coeffs.to(device), batch_y.to(device)
            pred_y = model(batch_coeffs).squeeze(-1)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))
    end = time.time()
    print("Time taken: {} seconds".format(end-start))

    return end-start

import time
from tqdm.notebook import tqdm

times_taken = [main(10) for _ in tqdm(range(64))]

import numpy as np

np.mean(times_taken) #59.79856628552079
pharringtonp19 commented 2 years ago

What's the latest status on this project? Seems useful

ChrisRackauckas commented 2 years ago

I don't think anyone has picked it up. In terms of differentiable interpolations, DataInterpolations.jl got some nice stable differentiability overloads, so this should be easy pickings but someone needs to package it all up.

johnp-4dvanalytics commented 2 years ago

@ChrisRackauckas I refactored the CPU version and changed it to only use Julia code. I created a simple .md example file for it. Could you point me to a guide for how to open a pull request for it?

ChrisRackauckas commented 2 years ago

https://www.youtube.com/watch?v=QVmU29rCjaA is a tutorial for all of that kind of stuff.