NVIDIA / DALI

A GPU-accelerated library containing highly optimized building blocks and an execution engine for data processing to accelerate deep learning training and inference applications.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
Apache License 2.0
5.17k stars 622 forks source link

Can't get the same acc with the pytorch transform in iNaturalist2018 dataset #5721

Open lyyi599 opened 3 days ago

lyyi599 commented 3 days ago

Thanks for this awesome tool.

Describe the question.

For my experiments, I need to modify a piece of open-source code. Due to the time-consuming nature of the transformation and the large size of the iNaturalist2018 dataset (over 400,000 images), I plan to switch to using DALI for data loading. Below is the original open-source code snippet.(from:https://github.com/shijxcs/LIFT/blob/661ead9b78368f05ba79abe4672d63154467f823/trainer.py#L102

    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(resolution),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    if cfg.tte:
        if cfg.tte_mode == "fivecrop":
            transform_test = transforms.Compose([
                transforms.Resize(resolution + expand),
                transforms.FiveCrop(resolution),
                transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                transforms.Normalize(mean, std),
            ])
        elif cfg.tte_mode == "tencrop":
            transform_test = transforms.Compose([
                transforms.Resize(resolution + expand),
                transforms.TenCrop(resolution),
                transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                transforms.Normalize(mean, std),
            ])
        elif cfg.tte_mode == "randaug":
            _resize_and_flip = transforms.Compose([
                transforms.RandomResizedCrop(resolution),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ])
            transform_test = transforms.Compose([
                transforms.Lambda(lambda image: torch.stack([_resize_and_flip(image) for _ in range(cfg.randaug_times)])),
                transforms.Normalize(mean, std),
            ])
    else:
        transform_test = transforms.Compose([
            transforms.Resize(resolution * 8 // 7),
            transforms.CenterCrop(resolution),
            transforms.Lambda(lambda crop: torch.stack([transforms.ToTensor()(crop)])),
            transforms.Normalize(mean, std),
        ])
    self.train_loader = DataLoader(train_dataset,
        batch_size=cfg.micro_batch_size, shuffle=True,
        num_workers=cfg.num_workers, pin_memory=True)

    self.test_loader = DataLoader(test_dataset,
        batch_size=64, shuffle=False,
        num_workers=cfg.num_workers, pin_memory=True)

Referring to the DALI example(https://github.com/NVIDIA/DALI/blob/4562f157a203bd17a1dbc9a0b07f05ba3a41c1fb/docs/examples/use_cases/pytorch/resnet50/main.py#L275 )I have successfully adapted the code for ImageNet_LT and achieved results comparable to the original paper (77.0 with LA loss). The modified code is as follows:

  def create_dali_pipeline(data_dir, data_list_dir, crop, size, shard_id, num_shards, dali_cpu=False, is_training=True):
      images, labels = fn.readers.file(file_root=data_dir,
                                       file_list=data_list_dir,
                                       shard_id=shard_id,
                                       num_shards=num_shards,
                                       random_shuffle=is_training,
                                       pad_last_batch=True,
                                       name="Reader")
      dali_device = 'cpu' if dali_cpu else 'gpu'
      decoder_device = 'cpu' if dali_cpu else 'mixed'
      # ask HW NVJPEG to allocate memory ahead for the biggest image in the data set to avoid reallocations in runtime
      preallocate_width_hint = 5980 if decoder_device == 'mixed' else 0
      preallocate_height_hint = 6430 if decoder_device == 'mixed' else 0
      if is_training:
          images = fn.decoders.image_random_crop(images,
                                                 device=decoder_device, output_type=types.RGB,
                                                 preallocate_width_hint=preallocate_width_hint,
                                                 preallocate_height_hint=preallocate_height_hint,
                                                 random_aspect_ratio=[0.8, 1.25],
                                                 random_area=[0.1, 1.0],
                                                 num_attempts=100)
          images = fn.resize(images,
                             device=dali_device,
                             resize_x=crop,
                             resize_y=crop,
                             interp_type=types.INTERP_TRIANGULAR)
          mirror = fn.random.coin_flip(probability=0.5)
      else:
          images = fn.decoders.image(images,
                                     device=decoder_device,
                                     output_type=types.RGB)
          images = fn.resize(images,
                             device=dali_device,
                             size=size,
                             mode="not_smaller",
                             interp_type=types.INTERP_TRIANGULAR)
          mirror = False

      images = fn.crop_mirror_normalize(images.gpu(),
                                        dtype=types.FLOAT,
                                        output_layout="CHW",
                                        crop=(crop, crop),
                                        mean=[0.485 * 255,0.456 * 255,0.406 * 255],
                                        std=[0.229 * 255,0.224 * 255,0.225 * 255],
                                        mirror=mirror)

      # for tte
      if not is_training:
          images = fn.stack(images, axis=0)
      labels = labels.gpu()
      return images, labels

              crop_size = 224
              val_size = 256 #five crop的时候需要进一步修改
              train_pipe = create_dali_pipeline(
                  batch_size=cfg.batch_size,
                  num_threads=cfg.num_workers,
                  device_id=torch.cuda.current_device(),
                  seed=12 + torch.cuda.current_device(),
                  data_dir=cfg.root,
                  data_list_dir = train_root_lt,
                  crop=crop_size,
                  size=val_size,
                  dali_cpu=cfg.dali_cpu,
                  shard_id=torch.cuda.current_device(),
                  num_shards=torch.cuda.device_count(),
                  is_training=True
              )
              train_pipe.build()
              self.train_loader = DALIClassificationIterator(train_pipe, reader_name="Reader",
                                                    last_batch_policy=LastBatchPolicy.PARTIAL,
                                                    auto_reset=True)

              test_pipe = create_dali_pipeline(
                  batch_size=cfg.batch_size,
                  num_threads=cfg.num_workers,
                  device_id=torch.cuda.current_device(),
                  seed=12 + torch.cuda.current_device(),
                  data_dir=cfg.root,
                  data_list_dir = test_root_lt,
                  crop=crop_size,
                  size=val_size,
                  dali_cpu=cfg.dali_cpu,
                  shard_id=torch.cuda.current_device(),
                  num_shards=torch.cuda.device_count(),
                  is_training=False)
              test_pipe.build()
              self.test_loader = DALIClassificationIterator(test_pipe, reader_name="Reader",
                                                  last_batch_policy=LastBatchPolicy.PARTIAL,
                                                  auto_reset=True)

As mentioned earlier, the code has been validated on the ImageNet_LT dataset. However, when applied to the iNaturalist2018 dataset, it results in significant differences, specifically:

  1. Under the same settings, the DALI pipeline achieves only 67.6% accuracy, while using PyTorch transforms achieves the paper's reported result of 79.1%.
  2. In the experiment logs using PyTorch transforms, the model initially learns samples from head classes (categories with more samples that are easier to collect), as observed in the first 50 batches.
epoch [1/20] batch [10/3419] time 0.761 (2.635) data 0.433 (2.208) loss 6.9438 (7.0018) acc 1.5625 (2.2394) (mean 0.2032 many 1.4897 med 0.0981 few 0.0000) lr 1.0000e-02 eta 2 days, 2:02:49
epoch [1/20] batch [20/3419] time 0.334 (1.935) data 0.000 (1.556) loss 6.7982 (6.7163) acc 7.0312 (4.1451) (mean 0.4500 many 2.4381 med 0.3216 few 0.0931) lr 1.0000e-02 eta 1 day, 12:44:06
epoch [1/20] batch [30/3419] time 0.331 (1.696) data 0.000 (1.332) loss 6.1816 (6.5292) acc 9.3750 (5.8511) (mean 0.6444 many 3.3323 med 0.5228 few 0.0962) lr 1.0000e-02 eta 1 day, 8:11:31
epoch [1/20] batch [40/3419] time 0.331 (1.587) data 0.000 (1.232) loss 6.1855 (6.3015) acc 5.4688 (6.2589) (mean 0.7358 many 3.6615 med 0.5883 few 0.1582) lr 1.0000e-02 eta 1 day, 6:08:02
epoch [1/20] batch [50/3419] time 0.329 (1.721) data 0.000 (1.370) loss 5.4283 (6.0245) acc 12.5000 (8.0597) (mean 0.9444 many 4.2423 med 0.7868 few 0.2823) lr 1.0000e-02 eta 1 day, 8:39:37

In contrast, using the DALI pipeline prioritizes learning from tail classes (categories with fewer samples that are harder to collect). The corresponding log is as follows:

epoch [1/20] batch [10/3419] time 0.327 (0.372) data 0.001 (0.001) loss 9.2174 (9.8390) acc 0.0000 (0.0461) (mean 0.0123 many 0.0000 med 0.0000 few 0.0310) lr 1.0000e-02 eta 7:03:39
epoch [1/20] batch [20/3419] time 0.326 (0.349) data 0.001 (0.001) loss 8.7870 (9.2437) acc 0.0000 (0.2400) (mean 0.0167 many 0.0000 med 0.0000 few 0.0422) lr 1.0000e-02 eta 6:37:59
epoch [1/20] batch [30/3419] time 0.326 (0.342) data 0.001 (0.001) loss 8.3188 (8.8392) acc 0.0000 (0.1540) (mean 0.0163 many 0.0000 med 0.0000 few 0.0413) lr 1.0000e-02 eta 6:29:29
epoch [1/20] batch [40/3419] time 0.327 (0.338) data 0.001 (0.001) loss 8.4082 (8.6090) acc 0.0000 (0.1240) (mean 0.0164 many 0.0000 med 0.0000 few 0.0415) lr 1.0000e-02 eta 6:25:10
epoch [1/20] batch [50/3419] time 0.326 (0.336) data 0.001 (0.001) loss 8.5320 (8.5095) acc 0.0000 (0.2166) (mean 0.0190 many 0.0000 med 0.0000 few 0.0480) lr 1.0000e-02 eta 6:22:21

It is evident that the learning processes of the two methods differ significantly. However, such a large discrepancy is not observed on the ImageNet_LT dataset.

Of course, after a certain number of epochs, the logs for both PyTorch transforms and the DALI pipeline eventually show lower accuracy for head classes and higher accuracy for tail classes. However, the overall accuracy drops by over 10%, which is an unacceptable difference.

Possible Explanations

The sampling process of the pipeline and the dataloader is different. However, I don't think it would cause such a significant discrepancy.

Heeeelp

Thank you for any suggestions!

Check for duplicates

JanuszL commented 2 days ago

Hi @lyyi599,

Thank you for reaching out and using DALI.

Under the same settings, the DALI pipeline achieves only 67.6% accuracy, while using PyTorch transforms achieves the paper's reported result of 79.1%. In the experiment logs using PyTorch transforms, the model initially learns samples from head classes (categories with more samples that are easier to collect), as observed in the first 50 batches.

DALI file reader does the following:

So it should not favor less-represented samples. Can you gather statistics of the classes DALI reader returns? Do they match the sample distribution in the dataset?

lyyi599 commented 2 days ago

Hi @lyyi599,

Thank you for reaching out and using DALI.

Under the same settings, the DALI pipeline achieves only 67.6% accuracy, while using PyTorch transforms achieves the paper's reported result of 79.1%. In the experiment logs using PyTorch transforms, the model initially learns samples from head classes (categories with more samples that are easier to collect), as observed in the first 50 batches.

DALI file reader does the following:

  • initially mix the whole data set to make sure that it doesn't sample only from the first class first
  • read N samples into shuffling buffer, where N is 1024 by default
  • randomly sample the buffer

So it should not favor less-represented samples. Can you gather statistics of the classes DALI reader returns? Do they match the sample distribution in the dataset?

Hi @JanuszL ,

Thank you for your reply. After my verification, it is indeed the case that each epoch samples the data from the train set once for training, and they match the sample distribution in the dataset. However, the code still produces the results mentioned earlier. For example, in first epoch, the accuracy obtained on val dataset using DALI is as follows: many: 20.0%, med: 21.0%, few: 22.2%, average: 25.4%, while the results using PyTorch transforms are: many: 3.4%, med: 9.3%, few: 16.5%, average: 11.5%.

This is very confusing to me. The main training code I'm currently using is as follows (including the pipeline and transform). Simply switching the dali_dataset flag to include iNaturalist2018 (which corresponds to using the pipeline and transform for data loading) results in the significant accuracy drop mentioned above. Could you help me check if I am using DALI correctly?trainer.py.txt

It is worth mentioning that the iNaturalist2018 dataset uses a .txt file for indexing, so the create_dali_pipeline function includes the corresponding data_list_dir parameter for reading it.

Thank you for any suggestions you may have.

mzient commented 1 day ago

Hello @lyyi599 , Where do the iNaturalist18_train.txt and iNaturalist18_val.txt come from? They're not part of the original dataset. Perhaps they are simply generated incorrectly and the samples are mislabeled in training?

lyyi599 commented 1 day ago

Hi @mzient ,

Thank you for your reply, and I apologize for not explaining the issue more clearly. Let me provide some context for the code: this is about long-tail recognition. Many real-world problems exhibit a long-tail distribution, meaning that in the training process, the number of samples per class varies. The iNaturalist2018 dataset is an example of this, and it has been widely studied. You can find existing examples here: https://github.com/shijxcs/LIFT/tree/661ead9b78368f05ba79abe4672d63154467f823/datasets/iNaturalist18. Therefore, the issue is likely not related to the .txt file used for indexing.

In the trainer.py code above, the only difference is how iNaturalist18 is loaded—specifically, the use of the pipeline and dataloader. This difference is causing the accuracy drop, so I suspect that there might be an issue with how I’m using DALI, or there could be some bugs in DALI that I haven’t identified. For comparison, the Imagenet_LT dataset is also indexed using a .txt file, and it uses a similar pipeline and dataloader approach, but I am able to obtain comparable results.

JanuszL commented 1 day ago

Therefore, the issue is likely not related to the .txt file used for indexing.

In the trainer.py code above, the only difference is how iNaturalist18 is loaded—specifically, the use of the pipeline and dataloader. This difference is causing the accuracy drop, so I suspect that there might be an issue with how I’m using DALI, or there could be some bugs in DALI that I haven’t identified. For comparison, the Imagenet_LT dataset is also indexed using a .txt file, and it uses a similar pipeline and dataloader approach, but I am able to obtain comparable results.

We are not saying it is the cause, but we want to make sure we are looking at the same things. If the index file misses some samples then DALI will not return them underrepresenting some classes and overrepresenting others. If you can also provide a way you generated these or the files itself it would be great.

lyyi599 commented 1 day ago

Therefore, the issue is likely not related to the .txt file used for indexing. In the trainer.py code above, the only difference is how iNaturalist18 is loaded—specifically, the use of the pipeline and dataloader. This difference is causing the accuracy drop, so I suspect that there might be an issue with how I’m using DALI, or there could be some bugs in DALI that I haven’t identified. For comparison, the Imagenet_LT dataset is also indexed using a .txt file, and it uses a similar pipeline and dataloader approach, but I am able to obtain comparable results.

We are not saying it is the cause, but we want to make sure we are looking at the same things. If the index file misses some samples then DALI will not return them underrepresenting some classes and overrepresenting others. If you can also provide a way you generated these or the files itself it would be great.

I get you. I download the files from the repo: https://github.com/shijxcs/LIFT/tree/661ead9b78368f05ba79abe4672d63154467f823/datasets/iNaturalist18 , and I ensured that the file properly indexes the corresponding data.