SciML / StochasticDiffEq.jl

Solvers for stochastic differential equations which connect with the scientific machine learning (SciML) ecosystem
Other
245 stars 66 forks source link

mul! errors with matrix input #515

Closed ba2tro closed 1 year ago

ba2tro commented 1 year ago

While using EulerHeun with BacksolveAdjoint there's this mul! issue due to W.dW being a matrix. Doing vec() fixes it. Downstream: https://github.com/SciML/DiffEqFlux.jl/pull/750#discussion_r1006298994

ChrisRackauckas commented 1 year ago

When doing this, ask yourself "is Euler fine?", because it's going to be the simplest method that is expansive like this. https://github.com/SciML/StochasticDiffEq.jl/blob/master/src/perform_step/low_order.jl#L54 So if you swap to EM(), is W.dW a matrix? It shouldn't be.

https://github.com/SciML/StochasticDiffEq.jl/blob/master/src/solve.jl#L293-L298 that should make it a vector. If it's passing a noise_rate_prototype in order to be non-diagonal, then rand_prototype is always a Vector, right? Check that it's a Vector. If there's anywhere to vec, it would be here.

And then this would need tests.

ba2tro commented 1 year ago

So I tried EM() and the same thing happens, where W.dW is a matrix and doing vec() resolves the error. But the reason for that is, instead of rand_prototype being set from here: https://github.com/SciML/StochasticDiffEq.jl/blob/master/src/solve.jl#L293-L298 ,as you suggested above, this elseif condition above it takes over and sets it as a matrix: https://github.com/SciML/StochasticDiffEq.jl/blob/454bf4ea6e21a0024d2ccbb8f05c72555884501d/src/solve.jl#L287-L291 in the else part: https://github.com/SciML/StochasticDiffEq.jl/blob/454bf4ea6e21a0024d2ccbb8f05c72555884501d/src/solve.jl#L291

ChrisRackauckas commented 1 year ago

But no that cannot be related because that's a branch only for diagonal noise. Your change is only to a branch that is for non-diagonal noise. What is your MWE? What is your test case? Is it diagonal noise or not?

ba2tro commented 1 year ago
#MWE
using Lux, StochasticDiffEq, SciMLSensitivity, Random, Zygote 

rng = Random.default_rng()

abstract type NeuralSDELayer <: Lux.AbstractExplicitContainerLayer{(:model1,:model2,)} end
basic_tgrad(u,p,t) = zero(u)

struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralSDELayer
    p::P
    len::Int
    model1::M
    re1::RE
    model2::M2
    re2::RE2
    tspan::T
    args::A
    kwargs::K

    function NeuralDSDE(model1::Lux.Chain,model2::Lux.Chain,tspan,args...;
                        p1 =nothing,
                        p = nothing, kwargs...)
        re1 = nothing
        re2 = nothing
        new{typeof(model1),typeof(p),typeof(re1),typeof(model2),typeof(re2),
            typeof(tspan),typeof(args),typeof(kwargs)}(p,
            Int(1),model1,re1,model2,re2,tspan,args,kwargs)
    end
end

function (n::NeuralDSDE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
    st1 = st.model1
    st2 = st.model2
    function dudt_(u,p,t;st=st1)
      u_, st = n.model1(u,p.model1,st)
      return u_
    end
    function g(u,p,t;st=st2)
      u_, st = n.model2(u,p.model2,st)
      return u_
    end

    ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
    prob = SDEProblem{false}(ff,g,x,n.tspan,p)
    return solve(prob,n.args...;sensealg=BacksolveAdjoint(),n.kwargs...), (model1 = st1, model2 = st2)
end

x = Float32[2.; 0.]
xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.]))
tspan = (0.0f0,1.0f0)

dudt = Lux.Chain(Lux.Dense(2,50,tanh),Lux.Dense(50,2))
dudt2 = Lux.Chain(Lux.Dense(2,50,tanh),Lux.Dense(50,2))

sode = NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),EulerHeun(),saveat=0.0:0.01:0.1,dt=0.01)
pd, st = Lux.setup(rng, sode)
pd = Lux.ComponentArray(pd)

grads = Zygote.gradient((x,p,st)->sum(sode(x,p,st)[1]),xs,pd,st)

So, the problem is diagonally noised. It starts off in this direction and hits this method: https://github.com/SciML/StochasticDiffEq.jl/blob/454bf4ea6e21a0024d2ccbb8f05c72555884501d/src/perform_step/low_order.jl#L63-L68 But after a few steps it starts hitting these mul!s: https://github.com/SciML/StochasticDiffEq.jl/blob/454bf4ea6e21a0024d2ccbb8f05c72555884501d/src/perform_step/low_order.jl#L92 https://github.com/SciML/StochasticDiffEq.jl/blob/454bf4ea6e21a0024d2ccbb8f05c72555884501d/src/perform_step/low_order.jl#L104 which are indeed in the non-diagonal noise branch. I am not sure but it could be due to the elseif branches getting mixed up as pointed out above?

ChrisRackauckas commented 1 year ago

No it's because even when the forward pass is diagonal, the reverse pass is non-diagonal. So it's two different solves, the first of which is diagonal and solves fine, the second of which is non-diagonal and fails. So your MWE is just a non-diagonal case.

But... the issue is that the reverse solve uses noise from the forward pass, which is a matrix from the diagonal case, as it uses it via the NoiseWrapper. So maybe the fix is very different. https://github.com/SciML/SciMLSensitivity.jl/commit/3262518c85d12144745cb70dc140d097f601576d was supposed to solve this @frankschae, but hmm maybe it missed something? @Abhishek-1Bhatt can you check that commit is in the version you're using?

That's at least enough information to build an MWE that doesn't use DiffEqFlux. It should just be SciMLSensitivity directly, define a diagonal SDE on a matrix-defined u0 and do the gradient of it. I would think the right thing is to broadcast a vec here: https://github.com/SciML/SciMLSensitivity.jl/blob/ab35c7e77c832c6995d8575df22e5a08ddb44c3a/src/interpolating_adjoint.jl#L49

frankschae commented 1 year ago

hmm yes. I don't think I've ever used matrix-valued us before, I've always vectorized them or I've used EnsembleProblem instead. I agree that we should fix this in SciMLSensitivity.

ba2tro commented 1 year ago

So, the diagonal noise, along with the matrix dW, of the forward pass is fed to the reverse SDEProblem here: https://github.com/SciML/SciMLSensitivity.jl/blob/23d9ed5d259d4097526d6c89d136f9057e6b61ba/src/backsolve_adjoint.jl#L330 This dW ends up in the non-diagonal mul! of the EulerHeun solver, I think that doing vec in this mul! is the best option here as dW is set to be Matrix{Float64} in the forward pass, so changing it to vector by doing backwardnoise.dW = vec(backwardnoise.dW) will cause issues here: https://github.com/SciML/DiffEqNoiseProcess.jl/blob/ab356697b6553871ad3af1e5862d6918195a2672/src/types.jl#L146-L158 even though it is a mutable struct, plus it doesn't cause any issues as its meant to be a vector in the non-diagonal case, it would just act as a fail-safe

ChrisRackauckas commented 1 year ago

plus it doesn't cause any issues as its meant to be a vector in the non-diagonal case, it would just act as a fail-safe

No, doing it here is a bandaid that can allow bad/wrong behavior to accidentally work sometimes. This should get fixed at the source. Also, you'd have to not just do it to this one method, but all SDE methods, and test a bunch of it. Instead, we should just enforce the noise is a vector for the non-diagonal case if that's what it's supposed to be.

I think that doing vec in this mul! is the best option here as dW is set to be Matrix{Float64} in the forward pass, so changing it to vector by doing backwardnoise.dW = vec(backwardnoise.dW) will cause issues here:

Is that link the right link? There's nothing there that would have an issue.

ba2tro commented 1 year ago

True, if it causes bugs, it should be handled at the source. The issue with the NoiseProcess is that when we instantiate it in the forward pass the typeof(dW) is set to be the same as the typeof(rand_prototype) (called W0 here) which is Matrix{Float64} but if we do backwardnoise.dW = vec(backwardnoise.dW) we are not just changing the value in dW field but also its type to vector which is not allowed and causes error:

ERROR: MethodError: no method matching Matrix{Float64}(::Vector{Float64})
Stacktrace:

convert(#unused#::Type{Matrix{Float64}}, a::Vector{Float64}) at [.\array.jl](vscode-file://vscode-app/c:/Users/abhishek_bhatt/AppData/Local/Programs/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-sandbox/workbench/workbench.html)

setproperty!(x::DiffEqNoiseProcess.NoiseWrapper{Float64, 3, Float64, Matrix{Float64}, Nothing, DiffEqNoiseProcess.NoiseProcess{Float64, 3, Float64, Matrix{Float64}, Nothing, Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float64, Matrix{Float64}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float64, Matrix{Float64}, Nothing}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, false}, f::Symbol, v::Vector{Float64}) at [.\Base.jl](vscode-file://vscode-app/c:/Users/abhishek_bhatt/AppData/Local/Programs/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-sandbox/workbench/workbench.html)

SDEAdjointProblem(sol::RODESolution{Float32, 3, Vector{Matrix{Float32}}, Nothing, Nothing, Vector{Float64}, DiffEqNoiseProcess.NoiseProcess{Float64, 3, Float64, Matrix{Float64}, Nothing,Nothing, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float64, Matrix{Float64}, Nothing}, false}, ResettableStacks.ResettableStack{Tuple{Float64, Matrix{Float64}, Nothing}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, SDEProblem{Matrix{Float32}, Tuple{Float64, Float64}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(model1 = ViewAxis(1:252, Axis(layer_1 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150,........

But since changing it to vec here is prone to bugs, besides being cumbersome, let me try to make it work with some changes to the NoiseProcess struct if it works.

ba2tro commented 1 year ago

Completed in SciML/DiffEqNoiseProcess.jl/pull/138 and SciML/SciMLSensitivity.jl/pull/761