Open johndpope opened 2 weeks ago
drafted https://github.com/johndpope/MegaPortrait-hack/tree/feat/sub-sampling
i dont see a massive speed up.
it's possible this could be randomized - sometimes go full / half / quarter etc.
class PerceptualLoss(nn.Module): def __init__(self, device, weights={'vgg19': 20.0, 'vggface':5.0, 'gaze': 4.0}): super(PerceptualLoss, self).__init__() self.device = device self.weights = weights # VGG19 network vgg19 = models.vgg19(pretrained=True).features self.vgg19 = nn.Sequential(*[vgg19[i] for i in range(30)]).to(device).eval() self.vgg19_layers = [1, 6, 11, 20, 29] # VGGFace network self.vggface = InceptionResnetV1(pretrained='vggface2').to(device).eval() self.vggface_layers = [4, 5, 6, 7] # Gaze loss self.gaze_loss = MPGazeLoss(device) # Trick shot to reduce memory 3.3 - use random sub_sample # https://arxiv.org/pdf/2404.09736#page=5.58 def forward(self, predicted, target, sub_sample_size=(128, 128),use_fm_loss=False): # Normalize input images predicted = self.normalize_input(predicted) target = self.normalize_input(target) # Compute VGG19 perceptual loss vgg19_loss = self.compute_vgg19_loss(predicted, target) # Compute VGGFace perceptual loss vggface_loss = self.compute_vggface_loss(predicted, target) # Compute gaze loss # gaze_loss = self.gaze_loss(predicted, target) # Compute total perceptual loss total_loss = ( self.weights['vgg19'] * vgg19_loss + self.weights['vggface'] * vggface_loss + self.weights['gaze'] * 1 #gaze_loss ) if use_fm_loss: # Compute feature matching loss fm_loss = self.compute_feature_matching_loss(predicted, target) total_loss += fm_loss return total_loss def sub_sample_tensor(self, tensor, sub_sample_size): assert tensor.ndim == 4, "Input tensor should have 4 dimensions (batch_size, channels, height, width)" assert tensor.shape[-2] >= sub_sample_size[0] and tensor.shape[-1] >= sub_sample_size[1], "Sub-sample size should not exceed the tensor dimensions" batch_size, channels, height, width = tensor.shape # randomly sample so we cover all the image over training. random_offset_x = np.random.randint(0, height - sub_sample_size[0]) random_offset_y = np.random.randint(0, width - sub_sample_size[1]) sub_sampled_tensor = tensor[..., random_offset_x:random_offset_x+sub_sample_size[0], random_offset_y:random_offset_y+sub_sample_size[1]] return sub_sampled_tensor def compute_vgg19_loss(self, predicted, target): return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target) def compute_vggface_loss(self, predicted, target): return self.compute_perceptual_loss(self.vggface, self.vggface_layers, predicted, target) def compute_feature_matching_loss(self, predicted, target): return self.compute_perceptual_loss(self.vgg19, self.vgg19_layers, predicted, target, detach=True) def compute_perceptual_loss(self, model, layers, predicted, target, detach=False): loss = 0.0 predicted_features = predicted target_features = target #print(f"predicted_features:{predicted_features.shape}") #print(f"target_features:{target_features.shape}") for i, layer in enumerate(model.children()): # print(f"i{i}") if isinstance(layer, nn.Conv2d): predicted_features = layer(predicted_features) target_features = layer(target_features) elif isinstance(layer, nn.Linear): predicted_features = predicted_features.view(predicted_features.size(0), -1) target_features = target_features.view(target_features.size(0), -1) predicted_features = layer(predicted_features) target_features = layer(target_features) else: predicted_features = layer(predicted_features) target_features = layer(target_features) if i in layers: if detach: loss += torch.mean(torch.abs(predicted_features - target_features.detach())) else: loss += torch.mean(torch.abs(predicted_features - target_features)) return loss def normalize_input(self, x): mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) return (x - mean) / std
drafted https://github.com/johndpope/MegaPortrait-hack/tree/feat/sub-sampling
i dont see a massive speed up.
it's possible this could be randomized - sometimes go full / half / quarter etc.