cleardusk / 3DDFA_V2

The official PyTorch implementation of Towards Fast, Accurate and Stable 3D Dense Face Alignment, ECCV 2020.
MIT License
2.92k stars 515 forks source link

Question about the implement of Meta-Joint Optimzation #162

Open zero0kiriyu opened 8 months ago

zero0kiriyu commented 8 months ago

I try to reimplement the Meta-joint optimization part, but it always output inf loss after several iter. @cleardusk

model_vdc = mobilenet()
model_wpdc = copy.deepcopy(model_vdc)

optimizer_vdc = torch.optim.SGD(params=model_vdc.parameters(),lr=lr)
optimizer_wpdc = torch.optim.SGD(params=model_wpdc.parameters(),lr=lr)

for epoch in range(N):
    for batch_idx,batch in enumerate(trainloader):
        if batch_idx == 0 or (batch_idx != 0 and batch_idx % meta_joint_k != 0):
            # update by vdc loss
            loss_vdc.backward()
            optimizer_vdc.step()
            optimizer_vdc.zero_grad()

            # update by wpdc loss
            loss_wpdc.backward()
            optimizer_wpdc.step()
            optimizer_wpdc.zero_grad()
        elif batch_idx != 0 and batch_idx % meta_joint_k == 0:
            model_vdc.eval();model_wpdc.eval()
            # calculate the vdc loss for two model
            ......

            if loss_vdc_vdc > loss_vdc_wpdc:
                model_vdc.load_state_dict(copy.deepcopy(model_wpdc))
                optimizer_vdc.load_state_dict(copy.deepcopy(optimizer_wpdc))
            else:
                model_wpdc.load_state_dict(copy.deepcopy(model_vdc))
                optimizer_wpdc.load_state_dict(copy.deepcopy(optimizer_vdc))
           model_vdc.training();model_wpdc.training()