grip-unina / TruFor

TruFor
132 stars 9 forks source link

Detailed question of parameters on training #14

Closed CIawevy closed 6 days ago

CIawevy commented 10 months ago

Hi, your work is really fascinating to me, these days I am trying to implement and train the TruFor network on my device. However , I have run into some questions, as you mentioned in the paper, you have divided the training into 3 stages. And I am just wondering whether you froze the backbone and the anomality decoder, when training the 3rd phase——training the detector and the confidence decoder altogether with the hybrid loss L3。 Furthermore,it seems like the localization results reported in the paper only comes from the anomality decoder , without the aid of confidence map, am i right? Lastly, it is reported that you use 100 epochs to train phase 2 and 3 altogether. But exactly how did you do this ? what is the number of sampled images in each epoch? I would be appreciate if you would reply! thanks!

fabrizioguillaro commented 10 months ago

Hello! Yes, during phase 3 the backbone and anomaly decoder are frozen. I will summarize the "active" and "frozen" modules for the 3 phases:

Furthermore,it seems like the localization results reported in the paper only comes from the anomality decoder , without the aid of confidence map, am i right?

Correct, during phase 3 the anomaly localization network is frozen, so its output will not change wtr phase 2. In this work, the confidence map is used to reduce the effect of the false positives on the final detection score, it does not "correct" the localization map directly.

For the question about the number of samples in each epoch, you can refer to issue #10.

CIawevy commented 10 months ago

Thanks a lot ! by the way can you release your code about the balanced cross entrophy loss of the detection head? thank you !

fabrizioguillaro commented 10 months ago

The balanced cross-entropy loss (Ldet in the paper) is simply the cross-entropy whith weighted classes. The weights depend on the rough size of the pristine/fake portion of the overall dataset.

    def forward(self, score, target):
        target_det = (torch.count_nonzero(target * (target >= 0), (-1, -2)) > 3).float().clamp(0, 1)
        weights_det = target_det * 0.5 / 0.7 + (1 - target_det) * 0.5 / 0.3
        loss_det = F.binary_cross_entropy_with_logits(score[:, 0], target_det, reduction='mean', weight=weights_det)
        return loss_det

The first line of code is to ensure that an image is considered fake if at least 3 pixels are fake. This is just an arbitrary number that we use to avoid that the current patch has too few fake pixels. Note that this could happen, since the input image is cropped in training and the tampered area could fall outside of the crop. In this case the crop is to be considered pristine, even if the original image was manipulated.