SciML / ModelingToolkitNeuralNets.jl

Symbolic-Numeric Universal Differential Equations for Automating Scientific Machine Learning (SciML)
MIT License
23 stars 1 forks source link

NaN in gradients #6

Closed SebastianM-C closed 6 months ago

SebastianM-C commented 6 months ago

When working locally on this I initially encountered an issue where the gradient would always be NaN, which is what I think it's causing #5. I enabled NaN safe mode and that seemed to fix the issue. Is that a bug or should we just document this?

Also, what's the best way of setting this up in CI?

ChrisRackauckas commented 6 months ago

What do you mean by NaN safe mode?

SebastianM-C commented 6 months ago

https://juliadiff.org/ForwardDiff.jl/stable/user/advanced/#Fixing-NaN/Inf-Issues

ChrisRackauckas commented 6 months ago

You shouldn't need that here.

ChrisRackauckas commented 6 months ago

Why do you get a NaN in the first place?

SebastianM-C commented 6 months ago

It looks like the problem comes from the loss computation

using ForwardDiff
x0′ = ForwardDiff.Dual{:tag}.(x0, 1)

test_p = SciMLStructures.replace(Tunable(), prob.p, x0′)
test_prob = remake(prob, p = test_p)

test_sol = solve(test_prob, Rodas4(autodiff=false), saveat=sol_ref.t)
sum(sqrt.(abs2.(get_vars(test_sol, 1) .- get_refs(sol_ref, 1))))

gives

Dual{:tag}(0.0,NaN)

I also see that NaNs appear if I print in the loss with

for i in eachindex(new_sol.u)
        loss += sum(sqrt.(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i))))
        if any(isnan.(ForwardDiff.partials(loss)))
            @info i
        end
    end
ChrisRackauckas commented 6 months ago

What's the first spot of nan?

SebastianM-C commented 6 months ago

It's due to sqrt.

ChrisRackauckas commented 6 months ago

What are the values?

SebastianM-C commented 6 months ago
julia> sqrt.(abs2.(get_vars(test_sol, 1) .- get_refs(sol_ref, 1)))
2-element Vector{ForwardDiff.Dual{:tag, Float64, 1}}:
 Dual{:tag}(0.0,NaN)
 Dual{:tag}(0.0,NaN)
ChrisRackauckas commented 6 months ago

Yes but what are the values that go in?

SebastianM-C commented 6 months ago
julia> get_vars(test_sol, 1)
2-element Vector{ForwardDiff.Dual{:tag, Float64, 1}}:
 Dual{:tag}(3.1,0.0)
 Dual{:tag}(1.5,0.0)

julia> get_refs(sol_ref, 1)
2-element Vector{Float64}:
 3.1
 1.5

Hmm, let me check why they are the same :thinking:

SebastianM-C commented 6 months ago

aah, it because we start with the same initial conditions

sum(sqrt.(abs2.(get_vars(test_sol, 2) .- get_refs(sol_ref, 2)))

gives Dual{:tag}(0.2685941909005718,0.08984350230039442)

ChrisRackauckas commented 6 months ago

Yeah the gradient at zero is NaN for sqrt. That seems like a loss function issue.

SebastianM-C commented 6 months ago

I started with the same initial conditions as in https://docs.sciml.ai/Overview/stable/showcase/missing_physics/, which means that at the very first time point we get 0 and NaN in the gradient, which ends up poisoning the whole loss.

SebastianM-C commented 6 months ago

So not a bug, but we should document this.