jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.61k stars 2.82k forks source link

Added stream annotation support via @compute_on('stream:#') decorator #25056

Open chaserileyroberts opened 6 days ago

chaserileyroberts commented 6 days ago

This is a tiny change that will add the stream annotation frontend_attribute when using this compute_on device type.

This feature is not yet fully enabled in XLA, and will sit behind a --xla_gpu_experimental_stream_annotation flag for the foreseeable future. When this flag is not enabled, this attribute is just ignored.

yashk2810 commented 6 days ago

Would be nice to include more in the description of why support this! i.e. make the description more fleshed out.