frederickszk / LRNet

Landmark Recurrent Network: An efficient and robust framework for Deepfakes detection
MIT License
90 stars 13 forks source link

LandmarkDropout 类 #37

Open jiaxinJ opened 3 weeks ago

jiaxinJ commented 3 weeks ago

大佬你好!还有个问题,model.py中的 LandmarkDropout 类是用来干啥的? `class LandmarkDropout(nn.Module): def init(self, p: float = 0.5): super(LandmarkDropout, self).init() if p < 0 or p > 1: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) self.p = p

def generate_mask(self, landmark, frame):
    position_p = torch.bernoulli(torch.Tensor([1 - self.p]*(landmark//2)))
    position_p = position_p.repeat_interleave(2)
    return position_p.repeat(1, frame, 1)

def forward(self, x: torch.Tensor):
    if self.training:
        _, frame, landmark = x.size()
        landmark_mask = self.generate_mask(landmark, frame)
        scale = 1/(1-self.p)
        return x*landmark_mask.to(x.device)*scale
    else:
        return x`
frederickszk commented 3 weeks ago

就是字面意思,对landmark进行dropout 缓解过拟合的,这部分的原理在新论文里(还在审😭) ,具体效果的话可以打断点看一下x在输入这一层前后的变化就知道了。 这里边看起来比较奇怪的position_p是为了配合landmark向量的形状,就是每两个值为一个坐标,看做一个整体。