SciML / SciMLDocs

Global documentation for the Julia SciML Scientific Machine Learning Organization
https://docs.sciml.ai
MIT License
53 stars 39 forks source link

docs: use `Adam` instead of `ADAM` in missing_physics tutorial #221

Closed sathvikbhagavan closed 4 months ago

sathvikbhagavan commented 4 months ago

@ChrisRackauckas, the missing physics tutorial is erroring out in this line:

estimation_prob = ODEProblem(recovered_dynamics!, u0, tspan, get_parameter_values(nn_eqs))
estimate = solve(estimation_prob, Tsit5(), saveat = solution.t)
julia> estimate = solve(estimation_prob, Tsit5(), saveat = solution.t)
┌ Warning: dt(8.881784197001252e-16) <= dtmin(8.881784197001252e-16) at t=2.0966388837394545, and step error estimate = 1.5129604958400626. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase ~/.julia/packages/SciMLBase/szsYq/src/integrator_interface.jl:599
retcode: DtLessThanMin
Interpolation: 1st order linear
t: 9-element Vector{Float64}:
 0.0
 0.25
 0.5
 0.75
 1.0
 1.25
 1.5
 1.75
 2.0
u: 9-element Vector{Vector{Float64}}:
 [3.1461493970111687, 1.5370475785612603]
 [3.6818439489850294, 1.2765416819739916]
 [4.5052732798471755, 1.1124706463392569]
 [5.778503254169173, 1.0199394404085198]
 [7.825801842962541, 0.9880422730295028]
 [11.38347487175259, 1.0240247801514342]
 [18.594915309958083, 1.1811076688536928]
 [39.26724555866004, 1.7611949065256864]
 [330.97703916071447, 14.982623648124799]

The tspan is (0.0, 5.0) but it is stopping at t = 2.0 and retcode is DtLessThanMin. The training of the neural network is good but the symbolic regression part is not working properly.

image

options = DataDrivenCommonOptions(maxiters = 10_000,
                                  normalize = DataNormalization(ZScoreTransform),
                                  selector = bic, digits = 1,
                                  data_processing = DataProcessing(split = 0.9,
                                                                   batchsize = 30,
                                                                   shuffle = true,
                                                                   rng = StableRNG(1111)))

nn_res = solve(nn_problem, basis, opt, options = options)
nn_eqs = get_basis(nn_res)
println(nn_res)

gives

"DataDrivenSolution{Float64}" with 2 equations and 6 parameters.
Returncode: Success
Residual sum of squares: 235.60416708442358

which looks very high.