Open borisdayma opened 4 months ago
outlines.processors
supports a number of array frameworks via dlpack
copy-free type conversions. These are incredibly efficient and have near-zero overhead.
It seems Jax supports dlpack https://jax.readthedocs.io/en/latest/jax.dlpack.html
I'm glad you're interested in contributing! The only necessary change to support Jax are
OutlinesLogitsProcessor
to support jax -> torch and torch -> jax. https://github.com/outlines-dev/outlines/blob/main/outlines/processors/base_logits_processor.py#L91-L135tests/processors/test_base_processor.py
)Please let me know if you have any questions.
Hi @lapp0 , I made the mentioned changes in the base_logits_processor.py
and also wrote the tests. But I am getting an error in mypy
pre-commit hook.
And could not find stubs for jax
and jaxlib
.
Any pointers on how to resolve this?
Yes you can simply ignore them by adding them to this list in pyproject.toml
Hey @rlouf , thanks. I have created a draft PR, open to feedback.
Presentation of the new feature
It would be great to have a JAX compatible API with the form of a logit processor.
Input would be current vocab probabilities and output would just make invalid ones based on the grammar at current state.
Where does it fit in Outlines?
I have used outlines with transformers and a similar experience with JAX would be great as there is not currently similar functionality.
Are you willing to open a PR?
Yes, I'd love a hint of where to start (for example recommended high level functions of it was in numpy or torch tensors).
My goal is to find how to integrate it easily with JAX sampling functions such as maxtext: https://github.com/google/maxtext/blob/5bc40298530c7b5acaa42a366da1e6c2d413fac9/MaxText/inference_utils.py#L30