Open jejjohnson opened 2 years ago
FWIW I tend to use the PyTorch dataloaders, mostly just out of familiarity. I know many folks use TF as well.
(The custom dataloaders in the docs are just to minimise dependencies for the examples.)
If I'm working with a small dataset then it can be advantageous to keep the entire dataset in-memory, see for example mmap.ninja. (Or you can probably just slice NumPy arrays - I've not checked too closely how that compares.)
Thanks for the advice @patrick-kidger ! I will do the same. I'm also more familiar with the PyTorch dataloaders as well. Do you do similar practices to, for example, what's done here? (i.e. standard dataset, dataloader, and collate function?)
Yep pretty much.
A word of caution: do not make your collate
return np.array
instead of Tensor
(and I think the same applies to return values from the dataset class). It's convenient but PyTorch has some custom hackery in Tensor
to make them efficient to be exchanged between processes (I think by avoiding serialization at process boundaries), which NumPy arrays lack. So it's much better to just return (structs of) Tensor
s and then jax.tree_map(lambda tensor: tensor.numpy(), batch)
.
That I didn't know! Thanks, that's really useful information.
A word of caution: do not make your
collate
returnnp.array
instead ofTensor
(and I think the same applies to return values from the dataset class). It's convenient but PyTorch has some custom hackery inTensor
to make them efficient to be exchanged between processes (I think by avoiding serialization at process boundaries), which NumPy arrays lack. So it's much better to just return (structs of)Tensor
s and thenjax.tree_map(lambda tensor: tensor.numpy(), batch)
.
Thanks! This is good advice! So just to be clear:
Did I understand that correctly?
this feels very ugly, numpy -> torch -> jax at every dataloading step? Why must I download a huge library like torch just for dataloading?
there must be a better solution..
this feels very ugly, numpy -> torch -> jax at every dataloading step?
You can use numpy but torch in PyTorch dataloader, the numpy array to torch tensor happened in collate_fn
. So, you can directly convert numpy array to JAX array if you define your own collate function.
Why must I download a huge library like torch just for dataloading?
If you only need the index dataloader:
def _np_sort(x: jax.Array, axis: int | None = None) -> np.ndarray:
"""Sort the array."""
return np.argsort(np.asarray(x), axis=axis)
@filter_jit
def fallback_argsort(x: jax.Array, axis: int | None = None) -> jax.Array:
"""Fallback to numpy argsort when CPU."""
if jax.devices()[0].platform == "cpu":
return jax.pure_callback(
_np_sort, jax.ShapeDtypeStruct(x.shape, jnp.int32), x, axis
)
return x.argsort(axis=axis)
class IdxDataloader(Module):
"""Simple index dataloader."""
length: int
pad: int
batch_size: int
drop_num: int
def __init__(
self,
length: int,
batch_size: int,
drop_last: bool = False,
) -> None:
"""Initiate the dataloader."""
self.length = length
length = length if not drop_last else length - length % batch_size
pad = (batch_size - r) % batch_size if (r := length % batch_size) else 0
self.pad = pad if pad != batch_size else 0
self.batch_size = batch_size
self.drop_num = self.length % batch_size if drop_last else 0
def __call__(self, key: jax.Array | None = None) -> tuple[Array, Array]:
"""Get the indexes."""
idxes = jnp.arange(self.length)
if key is not None:
idxes = jnp.take_along_axis(
idxes,
# NOTE: Fallback to numpy argsort since it has performance isssue in CPU.
# https://github.com/google/jax/issues/10434
fallback_argsort(jax.random.uniform(key, (self.length,))),
axis=0,
)
length = self.length if not self.drop_num else self.length - self.drop_num
idxes = jnp.r_[idxes, jnp.full(self.pad, -1, idxes.dtype)]
idxes = idxes[: length + self.pad].reshape(-1, self.batch_size)
return idxes, jnp.where(idxes == -1, 1, 0).astype(bool) # index and padding mask
def __len__(self) -> int:
"""Length of the dataloader."""
return self.length // self.batch_size + (
1 if (not self.drop_num) and self.length % self.batch_size else 0
)
I believe Nvidia's DALI would be worth looking into. I am doing that now and will report back with any updates
I was wondering if the equinox community has good advice / best practices for creating dataloaders that work well with jax? I've done some stuff from scratch but I tend to find that the training (on GPU) is a bit slower than if I use Pytorch + Pytorch Lightning (nothing fancy just a single GPU). But then I see claims such as this one with flax vs PyTorch-Lightning and I wonder perhaps it's something on my end. But I wanted to ask the equinox community about their experience.
For example, the work from the jax website shows a tensorflow dataloader example and pytorch dataloader example. But if you run these, the pytorch dataloader works 1-2 seconds faster per epoch than the tensorflow dataloader. In the docs I see custom iterators and in other similar libraries that have more elaborate schemes, e.g. treex, objax.