smistad / FAST

A framework for high-performance medical image processing, neural network inference and visualization
https://fast.eriksmistad.no
BSD 2-Clause "Simplified" License
433 stars 101 forks source link

pyFAST incompatible with Python multiprocessing #169

Closed andreped closed 7 months ago

andreped commented 1 year ago

Describe the bug

We have started to use pyFAST for streaming image patches from WSIs. The patch generator samples random patches from a WSI.

In TensorFlow, we have used multithreading to speed up patch generation, which is available through the tf.data.Dataset API, where the threading is performed in C++.

However, it does not seem like PyTorch 1.x supports multithreading (may change in the upcoming major release of PyTorch 2.x). In PyTorch, it is therefore common to use multiprocessing instead, which TensorFlow also supports, but for both libraries we get the same issue.

When enabling multiprocessing and setting num_workers > 1 in model.fit() in TensorFlow, it raises a RuntimeError. Same applies to setting num_workers > 0 for PyTorch's DataLoader.

I believe the core reason is that pyFAST does not work with Python multiprocessing which both TensorFlow and PyTorch use.

To Reproduce

I observe the same issue, if I try to read a patch from a WSI from separate processes (does it matter if I move the fast import inside the function?):

import fast
import multiprocessing as mp
import numpy as np
from tqdm import tqdm

def func(value=1):
    path = "./OS-2.vsi"
    importer = fast.WholeSlideImageImporter.create(path)
    wsi = importer.runAndGetOutputData()
    patch_access = wsi.getAccess(fast.ACCESS_READ)
    patch = patch_access.getPatchAsImage(0, 50, 100, 512, 512, False)
    patch = np.asarray(patch)
    print(patch.shape)
    return patch

def main():
    # this works
    """
    for i in range(100):
        print(i)
        p = mp.Process(target=func)
        p.start()
        p.join()
        del p
    """

    # this does not -> I believe TF/PyTorch does something like this in their dataloaders
    p = mp.Pool(4)
    ret = tqdm(p.map(func, range(100)))
    print(len(ret))

if __name__=='__main__':
    main()

We had issues with this before, but as we could use multithreading in TF, we did not look more into the issue. However, this needs to be resolved, as there is no other way to stream patches in parallel in PyTorch.

System

Expected behavior pyFAST should be compatible with Python multiprocessing for both TensorFlow and PyTorch.

Logs

multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 48, in mapstar
    return list(map(*args))
  File "test.py", line 15, in func
    patch = np.asarray(patch)
  File "/lhome/username/.local/lib/python3.8/site-packages/fast/fast.py", line 12836, in __array_interface__
    'data': (self._getHostDataPointer(), False),
  File "/lhome/username/.local/lib/python3.8/site-packages/fast/fast.py", line 12814, in _getHostDataPointer
    return _fast.Image__getHostDataPointer(self)
RuntimeError: clEnqueueReadImage
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "test.py", line 39, in <module>    main()
  File "test.py", line 35, in main
    ret = tqdm(p.map(func, range(100)))
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 364, in map
    return self._map_async(func, iterable, mapstar, chunksize).get()
  File "/usr/lib/python3.8/multiprocessing/pool.py", line 771, in get
    raise self._value
RuntimeError: clEnqueueReadImage

@markusdrange on cc.

smistad commented 1 year ago

Move "import fast" into func, and it works. Not sure why this error happens, but understand that threads and processes are fundamentally different. Threads share memory, processes do not.

andreped commented 1 year ago

Just tried it on my end. Works fine with having the import fast inside the func().

From a TensorFlow perpective I can understand that importing within the process separately has its value, as TF defines the session globally within the process. Hence, by importing TF in separate processes, you are able to completely clear the session when killing the subprocess. However, with TF, it still works and does not produce any errors or warnings.

When importing fast in separate processes, we do however get lots of these bad boys (see below). I assume this is written to standard output, and thus, I cannot mute this easily or?

     - Powered by -     
   _______   __________   
  / __/ _ | / __/_  __/   https://fast.eriksmistad.no
 / _// __ |_\ \  / /               v4.6.0
/_/ /_/ |_/___/ /_/       

     - Powered by -     
   _______   __________   
  / __/ _ | / __/_  __/   https://fast.eriksmistad.no
 / _// __ |_\ \  / /               v4.6.0
/_/ /_/ |_/___/ /_/       

     - Powered by -     
   _______   __________   
  / __/ _ | / __/_  __/   https://fast.eriksmistad.no
 / _// __ |_\ \  / /               v4.6.0
/_/ /_/ |_/___/ /_/       

     - Powered by -     
   _______   __________   
  / __/ _ | / __/_  __/   https://fast.eriksmistad.no
 / _// __ |_\ \  / /               v4.6.0
/_/ /_/ |_/___/ /_/       
smistad commented 1 year ago

There is no way to mute this "splash" right now. I could add an environment variable to mute it. Otherwise you can redirect stdout of your multiprocessing, so it doesn't show in the console. But then you will not see error messages either.