SciML / Optimization.jl

Mathematical Optimization in Julia. Local, global, gradient-based and derivative-free. Linear, Quadratic, Convex, Mixed-Integer, and Nonlinear Optimization in one simple, fast, and differentiable interface.
https://docs.sciml.ai/Optimization/stable/
MIT License
691 stars 75 forks source link

Learning rate decay in callback function #648

Closed KianHrz closed 6 months ago

KianHrz commented 6 months ago

I was wondering if there is a way to access the learning rate through a callback function when using Lux.jl and Optimization.jl packages. For instance in the line below:

Optimization.solve(optprob, ADAM(1e-3), callback = callback, progress = true, maxiters = sw) I have set the learning rate to 1e-3, and I would like to access it through the callback function to use learning rate decay.

The current available methods are intended for training their model using a loop, while the training for my code happens in one line, which is the line of code I pasted above. Therefore I would need to access the learning rate through a callback function. I was told it's possible to change the OptimizationOptimisers for this to be possible to pass the state to the callback function.