joshday / SparseRegression.jl

Statistical Models with Regularization in Pure Julia
Other
40 stars 4 forks source link

Line search #8

Closed patwa67 closed 6 years ago

patwa67 commented 7 years ago

Is it possible to extract the gradient from train!() in some way and how do I call the loss function? I want to implement a line search for the learning rate:

lr=0.7, a = 0.5, g = 0.7
for i=1:it
   learn!(s, Fista(lr), MaxIter(1), Converged(coef))
   while  loss(coef(s) - lr * grad) > (loss(coef(s)) - a * lr * dot(grad,-grad))
      lr = lr * g
   end
end
joshday commented 7 years ago

The gradient (or extrapolated gradient for Fista) is stored in the algorithm type:

https://github.com/joshday/SparseRegression.jl/blob/master/src/algorithms/proxgrad.jl#L14

Something like this (untested) could work:

lr=0.7, a = 0.5, g = 0.7
for i=1:it
    alg = Fista(lr)
    learn!(s, alg, MaxIter(1), Converged(coef))
    grad = alg.∇
    while  loss(coef(s) - lr * grad) > (loss(coef(s)) - a * lr * dot(grad,-grad))
         lr = lr * g
    end
end

I think I have a solution to do line search generally for ProxGrad/Fista/GradientDescent, but I haven't had time to implement it yet.

joshday commented 7 years ago

On second thought, while that is a way to get the gradient, this won't actually use the FISTA algorithm because a new instance of Fista is generated at each iteration. It will (I think) essentially do ProxGrad. I'll think on this.

patwa67 commented 7 years ago

Yes, you are right. It should go into the algorithms, just after the gradient!().

joshday commented 6 years ago

See http://joshday.github.io/SparseRegression.jl/latest/algorithms.html#LineSearch-1