open-mmlab / mmsegmentation

OpenMMLab Semantic Segmentation Toolbox and Benchmark.
https://mmsegmentation.readthedocs.io/en/main/
Apache License 2.0
8.3k stars 2.62k forks source link

bug which may result out-of-index error on cross_entropy function #3734

Open KaneiGi opened 4 months ago

KaneiGi commented 4 months ago

Thanks for your error report and we appreciate it a lot.

Checklist

  1. I have searched related issues but cannot get the expected help.
  2. The bug has not been fixed in the latest version.

Describe the bug on mmseg's model -> loss->cross_entropy_loss.py, line 66: label_weights = torch.stack([class_weight[cls] for cls in label ]).to(device=class_weight.device)

there is no check if the cls in label will break the range of class_weight.

Reproduction

  1. What command or script did you run?

    `python train.py pidnet-s_2xb6-120k_1024x1024-cityscapes.py
  2. Did you make any modifications on the code or config? Did you understand what you have modified?

  3. What dataset did you use? cityscapes **System environment: sys.platform: linux Python: 3.10.0 (default, May 11 2024, 13:44:05) [GCC 7.5.0] CUDA available: False MUSA available: False numpy_random_seed: 304 GCC: gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0 PyTorch: 1.12.1+cu102 PyTorch compiling details: PyTorch built with:

    • GCC 7.3
    • C++ Version: 201402
    • Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
    • Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
    • OpenMP 201511 (a.k.a. OpenMP 4.5)**
  4. Please run python mmseg/utils/collect_env.py to collect necessary environment information and paste it here.

  5. You may add addition that may be helpful for locating the problem, such as

    • How you installed PyTorch [e.g., pip, conda, source]
    • Other environment variables that may be related (such as $PATH, $LD_LIBRARY_PATH, $PYTHONPATH, etc.)

Error traceback

If applicable, paste the error trackback here.

A placeholder for trackback.
![图像](https://github.com/user-attachments/assets/d0226b05-aca3-4ab1-82db-28928f602455)

** if (avg_factor is None) and reduction == 'mean': if class_weight is None: if avg_non_ignore: avg_factor = label.numel() - (label == ignore_index).sum().item() else: avg_factor = label.numel()

    else:
        # the average factor should take the class weights into account
        mask = label == ignore_index
        masked_label = label * ~mask
        label_weights = class_weight[masked_label]
        label_weights[mask] = 0

        # label_weights = torch.stack([class_weight[cls] for cls in label
        #                              ]).to(device=class_weight.device)

        if avg_non_ignore:
            label_weights[label == ignore_index] = 0
        avg_factor = label_weights.sum()**

If you have already identified the reason, you can provide the information here. If you are willing to create a PR to fix it, please also leave a comment here and that would be much appreciated!

AreopagX commented 3 weeks ago

Had the same issue. Your fix works for me. Thank you