xypu98 / CWSAM

42 stars 0 forks source link

关于训练时的问题。 #14

Closed DevKPro closed 3 weeks ago

DevKPro commented 4 weeks ago

您好,作者。 非常感谢您开源了代码。我在使用开源代码训练时遇到以下问题,恳请您解惑~

  1. 我拉取了您的仓库并在本地训练,配置文件内容没有修改,除了数据集和sam权重文件的路径改为自己电脑下的路径。然后,我使用以下指令 python -m torch.distributed.launch --nnodes 1 --nproc_per_node 1 train.py 在我的WSL虚拟机上运行,然而报错显示损失函数中gt_mask的值必须为long类型,我使用类型转换解决这个报错并继续执行,随后又报另一个错误,我跟踪了相关代码,发现计算损失时传入的gt_mask是shape为(B,N,H,W)的图像,但是cross entropy传入的label不应该是(B,H,W)吗?
  2. 除此之外,我还发现一个问题。在代码中,LAB目录下的文件是标签文件,应该是多类别的标签,但是我调试时发现数据加载的label只有0和255(Image.open打开后转为ndarray查看得知),在transforms.toTensor()之后自动归一化为0和1,但这样的label能够用于多类别分割吗?读取的结果跟系统查看到的数据集Lab文件有很大差异,对比图片如下所示。 image

以上便是我执行代码时所遇到的问题,真切希望能得到作者的解答,十分感激! 祝生活愉快!

xypu98 commented 3 weeks ago
  1. 目前代码只兼容单卡,batch size=1的情况,可以直接调用python train.py ,bs=1,尝试一下是否能解决
  2. 代码对数据集label的处理是先将255归一化为1,然后需要进行one hot处理,再输入训练。具体实现可参考 datasets/wrappers.py文件
DevKPro commented 3 weeks ago
  1. 目前代码只兼容单卡,batch size=1的情况,可以直接调用python train.py ,bs=1,尝试一下是否能解决
  2. 代码对数据集label的处理是先将255归一化为1,然后需要进行one hot处理,再输入训练。具体实现可参考 datasets/wrappers.py文件

感谢作者的回复。 在 cross entropy 计算损失部分报错可能是版本或者系统问题导致,我后来在 windows 上执行该项目代码是没有任何报错的。 另外,其他关于 label 的问题也在调试 datasets 下相关代码中得到解决。 再次感谢作者的回复,祝您生活愉快!