lucpaoli / SAFT_ML

2 stars 0 forks source link

Set up training on saturation pressure. If training slow (because of differentiating through saturation pressure solver), can use implicit function theorem. See packages in description #12

Closed lucpaoli closed 10 months ago

lucpaoli commented 11 months ago

https://docs.juliahub.com/NonconvexUtils/N9Uwz/0.4.1/#:~:text=The%20implicit%20function%20theorem%20assumes,since%20its%20assumption%20is%20violated.

https://github.com/gdalle/ImplicitDifferentiation.jl

longemen3000 commented 11 months ago

does this requires defining ChainRules rules on Clapeyron.jl?

lucpaoli commented 11 months ago

@longemen3000 that's what I've done for now, you can see it here: https://github.com/lucpaoli/SAFT_ML/blob/master/6_sat_solver_NN.ipynb

Or just the relevant part:

function saturation_pressure_NN(X, T)
    model = make_model(X...)
    p, Vₗ, Vᵥ = saturation_pressure(model, T)

    return p
end

function ChainRulesCore.rrule(::typeof(saturation_pressure_NN), X, T)
    model = make_model(X...)
    p, Vₗ, Vᵥ = saturation_pressure(model, T)

    function f_pullback(Δy)
        #* Newton step from perfect initialisation
        function f_p(X, T)
            model = make_NN_model(X...)
            p2 = -(eos(model, Vᵥ, T) - eos(model, Vₗ, T))/(Vᵥ - Vₗ);
            return p2
        end

        ∂X = @thunk(ForwardDiff.gradient(X -> f_p(X, T), X) .* Δy)
        ∂T = @thunk(ForwardDiff.derivative(T -> f_p(X, T), T) .* Δy)
        return (NoTangent(), ∂X, ∂T)
    end

    return p, f_pullback
end

The only problem with this approach is I can't seem to get this to work with differentiating through the saturation densities with respect to the parameters, ForwardDiff appears to get confused tracking changes & zygote costs blow up. A package that appears as though it may solve this is TaylorDiff, but that is also states as being in alpha & I don't fully understand it. An example of being unable to differentiate the saturation densities can be seen here: https://github.com/lucpaoli/SAFT_ML/blob/1455488cf4d2b4ef288ab90ae9994b4e142fb48c/6_sat_solver_NN.ipynb

With this not functioning as I would expect:

T = 100.0
p, Vₗ, Vᵥ = impl_sat_p_v2(X)

# function f_pullback(Δy)
# Functions to calculate p, Vₗ, Vᵥ from a Newton step starting from the converged point
# This is the approach in Winter et al.
function f_p(X)
    T = 100.0 
    model = make_NN_model(16.04, X...)
    p2 = -(eos(model, Vᵥ, T) - eos(model, Vₗ, T))/(Vᵥ - Vₗ);
    return p2
end

function pressure(model::SAFTVRMieNN, V, T, z=[1.0])
    p = -ForwardDiff.derivative(V -> eos(model, V, T, z), V)
    return p
end

function f_V(X, V)
    T = 100.0 
    model = make_NN_model(16.04, X...)
    # ∂p∂V = Zygote.gradient(V -> pressure(model, V, T), V)[1]
    ∂p∂V = ForwardDiff.derivative(V -> pressure(model, V, T), V)
    V2 = V - (pressure(model, V, T) - p)/∂p∂V
    return V2
end

@show p, f_p(X), ForwardDiff.gradient(f_p, X);
@show Vₗ, f_V(X, Vₗ), ForwardDiff.gradient(X -> f_V(X, Vₗ), X);

Possibly from the mixing of first and second order forwarddiff calls all at once?I believe this is a relevant thread, though I need some more time to understand it: https://discourse.julialang.org/t/is-it-possible-to-do-nested-ad-elegantly-in-julia-pinns/98888

lucpaoli commented 11 months ago

This follows the approach set out in the SI of winter et al.(arXiv:2309.12404), though I imagine speeding up derivatives by doing this "perfect newton step" trick is far from novel