huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.4k stars 27.09k forks source link

data load speed is quite slow when dataloader_num_workers=0 #34059

Open geekifan opened 1 month ago

geekifan commented 1 month ago

System Info

Who can help?

@muellerzr @SunMarc

Information

Tasks

Reproduction

My dataset:

import torch
import os
import json
from utils.video import read_frames_decord
from torchvision.transforms import Compose, Resize, CenterCrop, RandomResizedCrop, RandomHorizontalFlip, ToTensor, Normalize
from PIL import Image

class DatasetForOfflineDistill(torch.utils.data.Dataset):
    def __init__(
            self, 
            anno_path: str | os.PathLike, 
            data_root: str | os.PathLike,
            feat_path: str | os.PathLike,
            tokenizer: torch.nn.Module | None = None,
            tokenize: bool = False,
            num_frames: int = 8,
            test: bool = False
        ):
        with open(anno_path) as f:
            self.anno = json.load(f)
        self.data_root = data_root
        # keys of each item: idx, text_embeds, video_embeds
        self.feat = torch.load(feat_path, weights_only=True)
        self.num_frames = num_frames
        self.transforms = self.build_transforms(test)
        self.tokenizer = tokenizer
        self.tokenize = tokenize

    def build_transforms(self, test: bool):
        image_mean =  [
            0.48145466,
            0.4578275,
            0.40821073
        ]
        image_std = [
            0.26862954,
            0.26130258,
            0.27577711
        ]
        size = 224
        normalize = (
            Normalize(mean=image_mean, std=image_std)
        )
        train_transforms = Compose(
            [
                RandomResizedCrop(size),
                RandomHorizontalFlip(),
                ToTensor(),
                normalize,
            ]
        )
        val_transforms = Compose(
            [
                Resize(size),
                CenterCrop(size),
                ToTensor(),
                normalize,
            ]
        )
        if test:
            return val_transforms
        return train_transforms

    def __len__(self):
        return len(self.anno)

    def __getitem__(self, idx):
        rank = int(os.environ.get("LOCAL_RANK") or 0)
        # HERE IS THE DEBUG MESSAGE
        now = datetime.now()
        dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
        print(f'[{dt_string}] Rank {rank} is loading', idx)
        item = self.feat[idx]
        anno_idx = item['idx']
        # [teacher_dim] -> [1, teacher_dim]
        text_embeds = item['text_embeds']
        video_embeds = item['video_embeds']
        caption = self.anno[anno_idx]['caption']
        if self.tokenizer is not None and self.tokenize:
            tokenized_caption = self.tokenizer(caption)
            caption = {
                'input_ids': tokenized_caption['input_ids'],
                'attention_mask': tokenized_caption['attention_mask'],
            }
        video_path = os.path.join(self.data_root, self.anno[anno_idx]['video'])
        video = read_frames_decord(video_path, num_frames=self.num_frames).numpy()
        frames = [self.transforms(Image.fromarray(frame)) for frame in video]
        frames = torch.stack(frames)
        return {
            'caption': caption, 
            'video': frames, 
            'text_embeds': text_embeds, 
            'video_embeds': video_embeds
        }

Part of my training script:

train_data = DatasetForOfflineDistill(
        anno_path=data_config['anno_path'],
        data_root=data_config['data_root'],
        feat_path=data_config['feat_paths'][teacher_type],
        tokenize=False,
        num_frames=num_frames,
    )

    def custom_collate_fn(batch):
        # batch is a list of dicts
        collated_batch = {}
        for key in batch[0].keys():
            collated_batch[key] = [b[key] for b in batch]
        # collated_batch['video'] is a list of [num_frames, 3, 224, 224]
        # collated_batch['caption'] is a list of strings
        tokenized_caption = model.student_caller.tokenizer(collated_batch['caption'], padding=True, return_tensors="pt")
        collated_batch['input_ids'] = tokenized_caption['input_ids']
        collated_batch['attention_mask'] = tokenized_caption['attention_mask']
        collated_batch['pixel_values'] = torch.stack(collated_batch['video'])
        collated_batch['video_embeds'] = torch.stack(collated_batch['video_embeds'])
        collated_batch['text_embeds'] = torch.stack(collated_batch['text_embeds'])
        return collated_batch

    trainer = Trainer(
        model=model,
        train_dataset=train_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_ratio=warmup_ratio,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            fp16=True if not bf16 else False,
            bf16=bf16,
            logging_steps=logging_steps,
            save_strategy="steps",
            eval_steps=None,
            save_steps=save_steps,
            output_dir=output_dir,
            save_total_limit=1,
            load_best_model_at_end=False,
            ddp_find_unused_parameters=False if ddp else None,
            run_name=run_name,
            report_to=None,
            deepspeed=deepspeed,
            gradient_checkpointing=grad_checkpoint,
            remove_unused_columns=False,
            dataloader_num_workers=0,
            dataloader_pin_memory=True,
            # dataloader_prefetch_factor=10,
            # dataloader_persistent_workers=True,
        ),
        data_collator=custom_collate_fn,
    )

The model is a simple CLIPModel.

If dataloader_num_workers=0 and dataloader_pin_memory=True, the load of cpu is around 1000 but the print speed of the debug message(see my code above) is about 1-2/sec. See the image below.

image image

If dataloader_num_workers=4, dataloader_pin_memory=True, dataloader_prefetch_factor=2 and dataloader_persistent_workers=True, the load of cpu is around 100 and the print speed of the debug message(see my code above) is above 20/sec.

image image

Expected behavior

  1. The speed should be the same whatever the setting. (at least dataloader_num_workers=0 is slower than dataloader_num_workers=4)
  2. The dataloader should prefetch data to avoid gpu waiting.
muellerzr commented 1 month ago

The speed shouldn't be the same, no? You're working with images, which takes much longer to load them into RAM, especially if you're doing so on a single worker unless I am mistaken. You can pillow-simd which speeds up pillow some

geekifan commented 1 month ago

The speed shouldn't be the same, no? You're working with images, which takes much longer to load them into RAM, especially if you're doing so on a single worker unless I am mistaken. You can pillow-simd which speeds up pillow some

Thanks for your reply!

Of course loading on a single worker is slower than loading on multiple workers. But the biggest problem is that when I load images on a single worker, the CPU usage is much more higher than loading images on 4 workers and meanwhile the speed of single worker is 20x slower than 4 workers.

I think the expected behavior should be: loading on a single worker uses ~4x less cpu than loading on 4 workers and the speed of single worker is ~4x slower than 4 workers. The CPU usage should MATCH the CPU time.

Besides, it seems that the dataloader is NOT PREFETCHING when loading on multiple workers.

geekifan commented 1 month ago

It is really weird for me to find that the dataloader starts to load data every 1/100 of total steps. It doesn't load any data when the gpu is running. Maybe the dataloader should load the data while the gpu is training?

techkang commented 3 weeks ago

The loading speed is ~4x less only when you set dataloader_num_workers=1 but not 0. When set dataloader_num_workers=1, dataloader will keep processing data when GPU is training.

github-actions[bot] commented 3 days ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.