Open DBraun opened 1 month ago
Here's another version that uses flax.jax_utils.prefetch_to_device
. Maybe this achieves what worker_buffer_size = 2
usually does, without actually having set worker_buffer_size = 2
. However, I would still like a way that multiprocesses the DataSimpleSource
. Multiple random arrays could be generated in parallel. This is just a stand-in for some other data loading process that I want to parallelize.
from typing import SupportsIndex
import jax
import jax.numpy as jnp
import jax.random as random
import flax.linen as nn
from flax.jax_utils import prefetch_to_device
from tqdm import tqdm
from absl import logging
import grain.python as grain
class Model(nn.Module):
n_layers: int = 10
features: int = 10
@nn.compact
def __call__(self, x):
for _ in range(self.n_layers):
x = nn.Dense(features=self.features)(x)
x = nn.relu(x)
return x
Model = nn.vmap(Model, variable_axes={'params': None}, split_rngs={'params': False})
B = 4
IN_FEATURES = 100
N_LAYERS = 20
FEATURES = 20
dummy_input = jnp.zeros(shape=(B, IN_FEATURES))
model = Model(n_layers=N_LAYERS, features=FEATURES)
params = model.init({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input)['params']
print(model.tabulate({'params': random.key(0), 'rng_stream': random.key(1)}, dummy_input))
@jax.jit
def jit_batch_inference(x):
return model.apply({'params': params}, x)
class DataSimpleSource(grain.RandomAccessDataSource):
def __init__(self, num_steps):
self._num_steps = num_steps
def __len__(self) -> int:
return self._num_steps
def __getitem__(self, record_key: SupportsIndex):
record_key = int(record_key)
return random.uniform(random.key(record_key), shape=(IN_FEATURES,))
class JITBatchTransform(grain.MapTransform):
def map(self, batch: jnp.ndarray):
assert batch.ndim == 2
assert batch.shape == (B, IN_FEATURES)
x = jit_batch_inference(batch)
return x
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
num_steps = 1000000
worker_count = 0 # todo:
worker_buffer_size = 1 # todo:
prefetch_size = 2 # todo:
datasource = DataSimpleSource(num_steps=num_steps)
index_sampler = grain.IndexSampler(
num_records=len(datasource),
num_epochs=1,
shard_options=grain.NoSharding(),
shuffle=False,
seed=0,
)
pygrain_ops = [
# grain.BatchOperation(batch_size=B, drop_remainder=True), # deprecated alternative to grain.Batch
grain.Batch(batch_size=B, drop_remainder=True),
JITBatchTransform(),
]
batched_dataloader = grain.DataLoader(
data_source=datasource,
sampler=index_sampler,
operations=pygrain_ops,
worker_count=worker_count,
worker_buffer_size=worker_buffer_size,
enable_profiling=False, # todo:
)
def prepare_for_prefetch(xs):
local_device_count = jax.local_device_count()
def _prepare(x):
return x.reshape((local_device_count, -1) + x.shape[1:])
return jax.tree_util.tree_map(_prepare, xs)
# # similar to flax.jax_utils.replicate
batched_dataloader = map(prepare_for_prefetch, batched_dataloader)
if prefetch_size > 1:
# For prefetch to work, we must have already used prepare_for_prefetch
batched_dataloader = prefetch_to_device(batched_dataloader, size=prefetch_size)
for x in tqdm(batched_dataloader, total=num_steps, desc='Grain Dataset'):
pass
I want to put JAX jitted batched data augmentation inside my grain dataloader. I'm currently pretending this augmentation is a jitted batch inference of a Flax model. With
worker_count=0
, it smoothly processes about 390-400 batches per second. However, withworker_count=1
it becomes more sporadic and slower. I suppose havingworker_count=0
is acceptable, and I can use this to feed a model for training. However, it might be useful to have a spare batch ready withworker_count=1
andworker_buffer_size=2
, assuming my GPU has the memory for two of the jitted functions to be run in parallel. In this case it does, and I still see issues even when I make the Flax model much smaller. What is your advice?