GeWu-Lab / OGM-GE_CVPR2022

The repo for "Balanced Multimodal Learning via On-the-fly Gradient Modulation", CVPR 2022 (ORAL)
MIT License
220 stars 18 forks source link

关于main.py中获取out_a与out_v相关操作的疑惑 #39

Open XuecWu opened 1 year ago

XuecWu commented 1 year ago

首先感谢贵团队的辛勤付出! 我的疑问存在于main.py中(line80-93)在获取模型的输出a, v, out之后,通过矩阵相乘(torch.mm)方法来获取到out_v以及out_a。具体如下: image 当我使用sum方法与concat方法时,该部分代码会正确运行。然而,当我使用FiLM以及Gated Fusion方法的时候,代码理应执行”else“部分的操作。但是却发生了报错,报错的主要原因是进行矩阵相乘操作的时候mat1与mat2的维度不相同。 以FiLM方法为例: image 设定目标种类数为n_classes=8 视觉特征张量v的shape为bs, 512 音频特征张量a的shape为bs, 512

out_v = (torch.mm(v, torch.transpose(model.module.fusion_module.fc_out.weight[:, weight_size // 2:], 0, 1)) + model.module.fusion_module.fc_out.bias / 2) out_a = (torch.mm(a, torch.transpose(model.module.fusion_module.fc_out.weight[:, :weight_size // 2], 0, 1)) + model.module.fusion_module.fc_out.bias / 2) model.module.fusion_module.fc_out.weight的shape为[n_classes, input_dim],即为8以及512。 经过上述的操作之后, torch.transpose(model.module.fusion_module.fc_out.weight[:, weight_size // 2:], 0, 1)的shape为256, 8 torch.transpose(model.module.fusion_module.fc_out.weight[:, :weight_size // 2], 0, 1)的shape亦为256, 8 然而v与a的shape均为bs, 512。因此会出现维度不匹配的相关报错。

基于上述观察,我将if opt.fusion_method == 'sum':时的代码移植到当前的情况中来,移植后的代码如下: out_v = (torch.mm(v, torch.transpose(model.module.fusion_module.fc_out.weight, 0, 1)) + model.module.fusion_module.fc_out.bias) out_a = (torch.mm(a, torch.transpose(model.module.fusion_module.fc_out.weight, 0, 1)) + model.module.fusion_module.fc_out.bias) 我经过测试后发现,这样代码是可以正常运行的,但是这样的视觉分支与音频分支的准确率很低,在整体准确率可以达到44.4的时候,视觉分支与音频分支的准确率只分别有14.6以及9.2。 这对于我造成了困惑,整体上就是我运行repo中的代码不正确,之后进行了相关更改。代码可以成功运行,但是两个分支的准确率很不理想的问题。 我想询问是哪里出现了问题还是我理解上出现了偏差,希望得到贵团队的回复!

谢谢!

XuecWu commented 1 year ago

补充说明下: 我对于数据集进行了更换,所报告的准确率是第一代的准确率。

Felix-fz commented 4 weeks ago

我遇到了相同的问题,请问您是否解决