Closed mikewojnowicz closed 1 year ago
Hi @mikewojnowicz, thanks for showing interesting the library!
I agree that it is probably helpful to not have to rely on global variables to control the behaviour of objective functions. I wonder if your use case could also be achieved by creating appropriate 'partial' or 'wrapped' objective functions?
Something like:
import functools
# using lambda
obj = lambda params: objective(params, arg1="foo", arg2="bar")
# using functools.partial
obj = functools.partial(objective, arg1="foo", arg2="bar")
# Pass this wrapped version of the objective function to gradient descent algorithm
run_gradient_descent(objective=obj, ...)
This feels like a pattern which appears often in other jax codebases and avoids relying on the undeclared default value of argnums=0
in jax.value_and_grad
. I also worry slightly that passing around **kwargs
might get a bit unwieldy.
Having said that I don't have loads of experience with this bit of the library @xinglong-li any thoughts?
@gileshd Thanks for your response. Using partial functions did not initially occur to me. I tried it out and it works great. I agree that it's a preferable solution to passing around **kwargs
. Closing this now.
Support gradient descent with multi-argument objective functions. Gradient descent is done on the first argument. Remaining arguments providing context, if they exist, can be optionally passed via **kwargs. This way, callers are not forced to pass such arguments as global variables.