Open TranThanh96 opened 2 weeks ago
I think you missing this one: https://github.com/CVI-SZU/UniFace/blob/main/head.py#L116
one_hot = torch.index_select(one_hot, 1, partial_index)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
one_hot = torch.index_select(one_hot, 1, partial_index)
is equal to onehot.scatter(1, labels[index].view(-1, 1).long(), 1)
I am trying to train with wf42m, I set sample_rate=0.3 and train with uniface, I think error at:
can you check it ?