Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.73k stars 1.05k forks source link

Add multi-threads support to samples crop #2794

Closed Nic-Ma closed 3 years ago

Nic-Ma commented 3 years ago

Is your feature request related to a problem? Please describe. Currently, there are 4 crop transforms can generate a list of samples, and we crop the images in a for loop. After some testing, I found that if executing in multi-threads, it can be much faster. So it would be useful to add num_workers support to these transforms, similar to the CacheDataset.

Nic-Ma commented 3 years ago

Here is the test code in RandCropByPosNegLabeld transform(changed for-loop to thread pool):

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> List[Dict[Hashable, np.ndarray]]:
    d = dict(data)
    label = d[self.label_key]
    image = d[self.image_key] if self.image_key else None
    fg_indices = d.get(self.fg_indices_key) if self.fg_indices_key is not None else None
    bg_indices = d.get(self.bg_indices_key) if self.bg_indices_key is not None else None

    self.randomize(label, fg_indices, bg_indices, image)
    if not isinstance(self.spatial_size, tuple):
        raise ValueError("spatial_size must be a valid tuple.")
    if self.centers is None:
        raise ValueError("no available ROI centers to crop.")

    def _crop(idx: int):
        # initialize returned list with shallow copy to preserve key ordering
        ret = dict(data)
        center = self.centers[idx]
        # fill in the extra keys with unmodified data
        for key in set(data.keys()).difference(set(self.keys)):
            ret[key] = deepcopy(data[key])
        for key in self.key_iterator(d):
            img = d[key]
            cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size)  # type: ignore
            orig_size = img.shape[1:]
            ret[key] = cropper(img)
            self.push_transform(ret, key, extra_info={"center": center}, orig_size=orig_size)
        # add `patch_index` to the meta data
        for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
            meta_key = meta_key or f"{key}_{meta_key_postfix}"
            if meta_key not in ret:
                ret[meta_key] = {}  # type: ignore
            ret[meta_key][Key.PATCH_INDEX] = idx
        return ret

    with ThreadPool(self.num_workers) as p:
        return list(p.imap(_crop, range(self.num_samples)))

@wyli @ericspod @rijobro , if you guys don't have concerns on this enhancement, I will try to make a PR for all the sample-crop transforms ASAP, it can be helpful in our performance task of v0.7.

Thanks in advance.

wyli commented 3 years ago

I don't understand why this is faster, multi-thread could be useful when the bottleneck is IO, in the other use cases normally it is not efficient because of Python GIL. Any idea?

Nic-Ma commented 3 years ago

Hi @wyli ,

Usually, muti-processing can help improve IO reading logic, better than multi-threads. Here we don't have IO issue, multi-threads can help crop several samples in parallel instead of the original for-loop. To verify the idea, I just tested again with below code to compare the speed:

start1 = time.time()
with ThreadPool(self.num_samples) as p:
    out = list(p.imap(_crop, range(self.num_samples)))
print(f"multi-threads time: {(time.time() - start1):.4f}")

start2 = time.time()
out = [_crop(i) for i in range(self.num_samples)]
print(f"for loop time: {(time.time() - start2):.4f}")

Part of the output log during training for your reference:

multi-threads time: 0.0406
for loop time: 0.0878
multi-threads time: 0.0112
for loop time: 0.0261
multi-threads time: 0.0104
for loop time: 0.0155
multi-threads time: 0.0090
for loop time: 0.0126
multi-threads time: 0.0107
for loop time: 0.0116
multi-threads time: 0.0109
for loop time: 0.0162
multi-threads time: 0.0318
for loop time: 0.0925
multi-threads time: 0.0083
for loop time: 0.0098
multi-threads time: 0.0115
for loop time: 0.0175
multi-threads time: 0.0356
for loop time: 0.0902
multi-threads time: 0.0134
for loop time: 0.0227
multi-threads time: 0.0112
for loop time: 0.0141
multi-threads time: 0.0116
for loop time: 0.0152
multi-threads time: 0.0096
for loop time: 0.0098
multi-threads time: 0.0276
for loop time: 0.0717
multi-threads time: 0.0132
for loop time: 0.0533
multi-threads time: 0.0077
for loop time: 0.0083
multi-threads time: 0.0459
for loop time: 0.1059
multi-threads time: 0.0377
for loop time: 0.0789
multi-threads time: 0.0419
for loop time: 0.0841
multi-threads time: 0.0355
for loop time: 0.0512

Thanks.

wyli commented 3 years ago

could you share a complete script for the benchmark?

wyli commented 3 years ago

more analysis of multi-threading here http://www.dabeaz.com/python/UnderstandingGIL.pdf

Nic-Ma commented 3 years ago

Hi @wyli ,

This is the full program I used to train the spleen model and got previous test result. With multi-threads, it can be 20% faster on the V100 GPU with PyTorch docker 21.06.

import glob
import math
import os
import shutil
import tempfile
import time

import torch
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import CacheDataset, DataLoader, ThreadDataLoader, Dataset, decollate_batch
from monai.inferers import sliding_window_inference
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.optimizers import Novograd
from monai.transforms import (
    AddChanneld,
    AsDiscrete,
    Compose,
    CropForegroundd,
    DeleteItemsd,
    FgBgToIndicesd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    ToDeviced,
    EnsureType,
)
from monai.utils import get_torch_version_tuple, set_determinism

print_config()

if get_torch_version_tuple() < (1, 6):
    raise RuntimeError(
        "AMP feature only exists in PyTorch version greater than v1.6."
    )

# ## Setup data directory
# 
# You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
# This allows you to save results and reuse downloads.  
# If not specified a temporary directory will be used.

# In[2]:

directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(f"root dir is: {root_dir}")

# ## Download dataset
# 
# Downloads and extracts the Decathlon Spleen dataset.

# In[3]:

resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"

compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_root = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_root):
    download_and_extract(resource, compressed_file, root_dir, md5)

# ## Set MSD Spleen dataset path

# In[4]:

train_images = sorted(
    glob.glob(os.path.join(data_root, "imagesTr", "*.nii.gz"))
)
train_labels = sorted(
    glob.glob(os.path.join(data_root, "labelsTr", "*.nii.gz"))
)
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

# In[5]:

def transformations():
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            ),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            # pre-compute foreground and background indexes
            # and cache them to accelerate training
            FgBgToIndicesd(
                keys="label",
                fg_postfix="_fg",
                bg_postfix="_bg",
                image_key="image",
            ),
            EnsureTyped(keys=["image", "label"]),
            ToDeviced(keys=["image", "label"], device="cuda:0"),
            # randomly crop out patch samples from big
            # image based on pos / neg ratio
            # the image centers of negative samples
            # must be in valid image area
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                fg_indices_key="label_fg",
                bg_indices_key="label_bg",
            ),
            DeleteItemsd(keys=["label_fg", "label_bg"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            AddChanneld(keys=["image", "label"]),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            ),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
            ToDeviced(keys=["image", "label"], device="cuda:0"),
        ]
    )
    return train_transforms, val_transforms

# ## Define the training progress
# For a typical PyTorch regular training procedure, use regular `Dataset`, `Adam` optimizer, and train the model.
# 
# For MONAI fast training progress, we mainly introduce the following features:
# 1. `CacheDataset`: Dataset with the cache mechanism that can load data and cache deterministic transforms' result during training.
# 2. `Novograd` optimizer: Novograd is based on the paper "Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks" `<https://arxiv.org/pdf/1905.11286.pdf>`.
# 3. `AMP` (auto mixed precision): AMP is an important feature released in PyTorch v1.6, NVIDIA CUDA 11 added strong support for AMP and significantly improved training speed.

# In[6]:

def train_process(fast=False):
    max_epochs = 30
    learning_rate = 2e-4
    val_interval = 5

    train_trans, val_trans = transformations()
    # set CacheDataset for MONAI training
    if fast:
        train_ds = CacheDataset(
            data=train_files,
            transform=train_trans,
            cache_rate=1.0,
            num_workers=8,
        )
        val_ds = CacheDataset(
            data=val_files, transform=val_trans, cache_rate=1.0, num_workers=5
        )
        # don't need many workers because already cached the data
        loader_workers = 0
    else:
        train_ds = Dataset(data=train_files, transform=train_trans)
        val_ds = Dataset(data=val_files, transform=val_trans)
        loader_workers = 4

    train_loader = ThreadDataLoader(
        train_ds, num_workers=loader_workers, batch_size=4, shuffle=True
    )
    val_loader = ThreadDataLoader(val_ds, num_workers=loader_workers, batch_size=1)

    device = torch.device("cuda:0")
    model = UNet(
        dimensions=3,
        in_channels=1,
        out_channels=2,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        norm=Norm.BATCH,
    ).to(device)
    loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True)

    post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, n_classes=2)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)])

    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

    # set Novograd optimizer for MONAI training
    if fast:
        # Novograd paper suggests to use a bigger LR than Adam,
        # because Adam does normalization by element-wise second moments
        optimizer = Novograd(model.parameters(), learning_rate * 10)
        scaler = torch.cuda.amp.GradScaler()
    else:
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)

    best_metric = -1
    best_metric_epoch = -1
    best_metrics_epochs_and_time = [[], [], []]
    epoch_loss_values = []
    metric_values = []
    epoch_times = []

    total_start = time.time()

    #nvidia_dlprof_pytorch_nvtx.init()
    #with torch.autograd.profiler.emit_nvtx():
    for epoch in range(max_epochs):
        epoch_start = time.time()
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step_start = time.time()
            step += 1
            inputs, labels = (
                batch_data["image"].to(device),
                batch_data["label"].to(device),
            )
            optimizer.zero_grad()
            # set AMP for MONAI training
            if fast:
                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    loss = loss_function(outputs, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                optimizer.step()
            epoch_loss += loss.item()
            epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
            print(
                f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
                f" step time: {(time.time() - step_start):.4f}"
            )
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_labels = (
                        val_data["image"].to(device),
                        val_data["label"].to(device),
                    )
                    roi_size = (160, 160, 160)
                    sw_batch_size = 4
                    # set AMP for MONAI validation
                    if fast:
                        with torch.cuda.amp.autocast():
                            val_outputs = sliding_window_inference(
                                val_inputs, roi_size, sw_batch_size, model
                            )
                    else:
                        val_outputs = sliding_window_inference(
                            val_inputs, roi_size, sw_batch_size, model
                        )
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    dice_metric(y_pred=val_outputs, y=val_labels)

                metric = dice_metric.aggregate().item()
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    best_metrics_epochs_and_time[0].append(best_metric)
                    best_metrics_epochs_and_time[1].append(best_metric_epoch)
                    best_metrics_epochs_and_time[2].append(
                        time.time() - total_start
                    )
                    torch.save(model.state_dict(), "best_metric_model.pth")
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current"
                    f" mean dice: {metric:.4f}"
                    f" best mean dice: {best_metric:.4f}"
                    f" at epoch: {best_metric_epoch}"
                )
        print(
            f"time consuming of epoch {epoch + 1} is:"
            f" {(time.time() - epoch_start):.4f}"
        )
        epoch_times.append(time.time() - epoch_start)

    print(
        f"train completed, best_metric: {best_metric:.4f}"
        f" at epoch: {best_metric_epoch}"
        f" total time: {(time.time() - total_start):.4f}"
    )
    return (
        max_epochs,
        epoch_loss_values,
        metric_values,
        epoch_times,
        best_metrics_epochs_and_time,
    )

# ## Enable determinism and execute MONAI optimized training

# In[7]:

set_determinism(seed=0)
monai_start = time.time()
(
    epoch_num,
    m_epoch_loss_values,
    m_metric_values,
    m_epoch_times,
    m_best,
) = train_process(fast=True)
m_total_time = time.time() - monai_start
print(
    f"total training time of {epoch_num} epochs with MONAI: {m_total_time:.4f}"
)

Thanks.

Nic-Ma commented 3 years ago

@wyli , BTW, here I used ToDeviced to cache data in GPU directly, then execute multi-threads crop based on GPU Tensor in the ThreadDataLoader.

Thanks.

Nic-Ma commented 3 years ago

And I think this is a non-breaking enhancement, users can try it to get better performance, if not helpful for their cases, just don't use it?

Thanks.

wyli commented 3 years ago

I feel at least we should understand the mechanism and make recommendations in the component usage. In this case maybe the random crop transform reads the cache for indices and image data, and because the dataset is large, the cache is actually in the virtual memory. multi-thread is useful in this case to address the IO bottleneck for fetching cache?

Nic-Ma commented 3 years ago

Hi @wyli ,

Thanks for your analysis, that makes sense to me. I also tried to remove ToDeviced, or move EnsureTyped to the end, still got the same test result. I think we can add some doc-string when I develop this PR: "when the image cached in virtual memory or other hardware storage, increasing the num_workers can help improve the IO loading speed." What do you think?

Thanks.

wyli commented 3 years ago

If that's the case perhaps we could have some generic multithreading approach by modifying 'apply_transform' https://github.com/Project-MONAI/MONAI/blob/a38bae30ca96a3e28207538c1e2b8022fbf07571/monai/transforms/transform.py#L91 or somewhere here for the compose: https://github.com/Project-MONAI/MONAI/blob/a38bae30ca96a3e28207538c1e2b8022fbf07571/monai/transforms/compose.py#L160 (I haven't tested this idea yet)

Nic-Ma commented 3 years ago

Hi @wyli ,

I don't quite understand your solution. The main idea of this ticket is to use multi-threads to accelerate the cropping of 1 transform for 1 image, because this transform will crop out 1 list of samples in for-loop. I don't see how to improve it by apply_transform. Please feel free to correct me if I misunderstood it.

Thanks.

Nic-Ma commented 3 years ago

Hi @wyli ,

Maybe I start the PR for RandCropByPosNegLabeld and let's review & discuss in the PR directly?

Thanks.

wyli commented 3 years ago

still, I don't fully understand the source of the speed up. if it is related to the cache reading, we should have a generic solution here, it will affect all the cachedataset based loading pipelines. we shouldn't just provide local solutions at the n_samples related transforms level. any idea here? could you investigate more?

Nic-Ma commented 3 years ago

OK, let me try to investigate more for a better solution.

Thanks.

wyli commented 3 years ago

And in your benchmark script, perhaps ThreadDataLoader's buffer size should be larger https://github.com/Project-MONAI/MONAI/blob/28856b88044e6310d0c5bfff2c088dcbea00324f/monai/data/thread_buffer.py#L38

And with multithreading enqueue cc @ericspod

Nic-Ma commented 3 years ago

Hi @wyli ,

Actually, I tried bigger buffer size at the beginning, no obvious difference was observed. I think it's because the put of the queue is slower than get, so when I used multi-threads in samples-crop, put is faster and the total speed also increased.

Thanks.

wyli commented 3 years ago

I think when the buffer size is larger than one, it should use multiple threads to fill the queue, this might address the issue?

Nic-Ma commented 3 years ago

Hi @wyli ,

Here we only created 1 thread to fill the queue: https://github.com/Project-MONAI/MONAI/blob/dev/monai/data/thread_buffer.py#L67 Because the source data is not thread safe and don't support index access in parallel? https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py#L550 @ericspod Do you have any idea to improve the ThreadDataLoader here?

And @wyli , please note that even we use more threads to fill the queue, it just make the batch computation in parallel, not cropping the samples, for example, we can have batch_size=2 and num_samples=8, increasing threads may help for batch_size=2 but num_samples=8 is still a for-loop in the RandCropByPosNegLabeld transform.

Thanks.

wyli commented 3 years ago

could be something like this (not tested)?

diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py
index 2901335b..748f9fb3 100644
--- a/monai/data/thread_buffer.py
+++ b/monai/data/thread_buffer.py
@@ -33,25 +33,30 @@ class ThreadBuffer:
         timeout: Time to wait for an item from the buffer, or to wait while the buffer is full when adding items
     """

-    def __init__(self, src, buffer_size=1, timeout=0.01):
+    def __init__(self, src, buffer_size=1, num_threads=4, timeout=0.01):
         self.src = src
         self.buffer_size = buffer_size
+        self.num_threads = num_threads
         self.timeout = timeout
         self.buffer = Queue(self.buffer_size)
         self.gen_thread = None
         self.is_running = False

     def enqueue_values(self):
-        for src_val in self.src:
-            while self.is_running:
-                try:
-                    self.buffer.put(src_val, timeout=self.timeout)
-                except Full:
-                    pass  # try to add the item again
-                else:
-                    break  # successfully added the item, quit trying
-            else:  # quit the thread cleanly when requested to stop
-                break
+        from multiprocessing.pool import ThreadPool
+        with ThreadPool(self.num_threads) as p:
+            for item in p.imap(lambda x: x, self.src):
+                buffered = False
+                while not buffered and self.is_running:
+                    try:
+                        self.buffer.put(item, timeout=self.timeout)
+                        buffered = True
+                    except Full:
+                        pass
+                    else:
+                        break
+                if not self.is_running:
+                    return

     def stop(self):
         self.is_running = False  # signal the thread to exit
@@ -86,7 +91,8 @@ class ThreadDataLoader(DataLoader):

     def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs):
         super().__init__(dataset, num_workers, **kwargs)
+        self.buffer_size = kwargs.get("batch_size", 1)

     def __iter__(self):
-        buffer = ThreadBuffer(super().__iter__())
+        buffer = ThreadBuffer(super().__iter__(), buffer_size=self.buffer_size*4)
         yield from buffer
Nic-Ma commented 3 years ago

Hi @wyli ,

Thanks for your great suggestion, I totally agree with you that if we can make the batch_size level to be multi-threads, it's more important than the multi-threads crop-samples. Let's focus on this multi-threads ThreadDataLoader first. About your sample code, maybe we need thread lock when changing the value of buffered? I will try to test this code later.

Thanks in advance.

wyli commented 3 years ago

Thanks, yes, I think we can remove the 'buffered' variable (following the original implementation).

Nic-Ma commented 3 years ago

Hi @wyli ,

I tested your proposal locally, it didn't give improvement on the speed with num_threads=4. I think it's because this line:

for item in p.imap(lambda x: x, self.src)

will complete all the items then start the for-loop, so we can't put & get the Queue in parallel, right?

Attach my test implementation(my test batch_size=4):

class ThreadBuffer:
    def __init__(self, src, buffer_size=4, num_threads=4, timeout=0.01):
        self.src = src
        self.buffer_size = buffer_size
        self.num_threads = num_threads
        self.timeout = timeout
        self.buffer = Queue(self.buffer_size)
        self.gen_thread = None
        self.is_running = False

    def enqueue_values(self):
        with ThreadPool(self.num_threads) as p:
            for item in p.imap(lambda x: x, self.src):
                while self.is_running:
                    try:
                        self.buffer.put(item, timeout=self.timeout)
                    except Full:
                        pass
                    else:
                        break
                else:  # quit the thread cleanly when requested to stop
                    break

    def stop(self):
        self.is_running = False  # signal the thread to exit

        if self.gen_thread is not None:
            self.gen_thread.join()

        self.gen_thread = None

    def __iter__(self):

        self.is_running = True
        self.gen_thread = Thread(target=self.enqueue_values, daemon=True)
        self.gen_thread.start()

        try:
            while self.is_running and (self.gen_thread.is_alive() or not self.buffer.empty()):
                try:
                    yield self.buffer.get(timeout=self.timeout)
                except Empty:
                    pass  # queue was empty this time, try again
        finally:
            self.stop()  # ensure thread completion

class ThreadDataLoader(DataLoader):
    """
    Subclass of `DataLoader` using a `ThreadBuffer` object to implement `__iter__` method asynchronously. This will
    iterate over data from the loader as expected however the data is generated on a separate thread. Use this class
    where a `DataLoader` instance is required and not just an iterable object.
    """

    def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs):
        super().__init__(dataset, num_workers, **kwargs)

    def __iter__(self):
        buffer = ThreadBuffer(super().__iter__())
        yield from buffer

Thanks.

wyli commented 3 years ago

according to the documentation the imap is implemented in a lazy manner -- The other major difference between imap/imap_unordered and map/map_async, is that with imap/imap_unordered, you can start receiving results from workers as soon as they're ready, rather than having to wait for all of them to be finished.

Nic-Ma commented 3 years ago

Hi @wyli ,

Another potential issue of this multi-threads method: many of our random transforms are NOT thread-safe, maybe we are not ready to run them in this way so far..

Thanks.

wyli commented 3 years ago

thanks, that's a good point, I think the lambda function of for item in p.imap(lambda x: x, self.src) could be replaced by a utility to set the random state and then fetch the data item, the random seed could come from a shared counter variable or computed based on the current buffer size. what do you think?

edit: I see, it'll still require a thread lock or deepcopy of the transforms which may slow down things...

Nic-Ma commented 3 years ago

Hi @wyli ,

Yes, I didn't see a good thread-safe solution to provide overall multi-threads support. So I plan to go back to the original smaller idea: investigate deeper why the multi-threads samples-crop can work for my test program.

Thanks.

Nic-Ma commented 3 years ago

Hi @wyli ,

After deeper analysis, I found that the root cause of why multi-threads samples-crop only helps when I cached the foreground and background indices: the deepcopy for every sample costs much time when copying the fg and bg indices, which are actually unnecessary after this crop transform: https://github.com/Project-MONAI/MONAI/blob/dev/monai/transforms/croppad/dictionary.py#L1114 So I submitted a PR to avoid deep-copying this unnecessary data: https://github.com/Project-MONAI/MONAI/pull/2804 With this PR, the training is much faster, even slightly faster than the previous multi-threads samples-crop. And I also tried to use multi-threads again, no obvious improvement was observed anymore. And we also don't need to manually remove the items in the transform chain anymore:

DeleteItemsd(keys=["label_fg", "label_bg"])

So I think we almost solved the perf issue of this crop, thanks very much for your great discussion & analysis. Could you please help review the PR?

Thanks in advance.

ericspod commented 3 years ago

I'm a bit late to the party but my few observations. Multi-threading has a number of problems relating to the GIL as we know but often we can route around that by using compiled functions in Numpy, Scipy, Pytorch, etc. Mixing threads and processes will be inefficient regardless because we typically create as many processes as we have CPU cores (virtual and physical). If we have multiple threads in these subprocesses running a transform pipeline there will be more threads than CPU cores and that will lose efficiency through contention. I don't think the advantages of accessing memory efficiently in threads would overcome that. I generally would suggest using either threads or processes for parallelism and not to mix.

The problem here with RandCropByPosNegLabeld is that this is a one-to-many transform where generating the many with multiple threads may be faster. If this is used on its own this might be the case but if used with multiple processes you could create too many threads. I would think it would be faster to change the transform to be one-to-one so that you get one cropped image for each input but give each transform the same input, ie. if you had a batch of duplicate images. This might work for particular use cases that would expect one-to-many but I'm personally not sure how this class is used now so I can't say for sure this makes sense.

ThreadDataLoader doesn't benefit from having a buffer size larger than 1 typically unless the sizes of the buffered objects vary wildly and so take varying amounts of time to generate. I left the option to change the buffer size in the original implementation to allow experimentation. The idea of ThreadDataLoader is to permit a separate thread to read from a data loader it had exclusive access to so that thread-safety wasn't a problem.

One idea I've been meaning to try is a DataLoader using threads instead of processes which would lack the interprocess communication overhead the current implementation must have, but would rely on using a lot of compiled functions to not get bogged down by the GIL and thread-safety of the source DataSet. This might be entirely unrelated to the problem at hand with the transforms being one-to-many.

I think @Nic-Ma has written something on results that make all this less meaningful but it's here for us to consider for later.

Nic-Ma commented 3 years ago

Hi @ericspod ,

Thanks very much for your deep analysis! Yes, we still need to keep optimizing the ThreadDataLoader with multi-threads and solve the thread-safe problem in our current random transforms design. And I got the idea to use ThreadDataLoader from NVIDIA DALI project, it uses multi-threads to build pipeline instead of the PyTorch multi-processing DataLoader.

Thanks.

wyli commented 3 years ago

Thanks for the update, looks great, I think we can have a separate ticket to enhance the thread based loader.