Open sheridana opened 11 months ago
Yes I need to look into this. I think this is somewhat low priority bc currently the model is fast enough with our compute. Once we get close to release and people with less resources want to use it I'll try and fix it.
Currently we just use a batch size of 1 and set
num_workers=0
during training. Since the model is pretty lightweight by default, we are slowing down training because we overload the cpu (the gpu processing is actually pretty fast and not memory intensive). But just setting num_workers>0 throws cryptic cuda errors and there doesn't seem to be super definitive documentation on the right way to handle all of the moving pieces. Some things to consider:worker_init_fn
should probably be passed to the data loader with using multiple workers (for seeding), from docs: "If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. ". An exampleworker_init_fn
:Now running will actually print the error instead of throwing cryptic cuda errors:
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
To fix this we'd need to set
torch.multiprocessing.set_start_method('spawn')
, i guess inside the main train function or config?Doing that will run things (at least until the following error), but is slow to initialize, i guess spawn is just much slower than fork? Or maybe some other things that need to be done.
Calling spawn might then throw:
torch RuntimeError: unable to open shared memory object in read-write mode: Too many open files (24)
- and you'd need to settorch.multiprocessing.set_sharing_strategy('file_system')
, but it looks like we already to this in the config.py?There are some other considerations, e.g pin_memory, collate_fn, batch_size (might also need to increase this from 1).
Some of those flags are set differently inside
TrackingDataset
than insideConfig
. Just make sure we are overriding these prior to callingtrainer.fit()