talmolab / dreem

DREEM Relates Every Entities' Motion (DREEM). Global Tracking Transformers for biological multi-object tracking.
https://dreem.sleap.ai
BSD 3-Clause "New" or "Revised" License
6 stars 1 forks source link

Parallel data loading #19

Open sheridana opened 11 months ago

sheridana commented 11 months ago

Currently we just use a batch size of 1 and set num_workers=0 during training. Since the model is pretty lightweight by default, we are slowing down training because we overload the cpu (the gpu processing is actually pretty fast and not memory intensive). But just setting num_workers>0 throws cryptic cuda errors and there doesn't seem to be super definitive documentation on the right way to handle all of the moving pieces. Some things to consider:

  1. worker_init_fn should probably be passed to the data loader with using multiple workers (for seeding), from docs: "If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. ". An example worker_init_fn:
def worker_init_fn(self, worker_id):
     print(f'setting worker id: {worker_id}')
     np.random.seed(np.random.get_state()[1][0] + worker_id)
     if torch.cuda.is_available():
          torch.cuda.set_device(worker_id % torch.cuda.device_count())

Now running will actually print the error instead of throwing cryptic cuda errors:

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

To fix this we'd need to set torch.multiprocessing.set_start_method('spawn'), i guess inside the main train function or config?

Doing that will run things (at least until the following error), but is slow to initialize, i guess spawn is just much slower than fork? Or maybe some other things that need to be done.

  1. Calling spawn might then throw: torch RuntimeError: unable to open shared memory object in read-write mode: Too many open files (24) - and you'd need to set torch.multiprocessing.set_sharing_strategy('file_system'), but it looks like we already to this in the config.py?

  2. There are some other considerations, e.g pin_memory, collate_fn, batch_size (might also need to increase this from 1).

  3. Some of those flags are set differently inside TrackingDataset than inside Config. Just make sure we are overriding these prior to calling trainer.fit()

aaprasad commented 2 months ago

Yes I need to look into this. I think this is somewhat low priority bc currently the model is fast enough with our compute. Once we get close to release and people with less resources want to use it I'll try and fix it.