google / grain

Apache License 2.0
250 stars 17 forks source link

Advice on using a JIT function inside a transform? #495

Open DBraun opened 1 month ago

DBraun commented 1 month ago

I want to put JAX jitted batched data augmentation inside my grain dataloader. I'm currently pretending this augmentation is a jitted batch inference of a Flax model. With worker_count=0, it smoothly processes about 390-400 batches per second. However, with worker_count=1 it becomes more sporadic and slower. I suppose having worker_count=0 is acceptable, and I can use this to feed a model for training. However, it might be useful to have a spare batch ready with worker_count=1 and worker_buffer_size=2, assuming my GPU has the memory for two of the jitted functions to be run in parallel. In this case it does, and I still see issues even when I make the Flax model much smaller. What is your advice?

from typing import SupportsIndex

import jax
import jax.numpy as jnp
import jax.random as random

import flax.linen as nn

from tqdm import tqdm
from absl import logging

import grain.python as grain

class Model(nn.Module):

    n_layers: int = 10
    features: int = 10

    @nn.compact
    def __call__(self, x):

        for _ in range(self.n_layers):
            x = nn.Dense(features=self.features)(x)
            x = nn.relu(x)

        return x

Model = nn.vmap(Model, variable_axes={'params': None}, split_rngs={'params': False})

B = 4
IN_FEATURES = 100
N_LAYERS = 20
FEATURES = 20

dummy_input = jnp.zeros(shape=(B, IN_FEATURES))

model = Model(n_layers=N_LAYERS, features=FEATURES)

params = model.init({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input)['params']

print(model.tabulate({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input))

@jax.jit
def jit_batch_inference(x):
    return model.apply({'params': params}, x)

class DataSimpleSource(grain.RandomAccessDataSource):

    def __init__(self, num_steps):

        self._num_steps = num_steps

    def __len__(self) -> int:
        return self._num_steps

    def __getitem__(self, record_key: SupportsIndex):
        record_key = int(record_key)
        return random.uniform(random.key(record_key), shape=(IN_FEATURES,))

class JITBatchTransform(grain.MapTransform):

    def map(self, batch: jnp.ndarray):
        assert batch.ndim == 2
        assert batch.shape == (B, IN_FEATURES)

        x = jit_batch_inference(batch)
        return x

if __name__ == '__main__':

    logging.set_verbosity(logging.INFO)

    num_steps = 1000000
    worker_count = 0  # todo:
    worker_buffer_size = 1  # todo:

    datasource = DataSimpleSource(num_steps=num_steps)

    index_sampler = grain.IndexSampler(
        num_records=len(datasource),
        num_epochs=1,
        shard_options=grain.NoSharding(),
        shuffle=False,
        seed=0,
    )

    pygrain_ops = [
        # grain.BatchOperation(batch_size=B, drop_remainder=True),  # deprecated alternative to grain.Batch
        grain.Batch(batch_size=B, drop_remainder=True),
        JITBatchTransform(),
    ]

    batched_dataloader = grain.DataLoader(
        data_source=datasource,
        sampler=index_sampler,
        operations=pygrain_ops,
        worker_count=worker_count,
        worker_buffer_size=worker_buffer_size,
        enable_profiling=False,  # todo:
    )

    for x in tqdm(batched_dataloader, total=num_steps, desc='Grain Dataset'):
        pass
DBraun commented 1 month ago

Here's another version that uses flax.jax_utils.prefetch_to_device. Maybe this achieves what worker_buffer_size = 2 usually does, without actually having set worker_buffer_size = 2. However, I would still like a way that multiprocesses the DataSimpleSource. Multiple random arrays could be generated in parallel. This is just a stand-in for some other data loading process that I want to parallelize.

from typing import SupportsIndex

import jax
import jax.numpy as jnp
import jax.random as random

import flax.linen as nn
from flax.jax_utils import prefetch_to_device

from tqdm import tqdm
from absl import logging

import grain.python as grain

class Model(nn.Module):

    n_layers: int = 10
    features: int = 10

    @nn.compact
    def __call__(self, x):

        for _ in range(self.n_layers):
            x = nn.Dense(features=self.features)(x)
            x = nn.relu(x)

        return x

Model = nn.vmap(Model, variable_axes={'params': None}, split_rngs={'params': False})

B = 4
IN_FEATURES = 100
N_LAYERS = 20
FEATURES = 20

dummy_input = jnp.zeros(shape=(B, IN_FEATURES))

model = Model(n_layers=N_LAYERS, features=FEATURES)

params = model.init({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input)['params']

print(model.tabulate({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input))

@jax.jit
def jit_batch_inference(x):
    return model.apply({'params': params}, x)

class DataSimpleSource(grain.RandomAccessDataSource):

    def __init__(self, num_steps):

        self._num_steps = num_steps

    def __len__(self) -> int:
        return self._num_steps

    def __getitem__(self, record_key: SupportsIndex):
        record_key = int(record_key)
        return random.uniform(random.key(record_key), shape=(IN_FEATURES,))

class JITBatchTransform(grain.MapTransform):

    def map(self, batch: jnp.ndarray):
        assert batch.ndim == 2
        assert batch.shape == (B, IN_FEATURES)

        x = jit_batch_inference(batch)
        return x

if __name__ == '__main__':

    logging.set_verbosity(logging.INFO)

    num_steps = 1000000
    worker_count = 0  # todo:
    worker_buffer_size = 1  # todo:
    prefetch_size = 2  # todo:

    datasource = DataSimpleSource(num_steps=num_steps)

    index_sampler = grain.IndexSampler(
        num_records=len(datasource),
        num_epochs=1,
        shard_options=grain.NoSharding(),
        shuffle=False,
        seed=0,
    )

    pygrain_ops = [
        # grain.BatchOperation(batch_size=B, drop_remainder=True),  # deprecated alternative to grain.Batch
        grain.Batch(batch_size=B, drop_remainder=True),
        JITBatchTransform(),
    ]

    batched_dataloader = grain.DataLoader(
        data_source=datasource,
        sampler=index_sampler,
        operations=pygrain_ops,
        worker_count=worker_count,
        worker_buffer_size=worker_buffer_size,
        enable_profiling=False,  # todo:
    )

    def prepare_for_prefetch(xs):
        local_device_count = jax.local_device_count()

        def _prepare(x):
            return x.reshape((local_device_count, -1) + x.shape[1:])

        return jax.tree_util.tree_map(_prepare, xs)

    # # similar to flax.jax_utils.replicate
    batched_dataloader = map(prepare_for_prefetch, batched_dataloader)

    if prefetch_size > 1:
        # For prefetch to work, we must have already used prepare_for_prefetch
        batched_dataloader = prefetch_to_device(batched_dataloader, size=prefetch_size)

    for x in tqdm(batched_dataloader, total=num_steps, desc='Grain Dataset'):
        pass