Open SarthakYadav opened 2 years ago
You might get more traction asking about this in the torch project.
You might get more traction asking about this in the torch project.
I'll give that a try, but this only happens when the model itself is implemented in Jax.
In the same env, everything-torch works absolutely fine.
Hi! I think that it's going to be really hard to make this work. We generally don't try to support python-multiprocessing: all the internal C++ libs we use aren't written to be fork-safe, and I'm not sure that TPU libtpu.so
can be used with multiprocessing at all.
Usually we recommend that people use TFDS / tf.data based dataloaders as they're far more CPU efficient for feeding multiple GPUs or TPUs than torch dataloaders with multiprocessing.
Hi! I think that it's going to be really hard to make this work. We generally don't try to support python-multiprocessing: all the internal C++ libs we use aren't written to be fork-safe, and I'm not sure that TPU
libtpu.so
can be used with multiprocessing at all.Usually we recommend that people use TFDS / tf.data based dataloaders as they're far more CPU efficient for feeding multiple GPUs or TPUs than torch dataloaders with multiprocessing.
Thanks for the detailed reply. I'll see how much time moving all data operations to tensorflow based ops will take.
However, I believe this information should be added to the official tutorial on using pytorch dataloaders with Jax, as this is quite the limitation. Using multiple workers in dataloaders is a standard practice in the PyTorch realm.
It will work for small datasets and a quick proof of concept for researchers/teams thinking about making the move to Jax, sure, but for full-bore training, using torch data loaders with Jax would not be feasible. Adding this as a disclaimer to the above-mentioned tutorial will save valuable time in my opinion.
+1 I agree these gotchas are major and should be mentioned front and center, it took me a long time and many lost hours this week to figure this out for myself. It's a limitation of the jax ecosystem right now!
This issue appears to be a regression compared to one year ago. I was using multi-worker data loaders in sabertooth and they worked fine at the time, but no longer work with newly started TPU VMs. I want to emphasize that the data workers are not using JAX nor accessing the TPUs in any way, just doing pure numpy computation.
torch.multiprocessing.set_start_method('spawn')
sort of works as a work around. I've managed to avoid the error RuntimeError: context has already been set
with the idiom if __name__ == '__main__': torch.multiprocessing.set_start_method('spawn')
-- I had to wrap it so that each spawned worker don't itself attempt to set the start method. However this workaround still has issues: each worker takes a really long time to spawn, and generates a bunch of libtpu.so already in use by another process
messages. Setting persistent_workers=True
helps cut down on these but it's still annoying.
Given that this is a regression, is it really the case that it can't be fixed? None of the child processes are actually doing anything with the TPU.
Agreed, would be great to find a solution ASAP thank you @nikitakit !
Are there any updates on this? It is frustrating to find out that PyTorch data loader cannot work with Jax on TPU despite it is used in Jax's official examples.
@nikitakit - do you perhaps have a small repro for the failure?
Following up on this
I also met similar problems on A100 GPU. But I have no idea how to fix it.
Same problem on NV T4 GPU, it's a disaster when training a tiny model with huge datasets.
Hi!
I'm new to the JAX ecosystem, have used PyTorch and TensorFlow extensively for over 5 years.
My issue is that I can't get PyTorch data loading to work with jax/flax with
num_workers>0
. Following is a minimal example to reproduce my issuesProblem encountered:
I've tried running the script on both
TPU
andGPU
: it works fine whennum_workers = 0
, but doesn't work withnum_workers > 0
.An earlier issue from 2020 recommended setting
torch.multiprocessing.set_start_method('spawn')
, but that didn't fix the issue for me. Unlike the author of that issue, I'm not using jax primitives in the data loading pipeline at all (as can be seen in thecollate_fn()
function)With
num_workers>0
, I get the following errors:On GPU
torch.multiprocessing.set_start_method('spawn')
throwsRuntimeError: context has already been set
torch.multiprocessing.set_start_method('fork')
throwsFailed setting context: CUDA_ERROR_NOT_INITIALIZED: initialization error
On TPUv2-8 VM
torch.multiprocessing.set_start_method('spawn')
throwslibtpu.so already in use by another process
, followed byRuntimeError: context has already been set
later in the stack trace.torch.multiprocessing.set_start_method('fork')
I get no error, the dataloader hangs indefinitely.Following are the packages being used:
Any help is appreciated!