FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.43k stars 601 forks source link

Derivative in loss function error #1464

Open stelmo opened 3 years ago

stelmo commented 3 years ago

Hi, I am trying to implement a PINN as described here using Flux. Essentially, I am trying to train a neural network that includes the time derivative of it in the loss function (time is one of its inputs). Below is a very minimal example:

using Flux

m = Chain(Dense(3, 10, relu), Dense(10, 10, relu), Dense(10, 1)) # [u0, k, t] -> u(t)
ps = Flux.params(m)

function loss(x, y)
    fitloss = Flux.Losses.mse(m(x), y) # typical loss function

    derivativeloss = abs2(gradient(a -> m(a)[1], x)[1][3]) #  problem source (3rd input is time)

    return fitloss + derivativeloss
end

xt = rand(3)
yt = rand(1)

gs = gradient(ps) do
    loss(xt, yt)
end # this generates a foreigncall exception

This issue seems to be pervasive, see here and #1338 and #1257 and here (the last one is me on the discourse channel). I have tried all the suggestions in the aforementioned links, but nothing seems to work. Do you have a work around or is this some built in limitation of Flux/Zygote?

DhairyaLGandhi commented 3 years ago

Thanks for brining this up! If you remove the need to use params, you get a more meaningful error.

function loss(m, x, y)
    fitloss = Flux.Losses.mse(m(x), y) # typical loss function

    derivativeloss = abs2(gradient(a -> m(a)[1], x)[1][3]) #  problem source (3rd input is time)

    return fitloss + derivativeloss
end

gs = gradient(m, xt, yt) do m, x, y
    loss(m, x, y)
end 

julia> gs = gradient(m, xt, yt) do m, x, y
           loss(m, x, y)
       end # this generates a foreigncall exception
ERROR: Mutating arrays is not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#380#381")(::Nothing) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/array.jl:58
 [3] (::Zygote.var"#2288#back#382"{Zygote.var"#380#381"})(::Nothing) at /Users/dhairyagandhi/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [4] (::Zygote.var"#152#153"{Zygote.var"#2288#back#382"{Zygote.var"#380#381"},Tuple{Tuple{Nothing,Nothing},Tuple{Nothing}}})(::Nothing) at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/lib.jl:191
 [5] (::Zygote.var"#1699#back#154"{Zygote.var"#152#153"{Zygote.var"#2288#back#382"{Zygote.var"#380#381"},Tuple{Tuple{Nothing,Nothing},Tuple{Nothing}}}})(::Nothing) at /Users/dhairyagandhi/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [6] #372 at /Users/dhairyagandhi/Downloads/new_clones/Zygote.jl/src/lib/array.jl:38 [inlined]

The same can be seen using https://github.com/FluxML/Zygote.jl/pull/823 directly as well.

The Fix

This comes from picking out the first element of the model output, instead if you use something like

julia> function loss(m, x, y)
           fitloss = Flux.Losses.mse(m(x), y) # typical loss function

           derivativeloss = abs2(gradient(a -> sum(m(a)), x)[1][3]) #  problem source (3rd input is time)

           return fitloss + derivativeloss
       end

(notice the call to sum instead of the first element), then things work fine.

In cases where we do need the element wise grads, sum doesn't quite cut it, and we need to write the getindex adjoint to not do array mutation (which it does for perf). We could always write a fallback rule that handles it as well.

stelmo commented 3 years ago

Thank you very much! This fixes the one dimensional case :) I will add this thread to my discourse question (as well as the other question on there).

How would one write the adjoint for the getindex function? I will need it to extend the PINN system to multiple dimensions...

I looked at the source code for Base.getindex and it is not very clear to me what's going on - hopefully I won't have to change anything there? The documentation of Zygote mentions that one should use ChainRules to define custom adjoint rules. Should I use that instead to define a fallback rule?

DhairyaLGandhi commented 3 years ago

What would the multidimensional case entail?

You're better off checking out the code for the get index adjoint in Zygote in src/lib/array.jl and write an adjoint to that

stelmo commented 3 years ago

Essentially I mean element wise gradients e.g.

m = Chain(Dense(5, 10, relu), Dense(10, 10, relu), Dense(10, 2)) # [u0_1, u0_2, k1, k2, t] -> [u1(t), u2(t)]

function loss(m, x, y)
    fitloss = Flux.Losses.mse(m(x), y) # typical loss function

    derivativeloss = 0.0f0

    for i=1:size(x, 2)
        for j=1:2 # dimension to take gradient of
            derivativeloss = abs2( gradient(a -> m(a)[j], x[:, i])[1][5] ) # ||du_j/dt|| for j=1,2 this mutates again :(
        end
    end

    return fitloss + derivativeloss
end

xt = rand(5, 10)
yt = rand(2, 10)

gs = gradient(m, xt, yt) do m, x, y
    loss(m, x, y)
end

Okay, I've looked in Zygote/src/lib/array.jl and found this:

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

∇getindex(x::AbstractArray, inds) = dy -> begin
  if inds isa  NTuple{<:Any, Integer}
    dx = _zero(x, typeof(dy))
    dx[inds...] = dy
  else
    dx = _zero(x, eltype(dy))
    dxv = view(dx, inds...)
    dxv .= accum.(dxv, _droplike(dy, dxv))
  end
  return (dx, map(_->nothing, inds)...)
end

Is the mutation happening here:

dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))

Would I have to change that somehow (I am not sure at all what to do)? I really appreciate your help in this!

DhairyaLGandhi commented 3 years ago

Right, anywhere setindex gets in the way.

mkalia94 commented 3 years ago

Bump. Zygote.jl now has a jacobian function but results in the same mutating array error. I think this is in general very critical for physics informed machine learning, would be great to have some clarity on this! @DhairyaLGandhi I'm not quite sure I get your last comment. Thanks!

ToucheSir commented 3 years ago

This is totally not my wheelhouse, but there are a few threads on discourse about representing PINNs that you could look into.

NagaChaitanya96 commented 2 years ago

Essentially I mean element wise gradients e.g.

m = Chain(Dense(5, 10, relu), Dense(10, 10, relu), Dense(10, 2)) # [u0_1, u0_2, k1, k2, t] -> [u1(t), u2(t)]

function loss(m, x, y)
    fitloss = Flux.Losses.mse(m(x), y) # typical loss function

    derivativeloss = 0.0f0

    for i=1:size(x, 2)
        for j=1:2 # dimension to take gradient of
            derivativeloss = abs2( gradient(a -> m(a)[j], x[:, i])[1][5] ) # ||du_j/dt|| for j=1,2 this mutates again :(
        end
    end

    return fitloss + derivativeloss
end

xt = rand(5, 10)
yt = rand(2, 10)

gs = gradient(m, xt, yt) do m, x, y
    loss(m, x, y)
end

Okay, I've looked in Zygote/src/lib/array.jl and found this:

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

∇getindex(x::AbstractArray, inds) = dy -> begin
  if inds isa  NTuple{<:Any, Integer}
    dx = _zero(x, typeof(dy))
    dx[inds...] = dy
  else
    dx = _zero(x, eltype(dy))
    dxv = view(dx, inds...)
    dxv .= accum.(dxv, _droplike(dy, dxv))
  end
  return (dx, map(_->nothing, inds)...)
end

Is the mutation happening here:

dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))

Would I have to change that somehow (I am not sure at all what to do)? I really appreciate your help in this!

This code isnt correct right? I'm getting error, at the gradient part.