primepake / wav2lip_288x288

MIT License
524 stars 135 forks source link

Model copying reference frames as it is #122

Closed mudassirkhan19 closed 4 months ago

mudassirkhan19 commented 4 months ago

Hi @mldev-stack ,

Thank you for your awesome work, I've been trying to train the wav2lip model (with SAM) that is present in the repo, I modified the models to work with images of size 288x288 (I don't have enough images with the resolution 384x384). However the model tends to not learn anything, it seems like it just copies the reference frames. Have you seen this before or can you point out where I can be going wrong?

wav2lip_hd_train_old

Modified sam model


class Wav2Lip_SAM(nn.Module):
    def __init__(self, audio_encoder=None):
        super(Wav2Lip_SAM, self).__init__()
        self.sam = SAM()
        self.face_encoder_blocks = nn.ModuleList(
            [
                nn.Sequential(Conv2d(6, 8, kernel_size=7, stride=1, padding=3)),  # 288, 288
                nn.Sequential(
                    Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
                    Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
                ),  # 144, 144
                nn.Sequential(
                    Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 72, 72
                    Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
                ),
                nn.Sequential(
                    Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 36, 36
                    Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                ),
                nn.Sequential(
                    Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # 18, 18
                    Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                ),
                nn.Sequential(
                    Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # 9, 9
                    Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                ),
                nn.Sequential(
                    Conv2d(256, 512, kernel_size=3, stride=2, padding=1),  # 5, 5
                    Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
                ),
                nn.Sequential(
                    Conv2d(512, 1024, kernel_size=3, stride=1, padding=0),  # 3, 3
                    Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
                ),
                nn.Sequential(
                    Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0),  # 1, 1
                    Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
                    Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
                ),
            ]
        )
        ##################

        if audio_encoder is None:
            self.audio_encoder = nn.Sequential(
                Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
                Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
                Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
                Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
                Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                ###################
                # Modified blocks
                ##################
                Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
                Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
                Conv2d(512, 1024, kernel_size=3, stride=1, padding=0),
                Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
            )
            ##################
        else:
            self.audio_encoder = audio_encoder

        for p in self.audio_encoder.parameters():
            p.requires_grad = False
        self.audio_refine = nn.Sequential(
            Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
            Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
        )

        self.face_decoder_blocks = nn.ModuleList(
            [
                nn.Sequential(
                    Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
                ),  # + 1024
                ###################
                # Modified blocks
                ##################
                nn.Sequential(
                    Conv2dTranspose(2048, 1024, kernel_size=3, stride=1, padding=0),  # 3,3
                    Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
                ),  # + 1024
                nn.Sequential(
                    Conv2dTranspose(2048, 1024, kernel_size=3, stride=2, padding=1),
                    Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
                ),  # 5, 5  + 512
                nn.Sequential(
                    Conv2dTranspose(1536, 768, kernel_size=3, stride=2, padding=1),
                    Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True),
                ),  # 9, 9  + 256
                nn.Sequential(
                    Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
                    Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
                ),  # 18, 18  + 128
                ##################
                nn.Sequential(
                    Conv2dTranspose(640, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
                    Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
                ),  # 36, 36  + 64
                nn.Sequential(
                    Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
                    Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
                ),  # 72, 72  + 32
                nn.Sequential(
                    Conv2dTranspose(160, 80, kernel_size=3, stride=2, padding=1, output_padding=1),
                    Conv2d(80, 80, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(80, 80, kernel_size=3, stride=1, padding=1, residual=True),
                ),  # 144, 144  + 16
                nn.Sequential(
                    Conv2dTranspose(96, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
                    Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                    Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
                ),  # 288, 288  + 8
            ]
        )

        self.output_block = nn.Sequential(
            Conv2d(72, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
            nn.Tanh(),
        )

    def freeze_audio_encoder(self):
        for p in self.audio_encoder.parameters():
            p.requires_grad = False

    def forward(self, audio_sequences, face_sequences, noise=False):
        B = audio_sequences.size(0)

        input_dim_size = len(face_sequences.size())
        if input_dim_size > 4:
            audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
            face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)

        audio_embedding = self.audio_encoder(audio_sequences)  # B, 512, 1, 1

        audio_embedding = audio_embedding.detach()
        audio_embedding = self.audio_refine(audio_embedding)
        audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
        audio_embedding = torch.clamp(audio_embedding, -1, 1)

        feats = []
        x = face_sequences
        for f in self.face_encoder_blocks:
            x = f(x)
            feats.append(x)

        x = audio_embedding
        for f in self.face_decoder_blocks:
            x = f(x)
            try:
                x = self.sam(feats[-1], x)
                x = torch.cat((x, feats[-1]), dim=1)
            except Exception as e:
                print(x.size())
                print(feats[-1].size())
                raise e

            feats.pop()

        x = self.output_block(x)

        if input_dim_size > 4:
            x = torch.split(x, B, dim=0)  # [(B, C, H, W)]
            outputs = torch.stack(x, dim=2)  # (B, C, T, H, W)

        else:
            outputs = x

        return outputs

Dataset

class Wav2LipHDDS(Wav2LipDataset):
    def __init__(
        self,
        target_size: int,
        ds_path: Path,
        fps: Optional[List[float]],
        min_face_size: Optional[int],
        ignore_cache: bool = False,
        av_offset: Optional[Dict[str, int]] = None
    ):
        super().__init__(
            ds_path=ds_path, fps=fps, min_face_size=min_face_size, ignore_cache=ignore_cache, av_offset=av_offset
        ),
        self.target_size = target_size

    def read_frames(
        self, window_fnames: List[Path], h_flip: bool = False, randomize: bool = False
    ) -> Optional[List[np.ndarray]]:
        if randomize:
            shuffle(window_fnames)
            if random.random() > 0.5:
                window_fnames = [random.choice(window_fnames)] * len(window_fnames)
        images = read_images(window_fnames)
        images = resize_images(images, self.target_size)
        if h_flip and images is not None:
            images = [cv2.flip(img, 1) for img in images]
        return images

    def get_segmented_mels(self, orig_mel: np.ndarray, start_idx: int, video_name: str) -> Optional[np.ndarray]:
        mels = []
        assert self.num_ref_frames == 5
        if start_idx - 2 < 0:
            return None
        for i in range(start_idx, start_idx + self.num_ref_frames):
            m = self.get_audio_encoding_segment(orig_mel, i - 2, video_name)
            if m is None or m.shape[0] != self.audio_window_size:
                return None
            mels.append(m.T)
        mels = np.asarray(mels).astype(np.float32)
        return mels

    def __getitem__(self, idx):
        while True:
            video = random.choice(self.videos)
            img_names = video["faces"]
            if len(img_names) <= 3 * self.num_ref_frames:
                continue
            offset = self.av_offset.get(video["name"], 0) if self.av_offset else 0
            start_idx = 0 if offset >= 0 else abs(offset)
            end_idx = len(img_names) if offset <= 0 else len(img_names) - offset
            img_idx = random.randint(start_idx, end_idx - self.num_ref_frames - 1)
            ref_idx = random.randint(0, len(img_names) - self.num_ref_frames - 1)

            while (abs(ref_idx - img_idx) < self.num_ref_frames) or (
                (len(img_names) - 1 - ref_idx) < self.num_ref_frames
            ):
                ref_idx = random.randint(0, len(img_names) - self.num_ref_frames - 1)

            h_flip = random.random() > 0.7

            frame_paths = self.get_frames(img_names, img_idx)
            ref_frame_paths = self.get_frames(img_names, ref_idx)

            frames = self.read_frames(frame_paths, h_flip)
            if frames is None:
                continue
            ref_frames = self.read_frames(ref_frame_paths, h_flip, randomize=True)
            if ref_frames is None:
                continue

            mel_path = video["audio"].parent / "mel.npy"
            if not mel_path.exists():
                raise FileNotFoundError(f"Mel file not found: {mel_path}")
            orig_mel = np.load(mel_path)
            mel = self.get_audio_encoding_segment(orig_mel.copy(), img_idx, video["name"])
            if mel is None or mel.shape[0] != self.audio_window_size:
                continue

            indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_idx, video["name"])
            if indiv_mels is None:
                continue

            # ground truth images
            frames = self.normalize_and_rearrange(frames)
            y = frames.copy()
            frames[:, :, frames.shape[2] // 2 :, :] = 0

            # reference images
            ref_frames = self.normalize_and_rearrange(ref_frames)
            x = np.concatenate([frames, ref_frames], axis=0)
            x = torch.FloatTensor(x)
            mel = torch.FloatTensor(mel.T).unsqueeze(0)
            indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
            y = torch.FloatTensor(y)
            return x, indiv_mels, mel, y

Dataloader output shapes

Batch size: 16

x: torch.Size([16, 6, 5, 288, 288]), indiv_mels: torch.Size([16, 5, 1, 80, 16]), mel: torch.Size([16, 1, 80, 16]), y: torch.Size([16, 3, 5, 288, 288])

Training pytorch lightning module

class Wav2LipHD(L.LightningModule):
    def __init__(self, model: torch.nn.Module, disc: NLayerDiscriminator, syncnet: SyncNet, hparams: Wav2LipHDConf):
        super().__init__()
        self.model = model
        self.disc = disc
        self.syncnet = syncnet
        self.conf = hparams
        self.automatic_optimization = False
        self.save_hyperparameters(hparams.model_dump(mode="json"))
        self.logloss = torch.nn.BCELoss()
        self.recon_loss = torch.nn.L1Loss()
        self.loss_fn_vgg = LPIPS(net="vgg").eval()
        self.d_weight = 0.025
        self.num_ref_frames = 5

    def forward(self, audio_enc: torch.Tensor, x: torch.Tensor):
        return self.model(audio_enc, x)

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
        x, indiv_mels, mel, gt = batch
        optimizer, disc_optimizer = self.optimizers()
        optimizer.zero_grad()
        disc_optimizer.zero_grad()
        g = self(indiv_mels, x)
        if batch_idx == 0 and self.trainer.is_global_zero:
            _ = self.video_tensor_to_gif(gt, g, "/tmp/train.gif")
            wandb.log({"train_gif": [wandb.Video("/tmp/train.gif", fps=4, format="gif")]})
        if self.global_step > self.conf.disc_iter_start:
            fake_output = self.disc(g)
            perceptual_loss = -torch.mean(fake_output)
        else:
            perceptual_loss = torch.tensor(0.0).to(self.device)

        l1loss = self.recon_loss(g, gt)
        # l1loss = torch.tensor(0.0).to(self.device)
        vgg_loss = self.loss_fn_vgg(
            torch.cat([g[:, :, i] for i in range(g.size(2))], dim=0),
            torch.cat([gt[:, :, i] for i in range(gt.size(2))], dim=0),
        )
        vgg_loss = vgg_loss.mean().to(self.device)
        nll_loss = l1loss + vgg_loss

        if self.global_step > self.conf.sync_iter_start and self.conf.syncnet_loss_wt > 0.0:
            g_reshaped = rearrange(g, "b c t h w -> (b t) c h w")
            interpolated = F.interpolate(
                g_reshaped, size=(self.conf.syncnet_img_size, self.conf.syncnet_img_size), mode="bilinear"
            )
            sync_tensor = rearrange(interpolated, "(b t) c h w -> b (c t) h w", t=self.num_ref_frames)

            y = torch.ones(sync_tensor.size(0), 1).float().to(self.device)
            a, v = self.syncnet(mel, sync_tensor[:, :, self.conf.syncnet_img_size // 2 :, :])
            sync_loss = self.cosine_loss(a, v, y)
        else:
            sync_loss = torch.tensor(0.0).to(self.device)

        loss = self.conf.syncnet_loss_wt * sync_loss + self.d_weight * perceptual_loss + nll_loss
        self.manual_backward(loss, retain_graph=True)
        optimizer.step()

        ### Remove all gradients before Training disc
        disc_optimizer.zero_grad()

        if self.global_step > self.conf.disc_iter_start:
            real_output = self.disc(gt)
            fake_output = self.disc(g.detach())
            disc_real_loss, disc_fake_loss = self.hinge_d_loss(real_output, fake_output)
            d_loss = 0.5 * (disc_fake_loss + disc_real_loss)
            self.manual_backward(d_loss)
            disc_optimizer.step()
        else:
            disc_real_loss = torch.tensor(0.0).to(self.device)
            disc_fake_loss = torch.tensor(0.0).to(self.device)
            d_loss = torch.tensor(0.0).to(self.device)
        return loss

    def hinge_d_loss(self, logits_real: torch.Tensor, logits_fake: torch.Tensor):
        loss_real = torch.mean(F.relu(1.0 - logits_real))
        loss_fake = torch.mean(F.relu(1.0 + logits_fake))
        return loss_real, loss_fake

    def cosine_loss(self, a: torch.Tensor, v: torch.Tensor, y: torch.Tensor):
        d = torch.nn.functional.cosine_similarity(a, v)
        loss = self.logloss(d.unsqueeze(1), y)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(
            [p for p in self.model.parameters() if p.requires_grad], lr=self.conf.lr, betas=(0.5, 0.999)
        )
        disc_optimizer = optim.Adam(
            [p for p in self.disc.parameters() if p.requires_grad], lr=self.conf.disc_lr, betas=(0.5, 0.999)
        )
        return optimizer, disc_optimizer
ghost commented 4 months ago

do you apply syncnet?

ghost commented 4 months ago

btw wav2lip is not good enough to train the general model, you have to figure it out on yourself.

mudassirkhan19 commented 4 months ago

Hey @primepake,

do you apply syncnet?

yes, I'm applying syncnet, but this happens even before syncnet is started.

btw wav2lip is not good enough to train the general model, you have to figure it out on yourself.

That's too bad, I was excited after seeing the Chinese demo video, I guess DINet is much better, looking forward to seeing your take on it with the melspectogram attack, thanks for your contribution though, your comments on various issues helped me a lot in training syncnet which was a critical piece of the puzzle.

ghost commented 4 months ago

wav2lip is just good on specific person, for sure a person need at least 60 minutes to get good result

see2run commented 1 month ago

wav2lip is just good on specific person, for sure a person need at least 60 minutes to get good result

Hey, can you help me? Percep, Fake, and Real are always 0.0 during the Wav2Lip training. Please