yikaiw / CEN

[TPAMI 2023, NeurIPS 2020] Code release for "Deep Multimodal Fusion by Channel Exchanging"
MIT License
281 stars 44 forks source link

For fusion of 3 modalities #6

Closed onat-dalmaz closed 3 years ago

onat-dalmaz commented 3 years ago

Hi, I was experimenting with your code on my own dataset. However, I realized that image2image translation model only supports the fusion for only two modalities. I checked out the code in detail, it seems that the exchange class is implemented for two sub-networks. class Exchange(nn.Module): def init(self): super(Exchange, self).init()

def forward(self, x, insnorm, insnorm_threshold):
    insnorm1, insnorm2 = insnorm[0].weight.abs(), insnorm[1].weight.abs()
    x1, x2 = torch.zeros_like(x[0]), torch.zeros_like(x[1])
    x1[:, insnorm1 >= insnorm_threshold] = x[0][:, insnorm1 >= insnorm_threshold]
    x1[:, insnorm1 < insnorm_threshold] = x[1][:, insnorm1 < insnorm_threshold]
    x2[:, insnorm2 >= insnorm_threshold] = x[1][:, insnorm2 >= insnorm_threshold]
    x2[:, insnorm2 < insnorm_threshold] = x[0][:, insnorm2 < insnorm_threshold]
    return [x1, x2]

You can see here that's the case. Can you provide the exchange class for more than two modalities? Thanks in advance

yikaiw commented 3 years ago

Hi, thank you for the interest in our work.

In our previous implementation, we manually wrote exchanging codes for 3 or 4 modalities (based on Eq. 6 in the paper) respectively. An example script of 3 modalities is,

def forward(self, x, insnorm, insnorm_threshold):
    insnorm0, insnorm1, insnorm2 = insnorm[0].weight.abs(), insnorm[1].weight.abs(), insnorm[2].weight.abs()
    x0, x1, x2 = torch.zeros_like(x[0]), torch.zeros_like(x[1]), torch.zeros_like(x[2])
    x0[:, insnorm0 >= insnorm_threshold] = x[0][:, insnorm0 >= insnorm_threshold]
    x0[:, insnorm0 < insnorm_threshold] = (x[1][:, insnorm0 < insnorm_threshold] + x[2][:, insnorm0 < insnorm_threshold]) / 2
    x1[:, insnorm1 >= insnorm_threshold] = x[1][:, insnorm1 >= insnorm_threshold]
    x1[:, insnorm1 < insnorm_threshold] = (x[0][:, insnorm1 < insnorm_threshold] + x[2][:, insnorm1 < insnorm_threshold]) / 2
    x2[:, insnorm2 >= insnorm_threshold] = x[2][:, insnorm2 >= insnorm_threshold]
    x2[:, insnorm2 < insnorm_threshold] = (x[0][:, insnorm2 < insnorm_threshold] + x[1][:, insnorm2 < insnorm_threshold]) / 2
    return [x0, x1, x2]

In addition, for 3 modalities, these lines of code should be modified to three disjoint parts, as shown in our supplementary materials (Figure 11 and Figure 12).

onat-dalmaz commented 3 years ago

Hi, thank you very much for the quick response, that was really helpful. How can I modify those lines of code? If you have the sample code for 3 modalities, I would appreciate if you can share it. Thanks.

yikaiw commented 3 years ago
        if len(slim_params) % 3 == 0:
            slim_params.append(param[:len(param) // 3])
        elif len(slim_params) % 3 == 1:
            slim_params.append(param[len(param) // 3: len(param) // 3 * 2])
        else:
            slim_params.append(param[len(param) // 3 * 2:])
onat-dalmaz commented 3 years ago

Thank you very much.