johndpope / MegaPortrait-hack

Using Claude Opus to reverse engineer code from MegaPortraits: One-shot Megapixel Neural Head Avatars
https://arxiv.org/abs/2207.07621
42 stars 7 forks source link

subsampling on perceptual loss trick from fsrt paper #41

Open johndpope opened 2 weeks ago

johndpope commented 2 weeks ago

Screenshot from 2024-06-11 12-26-16

drafted https://github.com/johndpope/MegaPortrait-hack/tree/feat/sub-sampling

i dont see a massive speed up.

it's possible this could be randomized - sometimes go full / half / quarter etc.


class PerceptualLoss(nn.Module):
    def __init__(self, device, weights={'vgg19': 20.0, 'vggface':5.0, 'gaze': 4.0}):
        super(PerceptualLoss, self).__init__()
        self.device = device
        self.weights = weights

        # VGG19 network
        vgg19 = models.vgg19(pretrained=True).features
        self.vgg19 = nn.Sequential(*[vgg19[i] for i in range(30)]).to(device).eval()
        self.vgg19_layers = [1, 6, 11, 20, 29]

        # VGGFace network
        self.vggface = InceptionResnetV1(pretrained='vggface2').to(device).eval()
        self.vggface_layers = [4, 5, 6, 7]

        # Gaze loss
        self.gaze_loss = MPGazeLoss(device)

    # Trick shot to reduce memory 3.3 - use random sub_sample
    # https://arxiv.org/pdf/2404.09736#page=5.58
    def forward(self, predicted, target, sub_sample_size=(128, 128),use_fm_loss=False):
        # Normalize input images
        predicted = self.normalize_input(predicted)
        target = self.normalize_input(target)

        # Compute VGG19 perceptual loss
        vgg19_loss = self.compute_vgg19_loss(predicted, target)

        # Compute VGGFace perceptual loss
        vggface_loss = self.compute_vggface_loss(predicted, target)

        # Compute gaze loss
        # gaze_loss = self.gaze_loss(predicted, target)

        # Compute total perceptual loss
        total_loss = (
            self.weights['vgg19'] * vgg19_loss +
            self.weights['vggface'] * vggface_loss +
            self.weights['gaze'] * 1 #gaze_loss
        )

        if use_fm_loss:
            # Compute feature matching loss
            fm_loss = self.compute_feature_matching_loss(predicted, target)
            total_loss += fm_loss

        return total_loss

    def sub_sample_tensor(self, tensor, sub_sample_size):
        assert tensor.ndim == 4, "Input tensor should have 4 dimensions (batch_size, channels, height, width)"
        assert tensor.shape[-2] >= sub_sample_size[0] and tensor.shape[-1] >= sub_sample_size[1], "Sub-sample size should not exceed the tensor dimensions"

        batch_size, channels, height, width = tensor.shape
        # randomly sample so we cover all the image over training.
        random_offset_x = np.random.randint(0, height - sub_sample_size[0])
        random_offset_y = np.random.randint(0, width - sub_sample_size[1])

        sub_sampled_tensor = tensor[..., random_offset_x:random_offset_x+sub_sample_size[0], random_offset_y:random_offset_y+sub_sample_size[1]]

        return sub_sampled_tensor

    def compute_vgg19_loss(self, predicted, target):
        return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target)

    def compute_vggface_loss(self, predicted, target):
        return self.compute_perceptual_loss(self.vggface, self.vggface_layers, predicted, target)

    def compute_feature_matching_loss(self, predicted, target):
        return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target, detach=True)

    def compute_perceptual_loss(self, model, layers, predicted, target, detach=False):
        loss = 0.0
        predicted_features = predicted
        target_features = target
        #print(f"predicted_features:{predicted_features.shape}")
        #print(f"target_features:{target_features.shape}")

        for i, layer in enumerate(model.children()):
            # print(f"i{i}")
            if isinstance(layer, nn.Conv2d):
                predicted_features = layer(predicted_features)
                target_features = layer(target_features)
            elif isinstance(layer, nn.Linear):
                predicted_features = predicted_features.view(predicted_features.size(0), -1)
                target_features = target_features.view(target_features.size(0), -1)
                predicted_features = layer(predicted_features)
                target_features = layer(target_features)
            else:
                predicted_features = layer(predicted_features)
                target_features = layer(target_features)

            if i in layers:
                if detach:
                    loss += torch.mean(torch.abs(predicted_features - target_features.detach()))
                else:
                    loss += torch.mean(torch.abs(predicted_features - target_features))

        return loss

    def normalize_input(self, x):
        mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1)
        return (x - mean) / std