google / objax

Apache License 2.0
769 stars 77 forks source link

Add support on objax.Parallel for list arguments. #153

Closed jyh closed 3 years ago

jyh commented 3 years ago

The default is True to preserve the behavior of splitting the batch dimension across devices.

If False, arguments are passed directly to jax.pmap without reshaping. This supports cases, for example, where an argument is a list of tensors to be split across devices.