jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

PyTorch Dataloading doesn't work with >0 workers #9190

Open SarthakYadav opened 2 years ago

SarthakYadav commented 2 years ago

Hi!

I'm new to the JAX ecosystem, have used PyTorch and TensorFlow extensively for over 5 years.

My issue is that I can't get PyTorch data loading to work with jax/flax with num_workers>0. Following is a minimal example to reproduce my issues

import argparse
from typing import Sequence
from functools import partial
import flax
from typing import Any
import optax
from flax.training import train_state
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import tqdm
from torchvision.datasets import CIFAR10
from flax.training import common_utils
import torch
import torchvision.transforms as transforms
import torch.multiprocessing as multiprocessing
multiprocessing.set_start_method('spawn')

NUM_CLASSES = 10
NUM_EPOCHS = 50
BATCH_SIZE = 512

parser = argparse.ArgumentParser()
parser.add_argument("--num_workers", default=0, type=int)

def collate_fn(batch):
    inputs_np = []
    targets_np = []
    for item in batch:
        inp_np = item[0].permute(1, 2, 0).detach().numpy()
        tgts_np = item[1]
        inputs_np.append(inp_np)
        targets_np.append(tgts_np)
    inputs_np = np.asarray(inputs_np)
    targets_np = np.asarray(targets_np)
    return inputs_np, targets_np

class CNN(nn.Module):
    @nn.compact
    def __call__(self, inputs, train=False):
        conv = partial(nn.Conv, kernel_size=(3, 3), strides=(2, 2), 
                       use_bias=False, kernel_init=jax.nn.initializers.kaiming_normal())
        bn = partial(nn.BatchNorm, use_running_average=not train, momentum=0.9,
                   epsilon=1e-5)
        x = conv(features=32)(inputs)
        x = bn()(x)
        x = nn.relu(x)
        x = conv(features=64)(x)
        x = bn()(x)
        x = nn.relu(x)
        x = conv(features=128)(x)
        x = bn()(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(4, 4), strides=(1, 1))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(NUM_CLASSES)(x)
        return x

def initialize(key, inp_shape, model):
  input_shape = (1,) + inp_shape
  @jax.jit
  def init(*args):
    return model.init(*args)
  variables = init({'params': key}, jnp.ones(input_shape))
  return variables['params'], variables['batch_stats']

@jax.jit
def cross_entropy_loss(logits, labels):
    one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES)
    xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
    return jnp.mean(xentropy)

@jax.jit
def calculate_accuracy(logits, labels):
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return accuracy

@jax.jit
def train_step(state, images, labels):
    step = state.step
    @jax.jit
    def cost_fn(params):
        logits, new_model_state = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats},
            images,
            mutable=['batch_stats'],
            train=True
        )
        loss = cross_entropy_loss(logits, labels)
        weight_penalty_params = jax.tree_leaves(params)
        weight_l2 = sum([jnp.sum(x ** 2)
                        for x in weight_penalty_params
                        if x.ndim > 1])
        weight_decay=0.0001
        weight_penalty = weight_decay * 0.5 * weight_l2
        loss = loss + weight_penalty
        return loss, (new_model_state, logits)
    grad_fn = jax.value_and_grad(cost_fn, has_aux=True)
    aux, grads = grad_fn(state.params)
    new_model_state, logits = aux[1]
    acc = calculate_accuracy(logits, labels)
    new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
    return new_state, aux[0], acc

@jax.jit
def eval_step(state, images, labels):
    logits = state.apply_fn(
        {"params": state.params, 
        "batch_stats": state.batch_stats}, 
        images, train=False, mutable=False)
    return calculate_accuracy(logits, labels)

class TrainState(train_state.TrainState):
    batch_stats: Any

if __name__ == "__main__":
  args = parser.parse_args()
  cnn = CNN()
  key = jax.random.PRNGKey(0)
  key, *subkeys = jax.random.split(key, 4)
  params, batch_stats = initialize(subkeys[0], (32, 32, 3), cnn)
  tx = optax.adam(
    1e-3
  )
  state = TrainState.create(
      apply_fn=cnn.apply,
      params=params,
      tx=tx,
      batch_stats=batch_stats
  )
  transform = transforms.Compose(
  [transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  batch_size = BATCH_SIZE
  trainset = CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)
  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, drop_last=True,
                                          shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn)
  num_tr_steps = len(trainloader)
  testset = CIFAR10(root='./data', train=False,
                                      download=True, transform=transform)
  testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, drop_last=True,
                                          shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
  num_test_steps = len(testloader)

  for epoch in range(1, NUM_EPOCHS+1):
    print("Starting epoch {}".format(epoch))
    train_loss = []
    train_acc = []
    itercnt = 0
    for batch in trainloader:
      images, labels = batch
      state, loss, acc = train_step(state, images, labels)
      if itercnt == 0:
        print("Input shape:", images.shape)
        print("labels shape:", labels.shape)
      if itercnt % 25 == 0:
        print("[{:03d}] | Step: [{:04d}/{:04d}] | Loss: {:.04f} | Acc: {:.04f}".format(
          epoch, itercnt, num_tr_steps, loss, acc
        ))
      train_loss.append(jax.device_get(loss))
      train_acc.append(jax.device_get(acc))
      itercnt += 1
    print("Validating...")
    val_accs = []
    for batch in testloader:
      images, labels = batch
      acc = eval_step(state, images, labels)
      val_accs.append(jax.device_get(acc))

    print("Epoch {:03d} done...".format(epoch))
    print("\t Train loss: {:.04f} | Train Acc: {:.04f}".format(
      np.mean(train_loss), np.mean(train_acc)))
    print("\t Val Acc: {:.04f}".format(np.mean(val_accs)))

Problem encountered:

I've tried running the script on both TPU and GPU: it works fine when num_workers = 0, but doesn't work with num_workers > 0.

An earlier issue from 2020 recommended setting torch.multiprocessing.set_start_method('spawn'), but that didn't fix the issue for me. Unlike the author of that issue, I'm not using jax primitives in the data loading pipeline at all (as can be seen in the collate_fn() function)

With num_workers>0, I get the following errors:

On GPU

On TPUv2-8 VM

Following are the packages being used:

torch==1.9.0+cu111
jax==0.2.26
jaxlib==0.1.75       #+cuda11.cudnn82 for GPU

Any help is appreciated!

jakevdp commented 2 years ago

You might get more traction asking about this in the torch project.

SarthakYadav commented 2 years ago

You might get more traction asking about this in the torch project.

I'll give that a try, but this only happens when the model itself is implemented in Jax.

In the same env, everything-torch works absolutely fine.

levskaya commented 2 years ago

Hi! I think that it's going to be really hard to make this work. We generally don't try to support python-multiprocessing: all the internal C++ libs we use aren't written to be fork-safe, and I'm not sure that TPU libtpu.so can be used with multiprocessing at all.

Usually we recommend that people use TFDS / tf.data based dataloaders as they're far more CPU efficient for feeding multiple GPUs or TPUs than torch dataloaders with multiprocessing.

SarthakYadav commented 2 years ago

Hi! I think that it's going to be really hard to make this work. We generally don't try to support python-multiprocessing: all the internal C++ libs we use aren't written to be fork-safe, and I'm not sure that TPU libtpu.so can be used with multiprocessing at all.

Usually we recommend that people use TFDS / tf.data based dataloaders as they're far more CPU efficient for feeding multiple GPUs or TPUs than torch dataloaders with multiprocessing.

Thanks for the detailed reply. I'll see how much time moving all data operations to tensorflow based ops will take.

However, I believe this information should be added to the official tutorial on using pytorch dataloaders with Jax, as this is quite the limitation. Using multiple workers in dataloaders is a standard practice in the PyTorch realm.

It will work for small datasets and a quick proof of concept for researchers/teams thinking about making the move to Jax, sure, but for full-bore training, using torch data loaders with Jax would not be feasible. Adding this as a disclaimer to the above-mentioned tutorial will save valuable time in my opinion.

jaanli commented 2 years ago

+1 I agree these gotchas are major and should be mentioned front and center, it took me a long time and many lost hours this week to figure this out for myself. It's a limitation of the jax ecosystem right now!

nikitakit commented 2 years ago

This issue appears to be a regression compared to one year ago. I was using multi-worker data loaders in sabertooth and they worked fine at the time, but no longer work with newly started TPU VMs. I want to emphasize that the data workers are not using JAX nor accessing the TPUs in any way, just doing pure numpy computation.

torch.multiprocessing.set_start_method('spawn') sort of works as a work around. I've managed to avoid the error RuntimeError: context has already been set with the idiom if __name__ == '__main__': torch.multiprocessing.set_start_method('spawn') -- I had to wrap it so that each spawned worker don't itself attempt to set the start method. However this workaround still has issues: each worker takes a really long time to spawn, and generates a bunch of libtpu.so already in use by another process messages. Setting persistent_workers=True helps cut down on these but it's still annoying.

Given that this is a regression, is it really the case that it can't be fixed? None of the child processes are actually doing anything with the TPU.

jaanli commented 2 years ago

Agreed, would be great to find a solution ASAP thank you @nikitakit !

haoliuhl commented 2 years ago

Are there any updates on this? It is frustrating to find out that PyTorch data loader cannot work with Jax on TPU despite it is used in Jax's official examples.

levskaya commented 2 years ago

@nikitakit - do you perhaps have a small repro for the failure?

haoliuhl commented 2 years ago

@levskaya there is a code snippet for the failure in #9767.

jaanli commented 2 years ago

Following up on this

rensushan commented 1 year ago

I also met similar problems on A100 GPU. But I have no idea how to fix it.

noahzhy commented 8 months ago

Same problem on NV T4 GPU, it's a disaster when training a tiny model with huge datasets.