qianyu-dlut / MVANet

MIT License
112 stars 12 forks source link

Contribution of intermediate layer components to the loss are not converging #15

Open VMinB12 opened 1 month ago

VMinB12 commented 1 month ago

We are retraining MVANet on the ImageMatte dataset and are observing some undesired behaviour in the loss. We made a validation split and are individually logging the following components of the loss:

def compute_loss(
    generator: MVANet, images, gts, target_1, target_2, target_3, target_4, target_5
):
    (
        sideout5,
        sideout4,
        sideout3,
        sideout2,
        sideout1,
        final,
        glb5,
        glb4,
        glb3,
        glb2,
        glb1,
        tokenattmap4,
        tokenattmap3,
        tokenattmap2,
        tokenattmap1,
    ) = generator.forward(images)

    loss1 = structure_loss(sideout5, target_4)
    loss2 = structure_loss(sideout4, target_3)
    loss3 = structure_loss(sideout3, target_2)
    loss4 = structure_loss(sideout2, target_1)
    loss5 = structure_loss(sideout1, target_1)
    final_loss = structure_loss(final, gts)
    loss7 = structure_loss(glb5, target_5)
    loss8 = structure_loss(glb4, target_4)
    loss9 = structure_loss(glb3, target_3)
    loss10 = structure_loss(glb2, target_2)
    loss11 = structure_loss(glb1, target_2)
    loss12 = structure_loss(tokenattmap4, target_3)
    loss13 = structure_loss(tokenattmap3, target_2)
    loss14 = structure_loss(tokenattmap2, target_1)
    loss15 = structure_loss(tokenattmap1, target_1)
    Loss_loc = loss1 + loss2 + loss3 + loss4 + loss5
    Loss_glb = loss7 + loss8 + loss9 + loss10 + loss11
    Loss_map = loss12 + loss13 + loss14 + loss15

    loss = final_loss + Loss_loc + 0.3 * Loss_glb + 0.3 * Loss_map
    return {
        "loss": loss,
        "final_loss": final_loss,
        "Loss_loc": Loss_loc,
        "Loss_glb": Loss_glb,
        "Loss_map": Loss_map,
    }

We are observing that the final_loss, which is the most important loss is decreasing nicely as expected:

Screenshot 2024-10-03 at 11 06 17

However, the Loss_loc, Loss_glb and Loss_map are not converging during training:

Screenshot 2024-10-03 at 11 07 31 Screenshot 2024-10-03 at 11 07 46 Screenshot 2024-10-03 at 11 07 08

We are interested in any insights into why this could happen and how we could improve the situation. Furthermore, did you keep track of these losses during pretraining of the original MVANet and did you observe convergence in this case?

For context, we used the script train.py provided in the repo without modification of the settings except for changing the dataset.

piercus commented 1 month ago

From DIS "intermediate supervision plays a typical role of regularizer for reducing the probability of over-fitting"

VMinB12 commented 1 month ago

I see the motivation for adding intermediate supervision. But surely it is not working as intended based on the loss curves we show. Is this really the intended behaviour?