google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.8k stars 611 forks source link

Consider deprecating common_utils.shard #948

Open jheek opened 3 years ago

jheek commented 3 years ago

The following pattern is common is user code and our examples:

def shard(pytree, n_devices):
  def _shard_array(array):
    return array.reshape((n_devices, -1) + array.shape[1:])
  return jax.tree_map(_shard_array, pytree)

The shard utility is part of flax.common_utils.

This pattern has 2 issues:

  1. It complicates the training loop because part of the preprocessing is done async in tf.data while shard operates in the training loop
  2. The code works both on NumPy and JAX arrays but has completely different performance characteristics. On numpy this is a view changes which is practically free. Whereas on Jax arrays this will trigger a copy that runs on the device which can cost a lot of time and memory in particular on TPUs.

Proposed solution

Make sure input data already has the correct shape at the end of the input pipeline. In tf.data we can do this by using a double batch call:

dataset = dataset.batch(device_batch_size).batch(num_devices)

The advantage here is that the input pipeline always provides use with correctly shaped data. No reshaping is done in the critical loop and there is no confusion about numpy vs jax arrays.

TODO:

peterjliu commented 3 years ago

Might still be useful for people who don't use tf.data?

1e100 commented 1 year ago

People will just implement an ad-hoc function with the exact same semantics. It's too convenient to do without. When you say "copy [...] takes a lot of time", it'd be good to supply some data to support the assertion.

Also I don't see how something like:

split_batch = shard(batch)

complicates much of anything.