HaiyuWu / SOTA-Face-Recognition-Train-and-Test

MIT License
6 stars 1 forks source link

Error when sample_rate < 1 #2

Open TranThanh96 opened 2 weeks ago

TranThanh96 commented 2 weeks ago

I am trying to train with wf42m, I set sample_rate=0.3 and train with uniface, I think error at:

loss = one_hot * p_loss + (~one_hot) * n_loss

can you check it ?

TranThanh96 commented 2 weeks ago

I think you missing this one: https://github.com/CVI-SZU/UniFace/blob/main/head.py#L116

one_hot = torch.index_select(one_hot, 1, partial_index)
HaiyuWu commented 2 weeks ago
    one_hot.scatter_(1, label.view(-1, 1).long(), 1)
    one_hot = torch.index_select(one_hot, 1, partial_index)

is equal to onehot.scatter(1, labels[index].view(-1, 1).long(), 1)