Closed SebastianM-C closed 6 months ago
What do you mean by NaN safe mode?
You shouldn't need that here.
Why do you get a NaN in the first place?
It looks like the problem comes from the loss computation
using ForwardDiff
x0′ = ForwardDiff.Dual{:tag}.(x0, 1)
test_p = SciMLStructures.replace(Tunable(), prob.p, x0′)
test_prob = remake(prob, p = test_p)
test_sol = solve(test_prob, Rodas4(autodiff=false), saveat=sol_ref.t)
sum(sqrt.(abs2.(get_vars(test_sol, 1) .- get_refs(sol_ref, 1))))
gives
Dual{:tag}(0.0,NaN)
I also see that NaNs appear if I print in the loss with
for i in eachindex(new_sol.u)
loss += sum(sqrt.(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i))))
if any(isnan.(ForwardDiff.partials(loss)))
@info i
end
end
What's the first spot of nan?
It's due to sqrt
.
What are the values?
julia> sqrt.(abs2.(get_vars(test_sol, 1) .- get_refs(sol_ref, 1)))
2-element Vector{ForwardDiff.Dual{:tag, Float64, 1}}:
Dual{:tag}(0.0,NaN)
Dual{:tag}(0.0,NaN)
Yes but what are the values that go in?
julia> get_vars(test_sol, 1)
2-element Vector{ForwardDiff.Dual{:tag, Float64, 1}}:
Dual{:tag}(3.1,0.0)
Dual{:tag}(1.5,0.0)
julia> get_refs(sol_ref, 1)
2-element Vector{Float64}:
3.1
1.5
Hmm, let me check why they are the same :thinking:
aah, it because we start with the same initial conditions
sum(sqrt.(abs2.(get_vars(test_sol, 2) .- get_refs(sol_ref, 2)))
gives Dual{:tag}(0.2685941909005718,0.08984350230039442)
Yeah the gradient at zero is NaN for sqrt. That seems like a loss function issue.
I started with the same initial conditions as in https://docs.sciml.ai/Overview/stable/showcase/missing_physics/, which means that at the very first time point we get 0 and NaN in the gradient, which ends up poisoning the whole loss.
So not a bug, but we should document this.
When working locally on this I initially encountered an issue where the gradient would always be NaN, which is what I think it's causing #5. I enabled NaN safe mode and that seemed to fix the issue. Is that a bug or should we just document this?
Also, what's the best way of setting this up in CI?