LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

Rework ChainRules for DynamicExpressions #608

Open avik-pal opened 2 months ago

avik-pal commented 2 months ago

DynamicExpressions supports ChainRules starting v0.17 https://github.com/SymbolicML/DynamicExpressions.jl/pull/71. We can remove parts of our code with CRC.rrule_via_ad. We still need to define a rule because we do an in-place node update. Additionally we need to extract the node parameters in the final parameter gradient.

avik-pal commented 2 months ago

~Needs some investigation, I wasn't able to unthunk the Tangent coming from https://github.com/SymbolicML/DynamicExpressions.jl/pull/71~

This would need some further thought.

function Lux.__apply_dynamic_expression_rrule(
        de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps)
    Lux.__update_expression_constants!(expr, ps)
    @static if pkgversion(DynamicExpressions) < v"0.17"
        error("`DynamicExpressions` v0.17 or later is required for reverse mode to work.")
    end
    (y, _), pb_f = CRC.rrule(eval_tree_array, expr, x, operator_enum; de.turbo, de.bumper)
    __∇apply_dynamic_expression = @closure Δ -> begin
        _, ∂expr, ∂x, ∂operator_enum = pb_f((Δ, nothing))
        ∂ps = CRC.unthunk(∂expr).gradient
        return NoTangent(), NoTangent(), NoTangent(), ∂operator_enum, ∂x, ∂ps, NoTangent()
    end
    return y, __∇apply_dynamic_expression
end

This works but we hit a clear regression on mixed-precision. Maybe once that is handled upstream we can use the rrule directly