lululxvi / deepxde

A library for scientific machine learning and physics-informed learning
https://deepxde.readthedocs.io
GNU Lesser General Public License v2.1
2.75k stars 755 forks source link

Can't seem to learn Lotka-Volterra #85

Closed ViktorC closed 4 years ago

ViktorC commented 4 years ago

Hi @lululxvi,

Thanks for the great library.

I have been trying to solve the Lotka-Volterra equations using DeepXDE but I can't seem to get very good results. I have tried various numbers of layers and layer sizes, different numbers of collocation points, different learning rates, etc., but none of these things seemed to have helped. Could you please help me identify what I might be doing wrong?

def ode_system(x, y):
    r = y[:, 0:1]
    p = y[:, 1:2]
    d_r_over_d_t = tf.gradients(r, x)[0]
    d_p_over_d_t = tf.gradients(p, x)[0]
    return [
        d_r_over_d_t - (2. * r - .04 * r * p),
        d_p_over_d_t - (.02 * r * p - 1.06 * p)
    ]

def boundary(_, on_initial):
    return on_initial

geom = dde.geometry.TimeDomain(0., 10.)
ic1 = dde.IC(geom, lambda _: np.full(1, 100.), boundary, component=0)
ic2 = dde.IC(geom, lambda _: np.full(1, 15.), boundary, component=1)
data = dde.data.PDE(
    geom, ode_system, [ic1, ic2], 2000, 2, num_test=100
)

layer_size = [1] + [50] * 3 + [2]
activation = "tanh"
initializer = "Glorot uniform"
net = dde.maps.FNN(layer_size, activation, initializer)

model = dde.Model(data, net)
model.compile("adam", lr=0.002)
model.train(epochs=50000)
model.compile('L-BFGS-B')
model.train()

t = np.linspace(0., 10., 1000).reshape(1000, 1)
y = model.predict(t)

plt.xlabel('t')
plt.ylabel('y')
plt.plot(t, y[:, 0])
plt.plot(t, y[:, 1])
plt.show()

As you can see below, the test loss of the optimised PINN indicates good performance. However, the plotted prediction seems off.

Step      Train loss                                  Test loss                                   Test metric
0         [3.75e-02, 1.59e+00, 1.00e+04, 2.25e+02]    [3.72e-02, 1.59e+00, 0.00e+00, 0.00e+00]    []  
1000      [2.92e+02, 4.68e+01, 2.87e+02, 1.72e+02]    [3.00e+02, 3.39e+01, 0.00e+00, 0.00e+00]    []  
2000      [1.82e+02, 4.93e+01, 2.75e+00, 1.11e+02]    [2.36e+02, 4.33e+01, 0.00e+00, 0.00e+00]    []  
3000      [8.46e+01, 5.66e+01, 1.23e+01, 5.30e+01]    [1.14e+02, 8.83e+01, 0.00e+00, 0.00e+00]    []  
4000      [2.20e+01, 6.58e+00, 2.77e-01, 3.15e+00]    [2.80e+01, 1.85e+01, 0.00e+00, 0.00e+00]    []  
5000      [1.69e+01, 2.48e+00, 3.16e-02, 3.83e-01]    [1.74e+01, 4.36e+00, 0.00e+00, 0.00e+00]    []  
6000      [1.51e+01, 1.22e+00, 2.71e-03, 2.15e-02]    [1.49e+01, 1.20e+00, 0.00e+00, 0.00e+00]    []  
7000      [1.30e+01, 1.38e+00, 1.98e-02, 3.60e-02]    [1.28e+01, 1.36e+00, 0.00e+00, 0.00e+00]    []  
8000      [1.19e+01, 1.23e+00, 6.84e-03, 1.41e-02]    [1.54e+01, 1.67e+00, 0.00e+00, 0.00e+00]    []  
9000      [9.54e+00, 7.16e-01, 1.12e-02, 4.47e-03]    [9.42e+00, 1.11e+00, 0.00e+00, 0.00e+00]    []  
10000     [8.19e+00, 5.47e-01, 8.19e-03, 3.83e-02]    [8.14e+00, 8.09e-01, 0.00e+00, 0.00e+00]    []  
11000     [6.90e+00, 3.71e-01, 5.60e-03, 3.28e-02]    [6.81e+00, 5.77e-01, 0.00e+00, 0.00e+00]    []  
12000     [5.79e+00, 1.93e-01, 4.73e-03, 1.64e-02]    [5.71e+00, 1.70e-01, 0.00e+00, 0.00e+00]    []  
13000     [4.70e+00, 1.23e-01, 2.70e-03, 5.92e-03]    [4.66e+00, 1.19e-01, 0.00e+00, 0.00e+00]    []  
14000     [3.84e+00, 1.00e-01, 1.96e-03, 1.88e-03]    [3.81e+00, 1.12e-01, 0.00e+00, 0.00e+00]    []  
15000     [3.15e+00, 1.67e-01, 1.35e-03, 7.76e-04]    [3.12e+00, 1.80e-01, 0.00e+00, 0.00e+00]    []  
16000     [2.79e+00, 1.62e-01, 1.77e-04, 2.36e-03]    [2.80e+00, 1.87e-01, 0.00e+00, 0.00e+00]    []  
17000     [2.21e+00, 5.76e-02, 4.50e-04, 4.23e-04]    [2.18e+00, 5.69e-02, 0.00e+00, 0.00e+00]    []  
18000     [1.89e+00, 5.29e-02, 3.41e-04, 2.90e-04]    [1.87e+00, 4.77e-02, 0.00e+00, 0.00e+00]    []  
19000     [1.68e+00, 2.76e-02, 3.10e-04, 1.24e-04]    [1.67e+00, 2.35e-02, 0.00e+00, 0.00e+00]    []  
20000     [1.52e+00, 2.92e-02, 2.68e-04, 8.89e-05]    [1.51e+00, 2.50e-02, 0.00e+00, 0.00e+00]    []  
21000     [1.44e+00, 3.03e-02, 2.33e-04, 3.50e-05]    [1.43e+00, 2.19e-02, 0.00e+00, 0.00e+00]    []  
22000     [1.43e+00, 7.56e-02, 9.30e-05, 1.00e-04]    [1.44e+00, 7.53e-02, 0.00e+00, 0.00e+00]    []  
23000     [1.50e+00, 9.43e-02, 1.94e-05, 1.05e-05]    [1.52e+00, 1.01e-01, 0.00e+00, 0.00e+00]    []  
24000     [1.26e+00, 1.66e-01, 1.11e-04, 1.00e-03]    [1.23e+00, 1.85e-01, 0.00e+00, 0.00e+00]    []  
25000     [1.19e+00, 5.75e-02, 3.91e-04, 4.33e-04]    [1.25e+00, 5.53e-02, 0.00e+00, 0.00e+00]    []  
26000     [1.07e+00, 1.30e-01, 6.02e-05, 2.95e-04]    [1.17e+00, 1.81e-01, 0.00e+00, 0.00e+00]    []  
27000     [8.55e-01, 8.70e-03, 2.59e-05, 5.52e-05]    [8.49e-01, 1.09e-02, 0.00e+00, 0.00e+00]    []  
28000     [8.63e-01, 5.18e-02, 2.78e-04, 2.56e-06]    [9.36e-01, 6.71e-02, 0.00e+00, 0.00e+00]    []  
29000     [7.53e-01, 2.01e-02, 2.32e-05, 6.20e-05]    [7.49e-01, 2.23e-02, 0.00e+00, 0.00e+00]    []  
30000     [8.72e-01, 2.67e-01, 3.93e-06, 1.94e-04]    [8.77e-01, 2.87e-01, 0.00e+00, 0.00e+00]    []  
31000     [7.20e-01, 3.76e-02, 5.86e-05, 4.04e-05]    [7.26e-01, 4.41e-02, 0.00e+00, 0.00e+00]    []  
32000     [6.51e-01, 9.69e-03, 2.27e-05, 1.08e-05]    [6.52e-01, 1.41e-02, 0.00e+00, 0.00e+00]    []  
33000     [6.08e-01, 6.92e-03, 2.29e-05, 1.95e-05]    [6.06e-01, 7.11e-03, 0.00e+00, 0.00e+00]    []  
34000     [6.82e-01, 5.77e-02, 6.50e-05, 1.28e-04]    [7.13e-01, 1.12e-01, 0.00e+00, 0.00e+00]    []  
35000     [5.55e-01, 1.33e-02, 3.35e-08, 2.48e-05]    [5.75e-01, 3.74e-02, 0.00e+00, 0.00e+00]    []  
36000     [5.40e-01, 7.49e-03, 1.81e-05, 1.41e-05]    [5.38e-01, 6.64e-03, 0.00e+00, 0.00e+00]    []  
37000     [5.17e-01, 6.41e-03, 6.20e-05, 1.59e-05]    [5.21e-01, 1.52e-02, 0.00e+00, 0.00e+00]    []  
38000     [5.61e-01, 2.29e-02, 4.33e-05, 2.19e-08]    [5.83e-01, 2.49e-02, 0.00e+00, 0.00e+00]    []  
39000     [9.00e-01, 3.74e-01, 8.49e-05, 3.80e-07]    [9.56e-01, 4.58e-01, 0.00e+00, 0.00e+00]    []  
40000     [4.64e-01, 5.03e-03, 5.10e-06, 1.05e-05]    [4.60e-01, 5.08e-03, 0.00e+00, 0.00e+00]    []  
41000     [5.49e-01, 3.93e-02, 1.35e-05, 1.39e-06]    [6.37e-01, 5.62e-02, 0.00e+00, 0.00e+00]    []  
42000     [4.63e-01, 3.12e-02, 5.70e-07, 4.40e-06]    [4.78e-01, 3.23e-02, 0.00e+00, 0.00e+00]    []  
43000     [4.59e-01, 4.47e-02, 2.42e-06, 4.21e-09]    [4.92e-01, 4.06e-02, 0.00e+00, 0.00e+00]    []  
44000     [4.11e-01, 1.05e-02, 2.34e-05, 7.00e-06]    [4.08e-01, 1.35e-02, 0.00e+00, 0.00e+00]    []  
45000     [3.91e-01, 3.76e-03, 9.22e-06, 4.33e-06]    [3.89e-01, 3.25e-03, 0.00e+00, 0.00e+00]    []  
46000     [4.16e-01, 4.45e-02, 1.17e-05, 3.93e-06]    [4.13e-01, 4.50e-02, 0.00e+00, 0.00e+00]    []  
47000     [3.66e-01, 1.22e-02, 1.29e-05, 6.66e-06]    [3.63e-01, 1.25e-02, 0.00e+00, 0.00e+00]    []  
48000     [3.54e-01, 2.83e-03, 9.98e-06, 5.91e-06]    [3.51e-01, 2.86e-03, 0.00e+00, 0.00e+00]    []  
49000     [3.53e-01, 3.48e-03, 3.09e-05, 2.16e-06]    [3.57e-01, 7.11e-03, 0.00e+00, 0.00e+00]    []  
50000     [3.64e-01, 1.65e-02, 1.17e-05, 1.07e-05]    [3.62e-01, 1.69e-02, 0.00e+00, 0.00e+00]    []  

Best model at step 49000:
  train loss: 3.57e-01
  test loss: 3.64e-01
  test metric: []

'train' took 178.948568 s

Compiling model...
'compile' took 0.245634 s

Training model...

Step      Train loss                                  Test loss                                   Test metric
50000     [3.64e-01, 1.65e-02, 1.17e-05, 1.07e-05]    [3.62e-01, 1.69e-02, 0.00e+00, 0.00e+00]    []  
51000     [4.44e-02, 7.43e-03, 3.16e-06, 1.07e-05]                                                    
52000     [1.37e-02, 2.57e-03, 3.21e-05, 7.00e-06]                                                    
53000     [6.53e-03, 5.52e-04, 3.08e-08, 5.78e-07]                                                    
53308     [6.10e-03, 5.07e-04, 2.67e-06, 8.50e-07]    [5.72e-03, 5.44e-04, 0.00e+00, 0.00e+00]    []  

Best model at step 53308:
  train loss: 6.61e-03
  test loss: 6.26e-03
  test metric: []

'train' took 21.291070 s

Solution: solve_serial_fine

PINN prediction: lotka_volterra

Do you know what the problem could be? Any pointers would be much appreciated.

Many thanks, Viktor

lululxvi commented 4 years ago

As you can see the first two losses, i.e., ODE losses, are very large. Could you try use a smaller time domain first, e.g., [0, 2]?

ViktorC commented 4 years ago

Thanks for the quick response.

Ah, I was under the impression that an MSE on the order of 10^-3 or 10^-4 is a small loss. What would be good values to aim for generally?

As per your suggestion, I tried it using a smaller domain with fairly good results. I then tried to expand the domain gradually while also increasing the number of training points. To be able to minimise the loss, I had to increase the depth of the network as well. Finally, using 4000 training points and 5 hidden layers, I managed to get really good results for the [0, 10] interval. Interestingly, the magnitude of the MSE was roughly the same as before (10^-4 to 10^-3) but the solution looked perfect this time around.

lululxvi commented 4 years ago

Usually I aim for MSE smaller than 10^-4 to achieve good accuracy.

ViktorC commented 4 years ago

Great, I'll use that as the target in the future. 🙂 Thank you for your help!

richieakka commented 3 years ago

Hi ViktorC,

I am not an expert in this area and trying to explore this field. It would be great help if you can please explain the initial conditions you had taken for this problem.

Rdfing commented 3 years ago

As you can see the first two losses, i.e., ODE losses, are very large. Could you try use a smaller time domain first, e.g., [0, 2]?

Hi Lu,

Why would a smaller domain help? Is this related to the feature scaling or higher frequency?

Thanks, Haochen

lululxvi commented 3 years ago

It is one reason. Also related to network optimization of SGD.

camolina2 commented 3 years ago

Hi Lulu, first at all I want to thanks for your awesome library!! Can you please explain what does mean one of each column of the train loss output? thank you so much!

lululxvi commented 3 years ago

@camolina2 One column is the value of one loss term.

Bensayah commented 2 years ago

Hi ViktorC, Many thanks for sharing the code. I do the same you did 4000 training points and 5 hidden layers but still obtain the second graph of your post. I will be thankfull if you share the updated code.