SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
871 stars 157 forks source link

Supervised NNODE Example Broken? #318

Closed caseykneale closed 4 years ago

caseykneale commented 4 years ago

tried working through the basic supervised example, and getting a really strange error related to the output from the NNODE layer. Pretty peculiar, when the model is called from Flux the data disappears? When called outside of Train! its all fine. No clue where this bug ends up - might be Zygote?

using Flux, DiffEqFlux, OrdinaryDiffEq

f = 5
x = randn(f,20)
y = randn(f,1)

nn  = Flux.Chain(   Dense( f, f, identity ) )
nn_ode = NeuralODE( nn, (0.f0), Tsit5(),
                        save_everystep = false,
                        reltol = 1e-3, abstol = 1e-3,
                        save_start = false )

fc  = Dense( f, 1,  identity )

function DiffEqArray_to_Array( x )
    xarr = Array( x )
    return reshape( xarr, size( xarr )[ 1:2 ] )
end

model    = Flux.Chain( nn_ode, DiffEqArray_to_Array, fc )
#Test it
model( x )

function loss( ex, ey)
    global model
    model(ex)
    return sum( model( ex ) .- ey )
end
#test that too
loss(x,y)

opt = ADAM(0.05)
Flux.train!(    loss,
                Flux.params( nn, nn_ode.p, fc ),
                [ ( x, y ) ],
                opt )
ChrisRackauckas commented 4 years ago

@dhairyagandhi96 do you know what this could be? Sounds like it could be a Flux.train! thing not really DiffEqFlux related.

caseykneale commented 4 years ago

Tried rolling back Flux, Zygote, OrdinaryDiffEq versions and no luck. I think I may have upgraded to Julia 1.4.2 back when this used to work. Worried it might be a Base bug? I mean an Array is disappearing. Still trying to debug but no luck...

avik-pal commented 4 years ago

It seems something is going wrong in the gradient. gradient(() -> loss(x, y), Flux.params( nn, nn_ode.p, fc )) is causing the issue.

avik-pal commented 4 years ago

I was able to resolve the issue by changing the tstep to (0.0, 1.0) instead of (0.0f0).

@caseykneale Apart from that I am not sure if it is intended, but in the code x has a batch size of 20 and y has a batch size of 1.

caseykneale commented 4 years ago

Oh yes that's a mistake on my part... I forgot that was the tstep arg.

Looks like I'm missing a transpose too. This resolved the issue. I'm okay if this gets closed :).