google-deepmind / penzai

A JAX research toolkit for building, editing, and visualizing neural networks.
https://penzai.readthedocs.io/
Apache License 2.0
1.68k stars 53 forks source link

FR: expose device/sharding argument for NamedArray.wrap #80

Open amifalk opened 3 months ago

danieldjohnson commented 3 months ago

Hm, I'm not sure this is necessary? You can provide the device and sharding for the array before wrapping it, and pz.nx.wrap should preserve it, e.g.

arr = pz.nx.wrap(jax.device_put(my_array, my_sharding))
amifalk commented 3 months ago

Yep, just a little syntactic sugar :). When training large models, I often use a data loader to grab data from the disk and process it onto a numpy array, then directly from there convert it to a NamedArray sharded over devices.