NeuroDiffGym / neurodiffeq

A library for solving differential equations using neural networks based on PyTorch, used by multiple research groups around the world, including at Harvard IACS.
http://pypi.org/project/neurodiffeq/
MIT License
680 stars 89 forks source link

Add the capability of training nets that are a bundle of solutions #118

Closed at-chantada closed 3 years ago

at-chantada commented 3 years ago

With inspiration from this paper, in this pull request are classes that allow one to train neural networks that are solutions of an ODE or system of ODEs for different values of its parameters and initial conditions. The highest dimensional case possible that this solver is capable of doing is a bundle in t_0, u_0, u_0_prime, and all parameters of the ODE or system of ODEs. To illustrate how it works, here is an example where a network is trained to be a solution for the damped harmonic oscillator, with different values for the initial position x_0 (between 0 and 1) and the damping parameter b (between 0.5 and 1.5):

First let's create the analytical solution to check how good is the network once it's trained:

import numpy as np

def damp_harm_sol(t, x_0, x_0_prime, m, b, k):
    x_t = 0

    # Under-damping:
    if b**2 < 4*m*k:

        # Frequency:
        w = np.sqrt(np.abs((b**2)-4*m*k))/(2*m)

        # Exponent factor:
        e = b/(2*m)

        # Parameters:
        c_1 = x_0
        c_2 = (x_0_prime+e*x_0)/w

        # Function:
        x_t = np.exp(-e*t)*(c_1*np.cos(w*t)+c_2*np.sin(w*t))

    # Over-damping:
    if b**2 > 4*m*k:
        # Auxiliary parameters:
        r_0 = np.sqrt((b**2)-4*m*k)/(2*m)
        r_00 = b/(2*m)

        # Exponent factors:
        r_1 = -r_00+r_0
        r_2 = -r_00-r_0

        # Parameters:
        c_2 = (x_0_prime-r_1*x_0)/(r_2-r_1)
        c_1 = x_0-c_2

        # Function:
        x_t = c_1*np.exp(r_1*t)+c_2*np.exp(r_2*t)

    # Critical Damping:
    if b**2 == 4*m*k:
        # Exponent factor:
        e = b/(2*m)

        # Parameters:
        c_1 = x_0
        c_2 = (x_0_prime+e*x_0)

        # Function:
        x_t = np.exp(-e*t)*(c_1+c_2*t)

    return x_t

Then the actual code to train the network:

import matplotlib.pyplot as plt
from neurodiffeq import diff      # the differentiation operation
from neurodiffeq.conditions import BundleIVP   # the initial condition
from neurodiffeq.solvers import BundleSolver1D  # the solver

# Damped oscilator paramters
m = 1
b_min = 0.5
b_max = 1.5
k = 2
x_0_min = 0
x_0_max = 1
x_0_prime = 1
t_0 = 0
t_f = np.pi
ts = np.linspace(t_0, t_f, 50)

# specify the ODE system
parametric_damp = lambda x, t, x_0, b: [m*diff(x, t, order=2) + b*diff(x, t) + k*x]  # All of the inputs to the network must be inputs to the ODE, including the conditions,
                                                                                       # even though those are not impelemented in the functions itself

# specify the initial conditions

init_vals_damp = [BundleIVP(t_0=t_0, u_0_prime=x_0_prime, bundle_conditions=['u_0'])]  # Specify in budnle_conditions, the conditions that are to be included in the bundle

# solve the ODE
solution_damp = BundleSolver1D(
    ode_system=parametric_damp, conditions=init_vals_damp, t_min=t_0, t_max=t_f,
    theta_min=(x_0_min, b_min),  # The order must always be the conditions first and in the same order as in bundle_conditions
    theta_max=(x_0_max, b_max)  # The order must always be the conditions first and in the same order as in bundle_conditions
    )
solution_damp.fit(max_epochs=20000)
solution = solution_damp.get_solution()

Finally some plots to see the loss and, qualitatively, how good the solutions are in comparison with the analytical solution:

for x_0, b in zip([0.1, 0.1, 0.9, 0.9] , [0.6, 1.4, 0.6, 1.4]):
  x1_net = solution(ts, x_0 * np.ones(len(ts)), b * np.ones(len(ts)), to_numpy=True)
  x1_ana = damp_harm_sol(ts, x_0, x_0_prime, m, b, k)

  fig, axs = plt.subplots(1, 1)

  axs.plot(ts, x1_net, label='ANN-based solution of x', color='C1')
  axs.plot(ts, x1_ana, '.',label='analytical solution of x', color='C0')
  axs.set_xlabel('t')
  axs.set_ylabel('x')

  title = 'Damped Harmonic Oscilator (' + r'$m={},\;k={},\;b={},\;x_0={},\;x^\prime_0={}$'.format(m, k, b, x_0, x_0_prime)  + ')'
  axs.set_title(title)
  axs.legend()
  filename = 'x={}_b={}.png'.format(x_0, b)
  plt.savefig(filename)

fig, axs = plt.subplots(1, 1)

loss_damp = solution_damp.metrics_history

axs.plot(loss_damp['train_loss'], label='training loss')
axs.plot(loss_damp['valid_loss'], label='validation loss')
axs.set_xlabel('epochs')
axs.set_yscale('log')
axs.set_title('Loss during training')
axs.legend()
plt.savefig('loss.png')

x=0 1_b=0 6 x=0 1_b=1 4 x=0 9_b=0 6 x=0 9_b=1 4 loss

codecov-commenter commented 3 years ago

Codecov Report

:exclamation: No coverage uploaded for pull request base (master@46786b0). Click here to learn what that means. The diff coverage is 67.91%.

Impacted file tree graph

@@            Coverage Diff            @@
##             master     #118   +/-   ##
=========================================
  Coverage          ?   89.63%           
=========================================
  Files             ?       17           
  Lines             ?     2807           
  Branches          ?        0           
=========================================
  Hits              ?     2516           
  Misses            ?      291           
  Partials          ?        0           
Impacted Files Coverage Δ
neurodiffeq/solvers.py 79.94% <17.94%> (ø)
neurodiffeq/generators.py 93.93% <85.33%> (ø)
neurodiffeq/conditions.py 93.87% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 46786b0...89c2675. Read the comment docs.

shuheng-liu commented 3 years ago

Hi Augusto. This is amazing! I'm merging it into master. In the future, I might make some very minor changes (e.g. arg names).

I'm thinking of a way to allow more flexible bundle learning. If we rewrite the logic for every reparameterization we have, there would be much code duplication sooner or later.

Another thing is that, while GeneratorND is a nice idea, we could also ensemble a generator that's essentially equivalent to an GeneratorND instance:

ga = Generator1D(...)
gb = Generator1D(...)
gc = Generator1D(...)
...
generator = ga * gb * gc * ...