SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
864 stars 153 forks source link

Training of UDEs with recurrent networks #391

Closed junsebas97 closed 3 years ago

junsebas97 commented 4 years ago

Hello, first of all, thanks for this package is very useful

I'm new in the scientific machine learning field and currently, I'm learning, and very interested, in the UDEs to model physical systems. So far I've incorporated just multilayer perceptrons into the ODEs, getting good results. However, I tried to incorporate recurrent networks in order to increase the capability of these models, but the training with Flux.train! crash.

For example, the next model uses 5 LSTM cells in its definition:

ANN = Chain(LSTM(1, 5), Dense(5, 1))
function model(du, u, p, t)
    du = ANN(u)[1]
end 

As Chain was used instead of FastChain, the training is with Flux.train!, but it fails; ERROR: LoadError: MethodError: no method matching similar(::DiffEqBase.NullParameters)

I think it is because Julia is able to recognize the weights of the network as parameters of ANN but no of model(du, u, p, t):

> params(ANN)
params(ANN)
Params([Float32[-0.15044485; 0.63485605; … ; -0.32755396; -0.1456971], Float32[-0.12670276 
0.5237209; -0.50491256 0.19825394; … ; 0.55224925 0.44527698; -0.33252665 0.08099249], Float32[-0.3998112, 0.78604937, 1.0, 1.0, 0.177412, -0.719973, 0.79453707, -0.07745761], Float32[0.0, 0.0], Float32[0.0, 0.0], Float32[0.08246351 -1.0055027], Float32[0.0]])

>params(model)
Params([])

So to avoid this error, I used the instead Flux.destructure to define the recurrent network:

ANN      = Chain(LSTM(1, 5), Dense(5, 1))
par, fun = Flux.destructure(ANN)

function model(du, u, p, t)
    du = fun(p)(u)
end

In this way, I was able to train with Flux.train! and DiffEqFlux.sciml_train and no errors appear. But, the workflow of this new network is not really clear for me, because this "Destructured" network seems to be not considering the previous states like the "normal" network:

> fun(par)(1.0)
1×1 Array{Float32,2}:
 -0.13853791

> fun(par)(1.0)
1×1 Array{Float32,2}:
 -0.13853791

>ANN(1.0)
1×1 Array{Float32,2}:
 -0.13853791

>ANN(1.0)
1×1 Array{Float32,2}:
 -0.04546199

I'd be very grateful if somebody could clarify the work of fun(u , par)(if it's truly recurrent or it works like multilayer perceptron?) or could you tell me how to correctly train the UDEs with recurrent networks?

ChrisRackauckas commented 4 years ago

If you want to use RNNs, just use Flux's RNN's and then use the destructure form, as you demonstrate here. That should then be fine. Without the destructure form there's no way to have the parameter vector influence the differential equation for the adjoint. Does that make sense?

junsebas97 commented 4 years ago

Yes, it makes sense and that is clear for me.

But I'm not sure yet about the destructure network; if it indeed works like a recurrent network. For example:

> ANN = Chain(LSTM(1, 5), Dense(5, 1));
> param, func = Flux.Destructure(ANN);
> ANN(0.5)
1×1 Array{Float32,2}:
 0.1435646

> ANN(0.5)
1×1 Array{Float32,2}:
 0.23851033

> ANN(0.5)
1×1 Array{Float32,2}:
 0.30093935

> func(param)(0.5)
1×1 Array{Float32,2}:
 0.1435646

> func(param)(0.5)
1×1 Array{Float32,2}:
 0.1435646

> func(param)(0.5)
1×1 Array{Float32,2}:
 0.1435646

Up you can see that the normal network ANN receives three times the same input but always produce different outputs, due to the influence of previous states; But the destructured network func for equal inputs always produce equal outputs, like a feed-forward network.

My question is if when I call the ODE, will func (destructured network) work like a recurrent network or like feed-forward?

ChrisRackauckas commented 4 years ago

It works exactly like whatever network you deconstructed. It just rebuilds that same network with new parameters.

junsebas97 commented 4 years ago

Ok, thank you so much!

But I am still wondering why in the upper code it doesn't work like a recurrent network? I guess it is because I made three separate calls so each call is an initialization.

Thank you so much and excuse me for being so persistent

ChrisRackauckas commented 4 years ago

What do you mean "it doesn't work like a recurrent network"? I'm lost which example you're pointing to.

junsebas97 commented 4 years ago
> ANN = Chain(LSTM(1, 5), Dense(5, 1));
> param, func = Flux.Destructure(ANN);

Up there are two definitions (normal ANN and deconstructed func) of the same network, a recurrent one. When the same input is passed twice to ANN it produces different outputs, due to the influence of the first input/output

> ANN(0.5)
1×1 Array{Float32,2}:
 0.1435646

> ANN(0.5)
1×1 Array{Float32,2}:
 0.23851033

But when I do the same process in func no matter how many times I pass an input the previous input/output are not considered (is not recurrent)

> func(param)(0.5)
1×1 Array{Float32,2}:
 0.1435646

> func(param)(0.5)
1×1 Array{Float32,2}:
 0.1435646

> func(param)(0.5)
1×1 Array{Float32,2}:
 0.1435646
avik-pal commented 4 years ago

I just checked that the state is reinitialized every time you do func(param). Probably likely needs to be fixed in Flux

avik-pal commented 4 years ago

@junsebas97 Could you test if this works for you?

mutable struct MyRecur{T}
  cell::T
  init
  state
end

function (m::MyRecur)(xs...)
  h, y = m.cell(m.state, xs...)
  m.state = h
  return y
end

Flux.@functor MyRecur cell, init, state

Flux.trainable(r::MyRecur) = Flux.trainable(r.cell)

c = Flux.LSTMCell(1, 1)

ANN = MyRecur(c, Flux.hidden(c), Flux.hidden(c)) # This is essentially your LSTM

If it does we can put this fix in

EDIT: Thanks to @timdkim for pointing out that destructure ends up returning state as well. Hence this solution will not work inside NeuralODE.

junsebas97 commented 4 years ago

Hi, I've been implemented it, but it seems to be not working (although I'm not sure if I did the correct implementation)

With the LSTM cell, you defined above :

ANN = MyRecur(c, Flux.hidden(c), Flux.hidden(c))

I use Flux.destructure and I have the same issue, the state is reinitialized every time I do func(param) therefore it doesn't consider the previous inputs.

> c  = Flux.LSTMCell(1, 1);
> ANN = MyRecur(c, Flux.hidden(c), Flux.hidden(c));
> par, func = Flux.destructure(ANN);
> func(par).state
(Float32[0.0], Float32[0.0])

> func(par)(0.5)
1×1 Array{Float64,2}:
 0.07713145256002776

> func(par).state
(Float32[0.0], Float32[0.0])

> func(par)(0.5)
1×1 Array{Float64,2}:
 0.07713145256002776

So I tried to use ANN normally (without Flux.destructure) and I was not even able to solve the next system:

> c  = Flux.LSTMCell(1, 1);
> ANN = MyRecur(c, Flux.hidden(c), Flux.hidden(c));
> dudt(u, p, t) = ANN(u);
> prob = ODEProblem(dudt, 0.0, (0.0, 10.0));
> solve(prob, TSit5(), saveat = 0.1)
MethodError: no method matching similar(::Float64, ::Type{Float64})

Sorry if I'm missing something Thank you so much!

avik-pal commented 4 years ago

Hi @junsebas97, the example you posted is working as expected. When you destructure ANN it has 0 as the state values. So when you do func(par) it basically reconstructs the layer with the old parameter values.

c  = Flux.LSTMCell(1, 1);
ANN = MyRecur(c, Flux.hidden(c), Flux.hidden(c));
ANN(1.0); ANN.state
par, func = Flux.destructure(ANN);
m = func(par)
m.state # Same value ass ANN.state
m(2.0)
par, func = Flux.destructure(m); # I need to destructure again for the new state to be reflected

Regarding the ODEProblem I need to look a bit more into it

timkimd commented 4 years ago

I'm finding that the values in par change when the state changes even though we are not updating any parameters. Is this expected?

avik-pal commented 3 years ago

Thanks for pointing out. Flux.destructure returns state as well (which is the changing). I will have to look into how to fix it.

ChrisRackauckas commented 3 years ago

Wait, this state, are you trying to preserve it between f calls?

avik-pal commented 3 years ago

Yes. The OP wanted to preserve state in the f calls across time and reset once we reevaluate from an initial condition

ChrisRackauckas commented 3 years ago

That doesn't make all that much sense. ODEs don't solve in a way that monotonically increases in time, so sharing state might not have the behavior you'd expect. Having extra state would need to be stored in additional ODEs.

avik-pal commented 3 years ago

Yes, I agree in this particular context it doesn't make a lot of sense, but there is definitely a bug on Flux's side, since doing re(p) I would expect to get back the same model I destructured.

ChrisRackauckas commented 3 years ago

Yup, so this should probably get an appropriate flux issue and close here.

avik-pal commented 3 years ago

Moving the discussion over to FluxML/Flux.jl#1329