Linfeng-Tang / SeAFusion

The code of " Image fusion in the loop of high-level vision tasks: A semantic-aware real-time infrared and visible image fusion network"
MIT License
184 stars 34 forks source link

train_seg #50

Open Zshuting opened 1 month ago

Zshuting commented 1 month ago

你好 我想问下我在代码里面看到lb的来源是: label = np.asarray(Image.fromarray(label), dtype=np.int64)
torch.tensor(label) label = Variable(label).cuda() lb = torch.squeeze(label, 1) 然后把lb送到交叉熵损失函数(CrossEntropyLoss)里面,但是我了解到的是这个损失函数的标签是目标的类别 然后我运行train.py 最先出现这种报错:RuntimeError: size mismatch (got input: [22118400], target: [7372800]) 后面我修改loss里面的代码 将logits中的各通道对应的像素点求均值 最后出现这种错误RuntimeError: Expected floating point type for target with class probabilities, got Long 能解释一下 lb指的是啥吗?

Linfeng-Tang commented 1 month ago

lb就是分割的标签类别 你可以检查一下dataloader是否有问题 我们之前训练的时候没有遇到你提的这个问题呢

Zshuting commented 1 month ago

谢谢您的解答 但是我尝试打印lb的形状 发现loss.py的lb的形状是[8, 480, 640, 3] 报错:输入的lb应该是三维的张量 而不应该是前面出现的形状 请问在你的训练过程中lb的形状应该是几维的?还有我是否应该使用经过visualize.py可视化后的label进行训练?

Linfeng-Tang commented 1 month ago

不需要使用经过visualize.py的label进行训练呢 应该是8 480 640 1

Zshuting @.***> 于2024年7月17日周三 15:05写道:

谢谢您的解答 但是我尝试打印lb的形状 发现loss.py的lb的形状是[8, 480, 640, 3] 报错:输入的lb应该是三维的张量 而不应该是前面出现的形状 请问在你的训练过程中lb的形状应该是几维的?还有我是否应该使用经过visualize.py可视化后的label进行训练?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

Zshuting commented 1 month ago

谢谢您的回复!我明白了