outlines-dev / outlines

Structured Text Generation
https://outlines-dev.github.io/outlines/
Apache License 2.0
8.16k stars 411 forks source link

JAX compatible API #1027

Open borisdayma opened 2 months ago

borisdayma commented 2 months ago

Presentation of the new feature

It would be great to have a JAX compatible API with the form of a logit processor.

Input would be current vocab probabilities and output would just make invalid ones based on the grammar at current state.

Where does it fit in Outlines?

I have used outlines with transformers and a similar experience with JAX would be great as there is not currently similar functionality.

Are you willing to open a PR?

Yes, I'd love a hint of where to start (for example recommended high level functions of it was in numpy or torch tensors).

My goal is to find how to integrate it easily with JAX sampling functions such as maxtext: https://github.com/google/maxtext/blob/5bc40298530c7b5acaa42a366da1e6c2d413fac9/MaxText/inference_utils.py#L30

lapp0 commented 1 month ago

outlines.processors supports a number of array frameworks via dlpack copy-free type conversions. These are incredibly efficient and have near-zero overhead.

It seems Jax supports dlpack https://jax.readthedocs.io/en/latest/jax.dlpack.html

I'm glad you're interested in contributing! The only necessary change to support Jax are

Please let me know if you have any questions.