IDRnD / ReDimNet

The official pytorch implemention of the Intespeech 2024 paper "Reshape Dimensions Network for Speaker Recognition"
MIT License
93 stars 5 forks source link

Process with batched input slower as compared to sequential processing #11

Closed debal-goyoyo closed 1 month ago

debal-goyoyo commented 1 month ago

Hi guys, thanks for this awesome model have been using it for the past two weeks. Recently was trying to use it for batched inference, but weirdly the time taken to process a batch input is slower than sequentially processing it. Is this expected ? Attaching the code and the relevant screenshots below:

# Model Loading
embedding_model = torch.hub.load('IDRnD/ReDimNet', 'b6', pretrained=True, finetuned=False)
embedding_model= embedding_model.to(torch.device("cuda"))
embedding_model.eval();

# Sequential processing code
embeds = []
with torch.no_grad():
    for s in seg:
        st = int(s['start']*16000)
        en = int(s['end']*16000)
        sample_input = waveform[None,:,st:en]
        embedding = embedding_model(sample_input.to("cuda"))
        embeds.append(embedding[0].cpu())
# time taken to process 10.5 seconds

# batching

# Dataset code
from torch.utils.data import Dataset,DataLoader
from typing import List,Dict
class SegmentationDataset(Dataset):

    def __init__(
            self,
            audio_tensor:torch.Tensor,
            segments:List[Dict],
            sample_rate:int=16000
            ) -> None:
        super().__init__()

        self.tensor = audio_tensor
        self.segment = segments
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.segment)

    def __getitem__(self, index) -> torch.Tensor:

        seg = self.segment[index]
        start = int(seg['start']*self.sample_rate)
        end = int(seg['end']*self.sample_rate)
        segment_tensor = self.tensor[0,start:end]
        return segment_tensor

def collate_fn(batch):
    max_length = max(tensor.size(0) for tensor in batch)
    padded_batch = []
    for tensor in batch:
        padding = max_length - tensor.size(0)
        if padding > 0:
            padded_tensor = torch.nn.functional.pad(tensor, (0, padding), value=0)
        else:
            padded_tensor = tensor
        padded_batch.append(padded_tensor)
    return torch.stack(padded_batch)

dataset = SegmentationDataset(waveform,seg)
dataloader = DataLoader(dataset=dataset,batch_size=128,shuffle=False,collate_fn=collate_fn)

with torch.no_grad():
    for batch in dataloader:
        embedding = embedding_model(batch.to("cuda"))
        embeds_batch.append(embedding.cpu())
embeds_batch = torch.vstack(embeds_batch)
# time taken to process 1m 15s

Screenshots

image image
vanIvan commented 1 month ago

Hello @debal-goyoyo we are glad to hear that models have been used! I'll able to look deeper on Monday, from the first look at your code: what is the number of workers used in data loading? Seems like it is set by default to 1, could you change it to 8 for example

debal-goyoyo commented 1 month ago

@vanIvan, here is the time taken to iterate over the whole dataset

image

I don't think the bottleneck is in data loading as throughout the process, the GPU utilization is 100%.

vanIvan commented 1 month ago

I see, batch processing should be faster, and in our cases batch processing is always the fastest way to process lots of utterances.

Could you please also check max / min / std / mean of segments lengths. You are padding them inside collate function - probably there are some very long segments, and short segments are padded to same length?

debal-goyoyo commented 1 month ago

Just wrote this short snippet to benchmark the time taken to process the input in sequential manner as well as batching using sample input

import time
from torch.utils.data import DataLoader
import torch

embedding_model = torch.hub.load('IDRnD/ReDimNet', 'b6', pretrained=True, finetuned=False)
embedding_model= embedding_model.to(torch.device("cuda"))
embedding_model.eval();

for len in range(1,7):
    total_length = 256
    sr = 16000
    c = 1
    sample_inp = torch.randn(total_length,c,len*sr)
    with torch.no_grad():
        start = time.time()
        for item in sample_inp:
            embed = embedding_model(item.to("cuda"))
        seq_time = time.time() - start

        dataloader = DataLoader(
            sample_inp,
            batch_size=128,
            shuffle=False,
            num_workers=4
            )
        start_b = time.time()
        for batch in dataloader:

            embed = embedding_model(batch.to("cuda"))

        batch_time = time.time() - start_b

    print(f'Time taken for audio length {len} seconds || Sequential Time: {seq_time} seconds || Batched Time: {batch_time} seconds')

Here is the results for the above code

Time taken for audio length 1 seconds || Sequential Time: 8.393802404403687 seconds || Batched Time: 8.831442594528198 seconds
Time taken for audio length 2 seconds || Sequential Time: 14.619158744812012 seconds || Batched Time: 20.736278772354126 seconds
Time taken for audio length 3 seconds || Sequential Time: 21.047341108322144 seconds || Batched Time: 31.216009855270386 seconds
Time taken for audio length 4 seconds || Sequential Time: 28.143197298049927 seconds || Batched Time: 41.43461489677429 seconds
Time taken for audio length 5 seconds || Sequential Time: 36.43804621696472 seconds || Batched Time: 51.67773962020874 seconds
Time taken for audio length 6 seconds || Sequential Time: 44.87288522720337 seconds || Batched Time: 62.44116187095642 seconds

I have also tried to use a profiler with the above code as well, below is the code for that

# Sequential
with torch.no_grad():
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("model_inference"):
            for item in sample_inp:
                    embed = embedding_model(item[None].to("cuda"))
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# Batching
with torch.no_grad():
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("model_inference"):
            for batch in dataloader:
                embed = embedding_model(batch.to("cuda"))
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

Got the following output

Sequential

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void implicit_convolve_sgemm<float, float, 128, 5, 5...         0.00%       0.000us         0.00%       0.000us       0.000us        2.883s        26.72%        2.883s      58.664us         49152  
cudnn_infer_volta_scudnn_winograd_128x128_ldg1_ldg4_...         0.00%       0.000us         0.00%       0.000us       0.000us        2.579s        23.90%        2.579s      30.993us         83200  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     838.983ms         7.78%     838.983ms      18.516us         45312  
                                  volta_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us     745.927ms         6.91%     745.927ms      64.751us         11520  
void cudnn::bn_fw_inf_1C11_kernel_NCHW<float, float,...         0.00%       0.000us         0.00%       0.000us       0.000us     703.985ms         6.52%     703.985ms      25.700us         27392  
                                  volta_sgemm_128x32_nn         0.00%       0.000us         0.00%       0.000us       0.000us     480.049ms         4.45%     480.049ms      55.153us          8704  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     479.265ms         4.44%     479.265ms      22.556us         21248  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     477.829ms         4.43%     477.829ms      23.043us         20736  
                                 volta_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us     285.188ms         2.64%     285.188ms     123.780us          2304  
void cudnn::winograd::generateWinogradTilesKernel<0,...         0.00%       0.000us         0.00%       0.000us       0.000us     250.065ms         2.32%     250.065ms       3.006us         83200  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.527s
Self CUDA time total: 10.790s

Batched

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
void cudnn::ops::genericTranspose_kernel<float, floa...         0.00%       0.000us         0.00%       0.000us       0.000us       58.303s        87.10%       58.303s      28.247ms          2064  
cudnn_infer_volta_scudnn_winograd_128x128_ldg1_ldg4_...         0.00%       0.000us         0.00%       0.000us       0.000us        3.456s         5.16%        3.456s       3.342ms          1034  
     cudnn_infer_volta_scudnn_128x128_relu_medium_nn_v1         0.00%       0.000us         0.00%       0.000us       0.000us     982.978ms         1.47%     982.978ms       8.474ms           116  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     867.299ms         1.30%     867.299ms       2.450ms           354  
void cudnn::bn_fw_inf_1C11_kernel_NCHW<float, float,...         0.00%       0.000us         0.00%       0.000us       0.000us     775.614ms         1.16%     775.614ms       3.624ms           214  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     685.520ms         1.02%     685.520ms       4.232ms           162  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     553.244ms         0.83%     553.244ms       3.333ms           166  
      cudnn_infer_volta_scudnn_128x64_relu_medium_nn_v1         0.00%       0.000us         0.00%       0.000us       0.000us     390.257ms         0.58%     390.257ms       5.913ms            66  
void at::native::(anonymous namespace)::CatArrayBatc...         0.00%       0.000us         0.00%       0.000us       0.000us     172.726ms         0.26%     172.726ms      14.394ms            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     156.590ms         0.23%     156.590ms       4.893ms            32  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 66.536s
Self CUDA time total: 66.937s

My hunch is that the bottleneck is somewhere on the process where the tensors are getting converted from 2-D to 1-D and vice-versa, from the above profiler result cause for batching the most time taken by the transpose operation.

vanIvan commented 1 month ago

I'll check it on Monday, thank you for providing code for replication.

vanIvan commented 1 month ago

I've checked performance of model on the following VM (used this script for getting system info: https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py):

PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31

Python version: 3.9.0 (default, Nov 15 2020, 14:28:56)  [GCC 7.3.0] (64-bit runtime)
Python platform: Linux-5.10.0-32-cloud-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA L4
Nvidia driver version: 550.54.15
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Byte Order:                           Little Endian
Address sizes:                        46 bits physical, 48 bits virtual
CPU(s):                               8
On-line CPU(s) list:                  0-7
Thread(s) per core:                   2
Core(s) per socket:                   4
Socket(s):                            1
NUMA node(s):                         1
Vendor ID:                            GenuineIntel
CPU family:                           6
Model:                                85
Model name:                           Intel(R) Xeon(R) CPU @ 2.20GHz
Stepping:                             7
CPU MHz:                              2200.198
BogoMIPS:                             4400.39
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            128 KiB
L1i cache:                            128 KiB
L2 cache:                             4 MiB
L3 cache:                             38.5 MiB
NUMA node0 CPU(s):                    0-7
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities

Versions of relevant libraries:
[pip3] flake8==3.8.2
[pip3] flake8-bugbear==23.3.12
[pip3] flake8-comprehensions==3.14.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-pyi==20.5.0
[pip3] gpytorch==1.11
[pip3] numpy==1.24.4
[pip3] onnxruntime==1.16.3
[pip3] pytorch-lightning==1.6.5
[pip3] torch==2.1.1
[pip3] torch-audiomentations==0.11.1
[pip3] torch-pitch-shift==1.2.4
[pip3] torchaudio==2.1.1
[pip3] torchdata==0.7.1
[pip3] torchlibrosa==0.1.0
[pip3] torchmetrics==1.3.1
[pip3] torchnet==0.0.4
[pip3] torchvision==0.16.1
[pip3] torchviz==0.0.2
[pip3] triton==2.1.0
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] gpytorch                  1.11                     pypi_0    pypi
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.24.4                   pypi_0    pypi
[conda] pytorch                   2.1.1           py3.9_cuda11.8_cudnn8.7.0_0    pytorch
[conda] pytorch-cuda              11.8                 h7e8668a_5    pytorch
[conda] pytorch-lightning         1.6.5                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch-audiomentations     0.11.1                   pypi_0    pypi
[conda] torch-pitch-shift         1.2.4                    pypi_0    pypi
[conda] torchaudio                2.1.1                py39_cu118    pytorch
[conda] torchdata                 0.7.1                    pypi_0    pypi
[conda] torchlibrosa              0.1.0                    pypi_0    pypi
[conda] torchmetrics              1.3.1                    pypi_0    pypi
[conda] torchnet                  0.0.4                    pypi_0    pypi
[conda] torchvision               0.16.1               py39_cu118    pytorch
[conda] torchviz                  0.0.2                    pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypi

And I'm getting different results:

Time taken for audio length 1 seconds || Sequential Time: 10.33 seconds || Batched Time: 2.29 seconds
Time taken for audio length 2 seconds || Sequential Time: 11.91 seconds || Batched Time: 4.17 seconds
Time taken for audio length 3 seconds || Sequential Time: 13.63 seconds || Batched Time: 5.27 seconds
Time taken for audio length 4 seconds || Sequential Time: 15.13 seconds || Batched Time: 8.23 seconds
Time taken for audio length 5 seconds || Sequential Time: 18.58 seconds || Batched Time: 8.19 seconds
Time taken for audio length 6 seconds || Sequential Time: 18.04 seconds || Batched Time: 9.59 seconds

I'll add output from profiler later, cause I'm having issues with my python enviroment.

For now I see that the main reason for performance difference may be in different torch/cuda versions you and I use.

vanIvan commented 1 month ago

seq_profile.txt batch_profile.txt

I've generated profiler results for cuda, and there is no transpose operation that eats most of time - it's convolutions, as expected. I'm attaching profiler results here.

There is also a way to speed up inference even more (in batch and sequential processing) by employing mixed precision:

embedding_model = torch.hub.load('IDRnD/ReDimNet', 'b6', pretrained=True, finetuned=False)
embedding_model = embedding_model.to(device)
spectrogram = embedding_model.spec
embedding_model.spec = torch.nn.Identity()
embedding_model = embedding_model.half()
spectrogram.eval()
embedding_model.eval()

for len in range(1,7):
    total_length = 256
    sr = 16000
    c = 1
    sample_inp = torch.randn(total_length,c,len*sr)
    with torch.no_grad():
        start = time.time()
        for item in sample_inp:
            spec = spectrogram(item.to(device)).half()
            embed = embedding_model(spec)
        seq_time = time.time() - start

        dataloader = DataLoader(
            sample_inp,
            batch_size=total_length//2,
            shuffle=False,
            num_workers=4
            )
        start_b = time.time()
        for batch in dataloader:
            specs = spectrogram(batch.to(device)).half()
            embed = embedding_model(specs)

        batch_time = time.time() - start_b

    print(f'Time taken for audio length {len} seconds || Sequential Time: {seq_time} seconds || Batched Time: {batch_time} seconds')

Here are there results (which could be compared to previously shared by me):

Time taken for audio length 1 seconds || Sequential Time: 9.33 seconds || Batched Time: 1.84 seconds
Time taken for audio length 2 seconds || Sequential Time: 10.19 seconds || Batched Time: 2.29 seconds
Time taken for audio length 3 seconds || Sequential Time: 10.25 seconds || Batched Time: 2.90 seconds
Time taken for audio length 4 seconds || Sequential Time: 10.53 seconds || Batched Time: 3.54 seconds
Time taken for audio length 5 seconds || Sequential Time: 10.32 seconds || Batched Time: 4.14 seconds
Time taken for audio length 6 seconds || Sequential Time: 10.62 seconds || Batched Time: 5.04 seconds

So the best speed up you can have for 6 seconds: 18.04 sec fp32 sequential -> 5.04 sec fp16 batched - ~x3.5 speed up

debal-goyoyo commented 1 month ago

@vanIvan , thanks for the suggestion will use half-precision.

debal-goyoyo commented 1 month ago

@vanIvan seems to be some issue with T4 GPU, switched my GPU to A10, got the following results:

Time taken for audio length 1 seconds || Sequential Time: 7.962625741958618 seconds || Batched Time: 0.6435325145721436 seconds
Time taken for audio length 2 seconds || Sequential Time: 7.51838755607605 seconds || Batched Time: 1.0514562129974365 seconds
Time taken for audio length 3 seconds || Sequential Time: 7.9794416427612305 seconds || Batched Time: 1.5150020122528076 seconds
Time taken for audio length 4 seconds || Sequential Time: 8.454184293746948 seconds || Batched Time: 2.174407482147217 seconds
Time taken for audio length 5 seconds || Sequential Time: 9.016132354736328 seconds || Batched Time: 2.396376609802246 seconds
Time taken for audio length 6 seconds || Sequential Time: 9.463118553161621 seconds || Batched Time: 3.1102302074432373 seconds
vanIvan commented 1 month ago

Great! Closing the issue