johndpope / Emote-hack

Emote Portrait Alive - using ai to reverse engineer code from white paper. (abandoned)
https://github.com/johndpope/VASA-1-hack
173 stars 9 forks source link

backbone network == diffused heads backbone #29

Closed johndpope closed 8 months ago

johndpope commented 8 months ago

https://github.com/MStypulkowski/diffused-heads/blob/main/diffusion.py


import numpy as np
import torch
import torch.nn as nn
from tqdm import trange
from torchvision.transforms import Compose

class Diffusion(nn.Module):
    def __init__(
        self, nn_backbone, device, n_timesteps=1000, in_channels=3, image_size=128, out_channels=6, motion_transforms=None):
        super(Diffusion, self).__init__()

        self.nn_backbone = nn_backbone
        self.n_timesteps = n_timesteps
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.x_shape = (image_size, image_size)
        self.device = device

        self.motion_transforms = motion_transforms if motion_transforms else Compose([])

        self.timesteps = torch.arange(n_timesteps)
        self.beta = self.get_beta_schedule()
        self.set_params()
        self.device = device

    def sample(self, x_cond, audio_emb, n_audio_motion_embs=2, n_motion_frames=2, motion_channels=3):
        with torch.no_grad():
            n_frames = audio_emb.shape[1]

            xT = torch.randn(x_cond.shape[0], n_frames, self.in_channels, self.x_shape[0], self.x_shape[1]).to(x_cond.device)

            audio_ids = [0] * n_audio_motion_embs
            for i in range(n_audio_motion_embs + 1):
                audio_ids += [i]

            motion_frames = [self.motion_transforms(x_cond) for _ in range(n_motion_frames)]
            motion_frames = torch.cat(motion_frames, dim=1)

            samples = []
            for i in trange(n_frames, desc=f'Sampling'):
                sample_frame = self.sample_loop(xT[:, i].to(x_cond.device), x_cond, motion_frames, audio_emb[:, audio_ids])
                samples.append(sample_frame.unsqueeze(1))
                motion_frames = torch.cat([motion_frames[:, motion_channels:, :], self.motion_transforms(sample_frame)], dim=1)
                audio_ids = audio_ids[1:] + [min(i + n_audio_motion_embs + 1, n_frames - 1)]
        return torch.cat(samples, dim=1)

    def sample_loop(self, xT, x_cond, motion_frames, audio_emb):
        xt = xT
        for i, t in reversed(list(enumerate(self.timesteps))):
            timesteps = torch.tensor([t] * xT.shape[0]).to(xT.device)
            timesteps_ids = torch.tensor([i] * xT.shape[0]).to(xT.device)
            nn_out = self.nn_backbone(xt, timesteps, x_cond, motion_frames=motion_frames, audio_emb=audio_emb)
            mean, logvar = self.get_p_params(xt, timesteps_ids, nn_out)
            noise = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt)
            xt = mean + noise * torch.exp(logvar / 2)

        return xt

    def get_p_params(self, xt, timesteps, nn_out):
        if self.in_channels == self.out_channels:
            eps_pred = nn_out
            p_logvar = self.expand(torch.log(self.beta[timesteps]))
        else:
            eps_pred, nu = nn_out.chunk(2, 1)
            nu = (nu + 1) / 2
            p_logvar = nu * self.expand(torch.log(self.beta[timesteps])) + (1 - nu) * self.expand(self.log_beta_tilde_clipped[timesteps])

        p_mean, _ = self.get_q_params(xt, timesteps, eps_pred=eps_pred)
        return p_mean, p_logvar

    def get_q_params(self, xt, timesteps, eps_pred=None, x0=None):
        if x0 is None:
        # predict x0 from xt and eps_pred
            coef1_x0 = self.expand(self.coef1_x0[timesteps])
            coef2_x0 = self.expand(self.coef2_x0[timesteps])
            x0 = coef1_x0 * xt - coef2_x0 * eps_pred
            x0 = x0.clamp(-1, 1)

        # q(x_{t-1} | x_t, x_0)
        coef1_q = self.expand(self.coef1_q[timesteps])
        coef2_q = self.expand(self.coef2_q[timesteps])
        q_mean = coef1_q * x0 + coef2_q * xt

        q_logvar = self.expand(self.log_beta_tilde_clipped[timesteps])

        return q_mean, q_logvar

    def get_beta_schedule(self, max_beta=0.999):
        alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
        betas = []
        for i in range(self.n_timesteps):
            t1 = i / self.n_timesteps
            t2 = (i + 1) / self.n_timesteps
            betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
        return torch.tensor(betas).float()

    def set_params(self):
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.alpha_bar_prev = torch.cat([torch.ones(1,), self.alpha_bar[:-1]])

        self.beta_tilde = self.beta * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar)
        self.log_beta_tilde_clipped = torch.log(torch.cat([self.beta_tilde[1, None], self.beta_tilde[1:]]))

        # to caluclate x0 from eps_pred
        self.coef1_x0 = torch.sqrt(1.0 / self.alpha_bar)
        self.coef2_x0 = torch.sqrt(1.0 / self.alpha_bar - 1)

        # for q(x_{t-1} | x_t, x_0)
        self.coef1_q = self.beta * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar)
        self.coef2_q = (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alpha) / (1.0 - self.alpha_bar)

    def space(self, n_timesteps_new):
        # change parameters for spaced timesteps during sampling
        self.timesteps = self.space_timesteps(self.n_timesteps, n_timesteps_new)
        self.n_timesteps = n_timesteps_new

        self.beta = self.get_spaced_beta()
        self.set_params()

    def space_timesteps(self, n_timesteps, target_timesteps):
        all_steps = []
        frac_stride = (n_timesteps - 1) / (target_timesteps - 1)
        cur_idx = 0.0
        taken_steps = []
        for _ in range(target_timesteps):
            taken_steps.append(round(cur_idx))
            cur_idx += frac_stride
        all_steps += taken_steps
        return all_steps

    def get_spaced_beta(self):
        last_alpha_cumprod = 1.0
        new_beta = []
        for i, alpha_cumprod in enumerate(self.alpha_bar):
            if i in self.timesteps:
                new_beta.append(1 - alpha_cumprod / last_alpha_cumprod)
                last_alpha_cumprod = alpha_cumprod
        return torch.tensor(new_beta)

    def expand(self, arr, dim=4):
        while arr.dim() < dim:
            arr = arr[:, None]
        return arr.to(self.device)
johndpope commented 8 months ago

this is nonsense - there's no stable diffusion. investigating re-using these trained models from AniPortrait. Screenshot from 2024-03-27 12-51-48