polarisZhao / PFLD-pytorch

PFLD pytorch Implementation
798 stars 197 forks source link

backward() Error in train function. #62

Open minimini-1 opened 2 years ago

minimini-1 commented 2 years ago
def train(train_loader, pfld_backbone, auxiliarynet, criterion, optimizer,
          epoch):
    losses = AverageMeter()

    weighted_loss, loss = None, None
    for img, landmark_gt, attribute_gt, euler_angle_gt in train_loader:
        img = img.to(device)
        attribute_gt = attribute_gt.to(device)
        landmark_gt = landmark_gt.to(device)
        euler_angle_gt = euler_angle_gt.to(device)
        pfld_backbone = pfld_backbone.to(device)
        auxiliarynet = auxiliarynet.to(device)
        features, landmarks = pfld_backbone(img)
        angle = auxiliarynet(features)
        weighted_loss, loss = criterion(attribute_gt, landmark_gt,
                                        euler_angle_gt, angle, landmarks,
                                        args.train_batchsize)
        optimizer.zero_grad()
        weighted_loss.backward()
        optimizer.step()

        losses.update(loss.item())
    return weighted_loss, loss

I had an error in weighted_loss.backward() How can i solve it??

package version is torch==1.2.0+cu92 torchvision==0.4.0+cu92 opencv-python==4.1.0.25

And python version is 3.7