jax-ml / coix

Inference Combinators in JAX
https://coix.readthedocs.io/en/latest/
Apache License 2.0
43 stars 2 forks source link

jit grad update and donate args properly #39

Closed fehiepsi closed 4 months ago

fehiepsi commented 4 months ago

This is useful for optimizing large parameters, like in LLM.

Also update for the new jax.tree pattern.