Closed Joshuaalbert closed 10 months ago
def loss(U: UType): return -model.log_prob_likelihood(U, allow_nan=False) solver = NonlinearCG( fun=loss, jit=True, unroll=False, verbose=False ) results = solver.run(init_params=init_U_point)
It still prints out linesearch debug lines like:
WARNING: jaxopt.ZoomLineSearch: Returning stepsize with sufficient decrease but curvature condition not satisfied. INFO: jaxopt.ZoomLineSearch: Iter: 9 Minimum Decrease & Curvature Errors (stop. crit.): 3.3350500139306405e-09 Stepsize:0.00016820144082885236 Decrease Error:0.0 Curvature Error:3.3350500139306405e-09
Thanks for bringing this up. https://github.com/google/jaxopt/pull/573 should fix the issue, but I need to investigate the failing tests.
It still prints out linesearch debug lines like: