NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
5.18k stars 621 forks source link

Performance difference between TCP and RDMA in multipath storage setup using DALI #5727

Closed asdfry closed 6 hours ago

asdfry commented 6 days ago

Describe the question.

Hello,

I have converted the entire ImageNet dataset to UHD format and stored it in NumPy format. I then ran several training scenarios for ResNet50 to compare the time taken for a single epoch. The results are as shown in the attached image.

I am noticing a surprising outcome: when using the multipath feature of the storage, the TCP protocol is outperforming the RDMA protocol in terms of epoch time. I expected RDMA to be faster than TCP due to its streamlined stack and lower latency. (The white bar in the chart below represents RDMA + Multipath, while the blue bar represents TCP + Multipath.) Image

Could you kindly explain why TCP appears to be faster in this multipath scenario compared to RDMA? Below are the mount options I used when enabling multipath:

Your insights would be greatly appreciated.

Thank you in advance for your support!

Check for duplicates

JanuszL commented 6 days ago

Hi @asdfry,

Thank you for reaching out. Can you provide more info about the pipeline you run and the configuration of your system? It would be beneficial to capture an NSight System profile to ensure that the difference comes from the data loading itself.

asdfry commented 6 days ago

Hi,

I used a slightly modified script based on this example code. The part where the DALI pipeline is created is as follows:

@pipeline_def
def create_dali_pipeline(
    data_dir,
    crop,
    seed,
    shard_id,
    num_shards,
    prefetch_queue,
    device,
    is_training=True,
    o_direct=False,
):
    images = fn.readers.numpy(
        device=device,
        file_root=data_dir,
        file_filter="*.image.npy",
        shard_id=shard_id,
        num_shards=num_shards,
        seed=seed,
        prefetch_queue_depth=prefetch_queue,
        pad_last_batch=True,
        shuffle_after_epoch=is_training,
        dont_use_mmap=True,
        use_o_direct=o_direct,
        name="Reader",
    )

    labels = fn.readers.numpy(
        device=device,
        file_root=data_dir,
        file_filter="*.label.npy",
        shard_id=shard_id,
        num_shards=num_shards,
        seed=seed,
        prefetch_queue_depth=prefetch_queue,
        pad_last_batch=True,
        shuffle_after_epoch=is_training,
        dont_use_mmap=True,
        use_o_direct=o_direct,
    )

    images = fn.crop_mirror_normalize(
        images.gpu(),
        dtype=types.FLOAT,
        output_layout="CHW",
        crop=(crop, crop),
        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
    )

    labels = labels.gpu()

    return images, labels

The experiment was conducted on a Kubernetes cluster, with the node tested being a DGX-H100. I used Vast CSI to create two PVs with Multipath + TCP and RDMA modes, respectively, and preloaded the dataset onto them. The training process was run for three epochs (the bars in the chart represent the average time taken), and each job was sequentially launched using the Kubeflow pipeline.

The epoch times:

The network traffic:

The only difference between these two cases is the storage mounting method...

JanuszL commented 6 days ago

Thank you for providing details. These are interesting observations. Have you run storage transfer tests outside DALI to see what it looks like in general? Maybe there is nothing specific to DALI in this behavior?

asdfry commented 5 days ago

I mounted the storage directly to the server using two different protocols and ran a read test using dd. It turns out that RDMA is roughly twice as fast. Image Left mount command: sudo mount -o nconnect=4,remoteports=192.168.0.21-192.168.0.24 192.168.0.20:/ten-data mnt-tcp Right mount command: sudo mount -o proto=rdma,nconnect=4,remoteports=192.168.0.21-192.168.0.24 192.168.0.20:/ten-data mnt-rdma

And here are the results from running ResNet50 again for both cases yesterday:

For TCP, the time per epoch significantly decreases as epochs go by. Although I'm not sure of the exact reason, could it be because TCP utilizes caching more effectively than RDMA? In that case, should I only use the time taken for the first epoch to accurately compare the storage performance?

JanuszL commented 5 days ago

Hi @asdfry,

I'm not familiar with the details of both ways of utilizing network storage solutions, and how, and if OS IO caches are involved in any of the mentioned cases. To obtain a more comprehensive overview of file system performance, you can check the iozone benchmark and see how reading and re-reading (when disc caches start to apply) perform. Also, if you set use_o_direct=True, the disc cache should be bypassed; on the other hand, the read is done in chunks. As I mentioned, the last thing you can check is the NSigh profile which should show the syscalls involved and how much the data reading contributes to the overall performance.

asdfry commented 14 hours ago

Hello, I appreciate your continued support and guidance.

Lastly, I tried running a script that only repeats the data loader without the model training, to compare performance. (I plan to use Nsight System after this test...)

Here is the Python script I'm running:

import time
import os
import torch
import argparse

from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--chunk_size", type=str, default="16M", choices=["1M", "2M", "4M", "8M", "16M"])
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--nconnect", type=int, default=4, choices=[4, 8])
    parser.add_argument("--protocol", type=str, required=True, choices=["tcp", "rdma"])
    parser.add_argument("--prefetch_queue", type=int, default=1)
    parser.add_argument("--workers", type=int, default=4, choices=[4, 16, 32])
    parser.add_argument("--odirect", action="store_true")
    return parser.parse_args()

@pipeline_def
def create_dali_pipeline(
    data_dir,
    crop,
    user_seed,
    shard_id,
    num_shards,
    prefetch_queue,
    device,
    is_training=True,
    o_direct=False,
):
    images = fn.readers.numpy(
        device=device,
        file_root=data_dir,
        file_filter="*.image.npy",
        shard_id=shard_id,
        num_shards=num_shards,
        seed=user_seed,
        prefetch_queue_depth=prefetch_queue,
        pad_last_batch=True,
        shuffle_after_epoch=is_training,
        dont_use_mmap=True,
        use_o_direct=o_direct,
        name="Reader",
    )

    labels = fn.readers.numpy(
        device=device,
        file_root=data_dir,
        file_filter="*.label.npy",
        shard_id=shard_id,
        num_shards=num_shards,
        seed=user_seed,
        prefetch_queue_depth=prefetch_queue,
        pad_last_batch=True,
        shuffle_after_epoch=is_training,
        dont_use_mmap=True,
        use_o_direct=o_direct,
    )

    images = fn.crop_mirror_normalize(
        images.gpu(),
        dtype=types.FLOAT,
        output_layout="CHW",
        crop=(crop, crop),
        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
    )

    labels = labels.gpu()

    return images, labels

def train(epoch, train_loader, max_batches, local_rank):
    if local_rank == 0:
        print(f"Start iterating")
    for i in range(epoch):
        start_time = time.time()

        for batch_idx, data in enumerate(train_loader):
            continue
            # if local_rank == 0 and (batch_idx + 1) % 25 == 0:
            #     images, labels = data[0]["data"], data[0]["label"]
            #     print(
            #         f"Batch {batch_idx + 1}/{max_batches}: Image size={images.shape}, Label size={labels.shape}"
            #     )

        end_time = time.time()
        if local_rank == 0:
            print(f"Epoch {i+1}/{epoch} completed, Time taken: {end_time - start_time:.2f} seconds")

def main():
    args = parse_arguments()
    os.environ["DALI_GDS_CHUNK_SIZE"] = args.chunk_size

    if not torch.distributed.is_initialized():
        torch.distributed.init_process_group(backend="nccl", init_method="env://")

    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = torch.distributed.get_world_size()

    if args.protocol == "rdma":
        data_dir = (
            f"/root/mnt-vast-data-multipath-rdma-nc{args.nconnect}/datasets/imagenet-eighth-numpy-uhd/train"
        )
    else:
        data_dir = f"/root/mnt-vast-data-multipath-nc{args.nconnect}/datasets/imagenet-eighth-numpy-uhd/train"

    if local_rank == 0:
        print(args)

    # DALI pipeline creation
    if local_rank == 0:
        print(f"Start building dali pipeline")
    train_pipe = create_dali_pipeline(
        batch_size=args.batch_size,
        num_threads=args.workers,
        device_id=local_rank,
        user_seed=12 + local_rank,
        shard_id=local_rank,
        prefetch_queue=args.prefetch_queue,
        data_dir=data_dir,
        crop=224,
        device="gpu",
        num_shards=world_size,
        is_training=True,
    )
    train_pipe.build()
    train_loader = DALIClassificationIterator(
        train_pipe,
        reader_name="Reader",
        last_batch_policy=LastBatchPolicy.PARTIAL,
        auto_reset=True,
    )

    # Train
    max_batches = (train_loader.size + args.batch_size - 1) // args.batch_size
    train(args.epochs, train_loader, max_batches, local_rank)

if __name__ == "__main__":
    main()

I created a Kubernetes pod with --privileged mode (to clear caches during testing) and executed the script. However, when using the tcp protocol, I encountered the following error:

root@pnode5:~# torchrun --nproc_per_node=8 vast_test.py --protocol tcp
W1204 02:12:11.272000 140235346100672 torch/distributed/run.py:757]                                                                                   
W1204 02:12:11.272000 140235346100672 torch/distributed/run.py:757] *****************************************                                         
W1204 02:12:11.272000 140235346100672 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, 
to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.                        
W1204 02:12:11.272000 140235346100672 torch/distributed/run.py:757] *****************************************                                         
Namespace(batch_size=256, chunk_size='16M', epochs=10, nconnect=4, protocol='tcp', prefetch_queue=1, workers=4, odirect=True)                         
Start building dali pipeline                                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
Assertion failure, file index :0  line :1388                                                                                                          
E1204 02:14:56.434000 140235346100672 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: -6) local_rank: 0 (pid: 2982) of binary:
 /usr/bin/python                                                                                                                                      
Traceback (most recent call last):                                                                                                                    
  File "/usr/local/bin/torchrun", line 33, in <module>                     
    sys.exit(load_entry_point('torch==2.4.0a0+07cecf4168.nv24.5', 'console_scripts', 'torchrun')())                                                   
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper                   
    return f(*args, **kwargs)                                              
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 879, in main                                                          
    run(args)                                                                                                                                         
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 870, in run                                                           
    elastic_launch(                                                                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 132, in __call__                                             
    return launch_agent(self._config, self._entrypoint, list(args))                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 263, in launch_agent                                         
    raise ChildFailedError(                                                                                                                           
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:                                                                                    
=====================================================                                                                                                 
vast_test.py FAILED                                                                                                                                   
-----------------------------------------------------                                                                                                 
Failures:                                                                  
[1]:                                                                       
  time      : 2024-12-04_02:14:56                                          
  host      : pnode5.idc1.ten1010.io                                                                                                                  
  rank      : 1 (local_rank: 1)                                                                                                                       
  exitcode  : -6 (pid: 2983)                                                                                                                          
  error_file: <N/A>                                                                                                                                   
  traceback : Signal 6 (SIGABRT) received by PID 2983                                                                                                 
[2]:                                                                                                                                                  
  time      : 2024-12-04_02:14:56                                                                                                                     
  host      : pnode5.idc1.ten1010.io                                                                                                                  
  rank      : 2 (local_rank: 2)                                                                                                                       
  exitcode  : -6 (pid: 2984)                                                                                                                          
  error_file: <N/A>                                                                                                                                   
  traceback : Signal 6 (SIGABRT) received by PID 2984                                                                                                 
[3]:                                                                                                                                                  
  time      : 2024-12-04_02:14:56                                                                                                                     
  host      : pnode5.idc1.ten1010.io                                       
  rank      : 3 (local_rank: 3)
  exitcode  : -6 (pid: 2985)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2985
[4]:
  time      : 2024-12-04_02:14:56
  host      : pnode5.idc1.ten1010.io
  rank      : 4 (local_rank: 4)
  exitcode  : -6 (pid: 2986)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2986
[5]:
  time      : 2024-12-04_02:14:56
  host      : pnode5.idc1.ten1010.io
  rank      : 5 (local_rank: 5)
  exitcode  : -6 (pid: 2987)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2987
[6]:
  time      : 2024-12-04_02:14:56
  host      : pnode5.idc1.ten1010.io
  rank      : 6 (local_rank: 6)
  exitcode  : -6 (pid: 2988)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2988
[7]:
  time      : 2024-12-04_02:14:56
  host      : pnode5.idc1.ten1010.io
  rank      : 7 (local_rank: 7)
  exitcode  : -6 (pid: 2989)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2989
-----------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-12-04_02:14:56
  host      : pnode5.idc1.ten1010.io
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 2982)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2982
=====================================================

Interestingly, when switching to the rdma protocol, the script works without any issues. Furthermore, when the pod is not created in --privileged mode, the script also runs without issues, even with the tcp protocol.

Additionally, the cufile.log file is not being generated when the error occurs.

Is GDS only available in an RDMA environment? How is TCP + GPU configuration possible in a container where privileged mode is disabled? Then I need to perform the experiments in a container with privileged mode enabled, and the possible combinations are as follows: Environment Reader device="cpu" Reader device="gpu"
TCP Supported Not Supported
RDMA Supported Supported
JanuszL commented 6 hours ago

Hi @asdfry,

Is GDS only available in an RDMA environment? How is TCP + GPU configuration possible in a container where privileged mode is disabled?

I think it would be best to run the GDS diagnosis tool and see what and when it is supported. The Assertion failure, file index comes from libcufile, so not DALI. I think it would be best to ask on the NVIDIA dev forum, and the issues are not related to DALI. As I see it, DALI correctly interfaces and delegates work to corresponding GDS libraries, which fails. I'm afraid we don't have sufficient insight into them to guide you further regarding your problem debugging.

asdfry commented 6 hours ago

Thank you once again for your helpful response! I truly appreciate your guidance.