google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.25k stars 249 forks source link

update progress_fn to return params at each checkpoint #299

Closed vijaysundaram closed 1 year ago

vijaysundaram commented 1 year ago

Howdy Brax team!

Can we update progress_fn arguments take and return params, in addition to metrics, at each checkpoint during training?

Yours, Vijay

erikfrey commented 1 year ago

Just so we don't break existing APIs, we created a new function: policy_params_fn that will pass the current step, make policy function, and params:

https://github.com/google/brax/blob/455e750f9354d986ef00523cb698cfe627979c9c/brax/training/agents/ppo/train.py#L87

This is live as of 455e750f9354d986ef00523cb698cfe627979c9c. Enjoy!