patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.05k stars 135 forks source link

Best Dataloader practices #137

Open jejjohnson opened 2 years ago

jejjohnson commented 2 years ago

I was wondering if the equinox community has good advice / best practices for creating dataloaders that work well with jax? I've done some stuff from scratch but I tend to find that the training (on GPU) is a bit slower than if I use Pytorch + Pytorch Lightning (nothing fancy just a single GPU). But then I see claims such as this one with flax vs PyTorch-Lightning and I wonder perhaps it's something on my end. But I wanted to ask the equinox community about their experience.

For example, the work from the jax website shows a tensorflow dataloader example and pytorch dataloader example. But if you run these, the pytorch dataloader works 1-2 seconds faster per epoch than the tensorflow dataloader. In the docs I see custom iterators and in other similar libraries that have more elaborate schemes, e.g. treex, objax.

patrick-kidger commented 2 years ago

FWIW I tend to use the PyTorch dataloaders, mostly just out of familiarity. I know many folks use TF as well.

(The custom dataloaders in the docs are just to minimise dependencies for the examples.)

If I'm working with a small dataset then it can be advantageous to keep the entire dataset in-memory, see for example mmap.ninja. (Or you can probably just slice NumPy arrays - I've not checked too closely how that compares.)

jejjohnson commented 2 years ago

Thanks for the advice @patrick-kidger ! I will do the same. I'm also more familiar with the PyTorch dataloaders as well. Do you do similar practices to, for example, what's done here? (i.e. standard dataset, dataloader, and collate function?)

patrick-kidger commented 2 years ago

Yep pretty much.

jatentaki commented 2 years ago

A word of caution: do not make your collate return np.array instead of Tensor (and I think the same applies to return values from the dataset class). It's convenient but PyTorch has some custom hackery in Tensor to make them efficient to be exchanged between processes (I think by avoiding serialization at process boundaries), which NumPy arrays lack. So it's much better to just return (structs of) Tensors and then jax.tree_map(lambda tensor: tensor.numpy(), batch).

patrick-kidger commented 2 years ago

That I didn't know! Thanks, that's really useful information.

jejjohnson commented 2 years ago

A word of caution: do not make your collate return np.array instead of Tensor (and I think the same applies to return values from the dataset class). It's convenient but PyTorch has some custom hackery in Tensor to make them efficient to be exchanged between processes (I think by avoiding serialization at process boundaries), which NumPy arrays lack. So it's much better to just return (structs of) Tensors and then jax.tree_map(lambda tensor: tensor.numpy(), batch).

Thanks! This is good advice! So just to be clear:

Did I understand that correctly?

elyxlz commented 1 month ago

this feels very ugly, numpy -> torch -> jax at every dataloading step? Why must I download a huge library like torch just for dataloading?

there must be a better solution..

nasyxx commented 1 month ago

this feels very ugly, numpy -> torch -> jax at every dataloading step?

You can use numpy but torch in PyTorch dataloader, the numpy array to torch tensor happened in collate_fn. So, you can directly convert numpy array to JAX array if you define your own collate function.

Why must I download a huge library like torch just for dataloading?

If you only need the index dataloader:

def _np_sort(x: jax.Array, axis: int | None = None) -> np.ndarray:
  """Sort the array."""
  return np.argsort(np.asarray(x), axis=axis)

@filter_jit
def fallback_argsort(x: jax.Array, axis: int | None = None) -> jax.Array:
  """Fallback to numpy argsort when CPU."""
  if jax.devices()[0].platform == "cpu":
    return jax.pure_callback(
      _np_sort, jax.ShapeDtypeStruct(x.shape, jnp.int32), x, axis
    )
  return x.argsort(axis=axis)

class IdxDataloader(Module):
  """Simple index dataloader."""

  length: int
  pad: int
  batch_size: int
  drop_num: int

  def __init__(
    self,
    length: int,
    batch_size: int,
    drop_last: bool = False,
  ) -> None:
    """Initiate the dataloader."""
    self.length = length
    length = length if not drop_last else length - length % batch_size
    pad = (batch_size - r) % batch_size if (r := length % batch_size) else 0
    self.pad = pad if pad != batch_size else 0
    self.batch_size = batch_size
    self.drop_num = self.length % batch_size if drop_last else 0

  def __call__(self, key: jax.Array | None = None) -> tuple[Array, Array]:
    """Get the indexes."""
    idxes = jnp.arange(self.length)
    if key is not None:
      idxes = jnp.take_along_axis(
        idxes,
        # NOTE: Fallback to numpy argsort since it has performance isssue in CPU.
        # https://github.com/google/jax/issues/10434
        fallback_argsort(jax.random.uniform(key, (self.length,))),
        axis=0,
      )
    length = self.length if not self.drop_num else self.length - self.drop_num

    idxes = jnp.r_[idxes, jnp.full(self.pad, -1, idxes.dtype)]
    idxes = idxes[: length + self.pad].reshape(-1, self.batch_size)
    return idxes, jnp.where(idxes == -1, 1, 0).astype(bool)   # index and padding mask

  def __len__(self) -> int:
    """Length of the dataloader."""
    return self.length // self.batch_size + (
      1 if (not self.drop_num) and self.length % self.batch_size else 0
    )
buchholzmd commented 1 month ago

I believe Nvidia's DALI would be worth looking into. I am doing that now and will report back with any updates