The default is True to preserve the behavior of splitting the batch
dimension across devices.
If False, arguments are passed directly to jax.pmap without reshaping.
This supports cases, for example, where an argument is a list of tensors
to be split across devices.
The default is True to preserve the behavior of splitting the batch dimension across devices.
If False, arguments are passed directly to jax.pmap without reshaping. This supports cases, for example, where an argument is a list of tensors to be split across devices.