dottxt-ai / outlines

Structured Text Generation
https://dottxt-ai.github.io/outlines/
Apache License 2.0
9.82k stars 501 forks source link

JAX compatible API #1027

Open borisdayma opened 4 months ago

borisdayma commented 4 months ago

Presentation of the new feature

It would be great to have a JAX compatible API with the form of a logit processor.

Input would be current vocab probabilities and output would just make invalid ones based on the grammar at current state.

Where does it fit in Outlines?

I have used outlines with transformers and a similar experience with JAX would be great as there is not currently similar functionality.

Are you willing to open a PR?

Yes, I'd love a hint of where to start (for example recommended high level functions of it was in numpy or torch tensors).

My goal is to find how to integrate it easily with JAX sampling functions such as maxtext: https://github.com/google/maxtext/blob/5bc40298530c7b5acaa42a366da1e6c2d413fac9/MaxText/inference_utils.py#L30

lapp0 commented 4 months ago

outlines.processors supports a number of array frameworks via dlpack copy-free type conversions. These are incredibly efficient and have near-zero overhead.

It seems Jax supports dlpack https://jax.readthedocs.io/en/latest/jax.dlpack.html

I'm glad you're interested in contributing! The only necessary change to support Jax are

Please let me know if you have any questions.

sky-2002 commented 1 month ago

Hi @lapp0 , I made the mentioned changes in the base_logits_processor.py and also wrote the tests. But I am getting an error in mypy pre-commit hook. image And could not find stubs for jax and jaxlib.

Any pointers on how to resolve this?

rlouf commented 1 month ago

Yes you can simply ignore them by adding them to this list in pyproject.toml

sky-2002 commented 1 month ago

Hey @rlouf , thanks. I have created a draft PR, open to feedback.