jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

Improve jax.Array documentation #19342

Open hawkinsp opened 9 months ago

hawkinsp commented 9 months ago

https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array

Many of the methods and properties of jax.Array are not sufficiently documented.

addressable_data(index) | Return an array of the addressable data at a particular index.

I'm none the wiser from that description what exactly index is. Is it an integer? Is it a NumPy index? If it is an integer, what does the integer mean? What is the "index-th addressable data"?

addressable_shards | List of addressable shards.

What type of object is a shard? Does they come in any particular order?

jakevdp commented 9 months ago

I don't think ArrayImpl appears in the documentation at all

This came up when we were creating the linked page. I recall @yashk2810 was a strong -1 on documenting ArrayImpl. Some impl-only functions are not documented, becuase they don't exist on the base Array class.

For the others, getting more detailed docs is mainly about changing how we declare jax.Array in the sphinx sources.

hawkinsp commented 9 months ago

I think we should document the existence of ArrayImpl but not say a whole lot about it, pointing the user to `Array.