LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
475 stars 57 forks source link

Add dynamic expressions extension #585

Closed avik-pal closed 4 months ago

avik-pal commented 4 months ago

This turned out to be easier than I expected.

cc @MilesCranmer

avik-pal commented 4 months ago
using Lux, Random
using DynamicExpressions
using Zygote, Tracker, ReverseDiff
using ComponentArrays

operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos])

x1 = Node(; feature=1)
x2 = Node(; feature=2)

expr_1 = x1 * cos(x2 - 3.2)
expr_2 = x2 - x1 * x2 + 3.2 - 1.0 * x1

layer = DynamicExpressionsLayer(operators, expr_1, expr_2)

ps, st = Lux.setup(Random.default_rng(), layer)

x = rand(Float32, 2, 16)

layer(x, ps, st)

Zygote.gradient(Base.Fix1(sum, abs2) ∘ first ∘ layer, x, ps, st)

Tracker.gradient(Base.Fix1(sum, abs2) ∘ first ∘ layer, x, ps, st)

ps2 = ComponentArray(ps)

ReverseDiff.gradient((x, ps) -> sum(abs2, first(layer(x, ps, st))), (x, ps2))
avik-pal commented 4 months ago

https://lux.csail.mit.edu/previews/PR585/tutorials/advanced/2_SymbolicOptimalControl is a preview of a tutorial