google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
903 stars 62 forks source link

JAXOPT Projected Gradient #582

Open pulkitchhabra19 opened 4 months ago

pulkitchhabra19 commented 4 months ago

Hello, I have been using JAXOPT Projected Gradient approach with projection box constraints. And I would like to get more output from the algorithm than it is providing such as I would like to have to loss value at each iteration as output. I believe I can set verbose to False to display the error at each iteration but instead I'd like it as output.

Secondly, I believe the tolerance that is passed to the Proximal Gradient class is not used anywhere and so if that is the case, this might be a bug as no matter what tolerance I set, it would not matter. Could you please look into it? Thanks.

Irrespective of what tolerance I put, it runs to the max iterations and so I looked at this and found that tolerance is not used at all: https://github.com/google/jaxopt/blob/501cc208c2493395fbe8026b963e7867397403db/jaxopt/_src/proximal_gradient.py#L153