probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
636 stars 69 forks source link

Support gradient descent with multi-argument objective functions #311

Closed mikewojnowicz closed 1 year ago

mikewojnowicz commented 1 year ago

Support gradient descent with multi-argument objective functions. Gradient descent is done on the first argument. Remaining arguments providing context, if they exist, can be optionally passed via **kwargs. This way, callers are not forced to pass such arguments as global variables.

gileshd commented 1 year ago

Hi @mikewojnowicz, thanks for showing interesting the library!

I agree that it is probably helpful to not have to rely on global variables to control the behaviour of objective functions. I wonder if your use case could also be achieved by creating appropriate 'partial' or 'wrapped' objective functions?

Something like:

import functools 

# using lambda
obj = lambda params: objective(params, arg1="foo", arg2="bar")
 # using functools.partial
obj = functools.partial(objective, arg1="foo", arg2="bar")

# Pass this wrapped version of the objective function to gradient descent algorithm
run_gradient_descent(objective=obj, ...)

This feels like a pattern which appears often in other jax codebases and avoids relying on the undeclared default value of argnums=0 in jax.value_and_grad. I also worry slightly that passing around **kwargs might get a bit unwieldy.

Having said that I don't have loads of experience with this bit of the library @xinglong-li any thoughts?

mikewojnowicz commented 1 year ago

@gileshd Thanks for your response. Using partial functions did not initially occur to me. I tried it out and it works great. I agree that it's a preferable solution to passing around **kwargs. Closing this now.