Closed lucpaoli closed 10 months ago
does this requires defining ChainRules
rules on Clapeyron.jl?
@longemen3000 that's what I've done for now, you can see it here: https://github.com/lucpaoli/SAFT_ML/blob/master/6_sat_solver_NN.ipynb
Or just the relevant part:
function saturation_pressure_NN(X, T)
model = make_model(X...)
p, Vₗ, Vᵥ = saturation_pressure(model, T)
return p
end
function ChainRulesCore.rrule(::typeof(saturation_pressure_NN), X, T)
model = make_model(X...)
p, Vₗ, Vᵥ = saturation_pressure(model, T)
function f_pullback(Δy)
#* Newton step from perfect initialisation
function f_p(X, T)
model = make_NN_model(X...)
p2 = -(eos(model, Vᵥ, T) - eos(model, Vₗ, T))/(Vᵥ - Vₗ);
return p2
end
∂X = @thunk(ForwardDiff.gradient(X -> f_p(X, T), X) .* Δy)
∂T = @thunk(ForwardDiff.derivative(T -> f_p(X, T), T) .* Δy)
return (NoTangent(), ∂X, ∂T)
end
return p, f_pullback
end
The only problem with this approach is I can't seem to get this to work with differentiating through the saturation densities with respect to the parameters, ForwardDiff appears to get confused tracking changes & zygote costs blow up. A package that appears as though it may solve this is TaylorDiff, but that is also states as being in alpha & I don't fully understand it. An example of being unable to differentiate the saturation densities can be seen here: https://github.com/lucpaoli/SAFT_ML/blob/1455488cf4d2b4ef288ab90ae9994b4e142fb48c/6_sat_solver_NN.ipynb
With this not functioning as I would expect:
T = 100.0
p, Vₗ, Vᵥ = impl_sat_p_v2(X)
# function f_pullback(Δy)
# Functions to calculate p, Vₗ, Vᵥ from a Newton step starting from the converged point
# This is the approach in Winter et al.
function f_p(X)
T = 100.0
model = make_NN_model(16.04, X...)
p2 = -(eos(model, Vᵥ, T) - eos(model, Vₗ, T))/(Vᵥ - Vₗ);
return p2
end
function pressure(model::SAFTVRMieNN, V, T, z=[1.0])
p = -ForwardDiff.derivative(V -> eos(model, V, T, z), V)
return p
end
function f_V(X, V)
T = 100.0
model = make_NN_model(16.04, X...)
# ∂p∂V = Zygote.gradient(V -> pressure(model, V, T), V)[1]
∂p∂V = ForwardDiff.derivative(V -> pressure(model, V, T), V)
V2 = V - (pressure(model, V, T) - p)/∂p∂V
return V2
end
@show p, f_p(X), ForwardDiff.gradient(f_p, X);
@show Vₗ, f_V(X, Vₗ), ForwardDiff.gradient(X -> f_V(X, Vₗ), X);
Possibly from the mixing of first and second order forwarddiff calls all at once?I believe this is a relevant thread, though I need some more time to understand it: https://discourse.julialang.org/t/is-it-possible-to-do-nested-ad-elegantly-in-julia-pinns/98888
This follows the approach set out in the SI of winter et al.(arXiv:2309.12404), though I imagine speeding up derivatives by doing this "perfect newton step" trick is far from novel
https://docs.juliahub.com/NonconvexUtils/N9Uwz/0.4.1/#:~:text=The%20implicit%20function%20theorem%20assumes,since%20its%20assumption%20is%20violated.
https://github.com/gdalle/ImplicitDifferentiation.jl