BBBBchan / CorrMatch

Official code for "CorrMatch: Label Propagation via Correlation Matching for Semi-Supervised Semantic Segmentation"
128 stars 8 forks source link

about train of question #18

Closed macaulishxcoo closed 3 months ago

macaulishxcoo commented 3 months ago

源码中需要针对每个数据集做特殊处理吗?包括我从kaggle下载的msd数据集(手机缺陷,分为3类缺陷)配置好训练报错: 0%| | 0/62 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/NLLLoss2d.cu:103: nll_loss2d_forward_kernel: block: [0,0,0], thread: [352,0,0] Assertion t >= 0 && t < n_classes failed. ../aten/src/ATen/native/cuda/NLLLoss2d.cu:103: nll_loss2d_forward_kernel: block: [0,0,0], thread: [768,0,0] Assertion t >= 0 && t < n_classes failed. ../aten/src/ATen/native/cuda/NLLLoss2d.cu:103: nll_loss2d_forward_kernel: block: [0,0,0], thread: [160,0,0] Assertion t >= 0 && t < n_classes failed. Traceback (most recent call last): File "corrmatch.py", line 323, in main() File "corrmatch.py", line 243, in main loss_u_s1 = torch.sum(loss_u_s1) / torch.sum(ignore_mask_cutmixed1 != 255).item() RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions. 我自己制作的数据集也出现类似问题。

BBBBchan commented 3 months ago

您好。这个报错看起来像是数组越界,我们在config目录下的配置文件中对数据集类别数量,crop size等进行了指定。如果您需要再自己的数据集上训练,至少需要修改config中的nclass. 另外,PASCAL和Cityscapes数据集中都有ignore类别存在,且都以255的形式标注,如果您的数据集也包含ignore类别且标注形式不同(例如有些数据集用0或-1来标注ignore类),可能需要再config和dataset/semi.py中进行修改。