johndpope / Emote-hack

Emote Portrait Alive - using ai to reverse engineer code from white paper. (abandoned)
https://github.com/johndpope/VASA-1-hack
172 stars 9 forks source link

It's just me (and ai) against Alibaba #26

Closed johndpope closed 8 months ago

johndpope commented 8 months ago

High Fidelity Image-to-Video Generation - 24-march. More importantly, as we have fixed the weights of the spatial layers, our work can seamlessly integrate with existing plugins such as ControlNet [42], LoRAs [19], and stylized base models. https://atomo-video.github.io/ framework

surprised that ChatGpt got this paper - it's becoming a search engine. Screenshot from 2024-03-24 09-37-12

AtomoVideo - by ClaudeAI - Opus - first pass


import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel, DDPMScheduler, DPMSolverMultistepScheduler

class TemporalConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=(kernel_size, 1, 1), padding=(padding, 0, 0))
        self.norm = nn.GroupNorm(num_groups=32, num_channels=out_channels)
        self.activation = nn.SiLU()

    def forward(self, x):
        return self.activation(self.norm(self.conv(x)))

class TemporalAttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups=32, num_channels=channels)
        self.query = nn.Conv3d(channels, channels, kernel_size=1)
        self.key = nn.Conv3d(channels, channels, kernel_size=1)
        self.value = nn.Conv3d(channels, channels, kernel_size=1)
        self.proj = nn.Conv3d(channels, channels, kernel_size=1)

    def forward(self, x):
        batch, channel, frames, height, width = x.shape
        x = x.view(batch, channel, frames, height * width).permute(0, 3, 2, 1)  # (batch, height*width, frames, channels)
        q = self.query(x).permute(0, 3, 1, 2)  # (batch, channels, height*width, frames)
        k = self.key(x).permute(0, 3, 2, 1)  # (batch, channels, frames, height*width)
        v = self.value(x).permute(0, 3, 1, 2)  # (batch, channels, height*width, frames)
        energy = torch.matmul(q, k) / (channel ** 0.5)
        attention = torch.softmax(energy, dim=-1)
        out = torch.matmul(v, attention).permute(0, 2, 3, 1).contiguous().view(batch, channel, frames, height, width)
        out = self.proj(out)
        return out + x

class AtomoVideoModel(nn.Module):
    def __init__(self, image_encoder, text_encoder, unet, vae_decoder, noise_scheduler):
        super().__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.unet = unet
        self.vae_decoder = vae_decoder
        self.noise_scheduler = noise_scheduler

        self.temporal_conv = nn.ModuleList([
            TemporalConvBlock(128, 128),
            TemporalConvBlock(256, 256),
            TemporalConvBlock(512, 512),
            TemporalConvBlock(1024, 1024),
        ])

        self.temporal_attn = nn.ModuleList([
            TemporalAttentionBlock(128),
            TemporalAttentionBlock(256),
            TemporalAttentionBlock(512),
            TemporalAttentionBlock(1024),
        ])

    def forward(self, input_image, prompt, video_length, num_inference_steps=50):
        # Encode input image and prompt
        image_latents = self.image_encoder(input_image)
        text_embeddings = self.text_encoder(prompt)

        frame_mask = torch.ones(1, 1, video_length, image_latents.shape[2], image_latents.shape[3])
        image_latents = image_latents.unsqueeze(2).repeat(1, 1, video_length, 1, 1)

        input = torch.cat([torch.randn(1, 6, video_length, image_latents.shape[3], image_latents.shape[4]), frame_mask, image_latents], dim=1)

        for t in range(num_inference_steps):
            # Predict noise residual
            model_output = self.unet(input, t, encoder_hidden_states=text_embeddings)

            # Apply temporal convolutions and attention
            for conv, attn in zip(self.temporal_conv, self.temporal_attn):
                model_output = conv(model_output)
                model_output = attn(model_output)

            # Compute previous noisy sample x_t -> x_t-1
            input = self.noise_scheduler.step(model_output, t, input)

        # Decode the generated video frames
        generated_video = self.vae_decoder(input)

        return generated_video

# Initialize the model components
image_encoder = torch.hub.load("openai/clip-vit-base-patch16", "encode_image").cuda()
text_encoder = torch.hub.load("openai/clip-vit-base-patch16", "encode_text").cuda()
vae_decoder = torch.hub.load("CompVis/stable-diffusion-v1-4", "vae_decoder").cuda()

unet = UNet2DConditionModel(
    sample_size=512,
    in_channels=9,
    out_channels=6,
    layers_per_block=2,
    block_out_channels=(128, 256, 512, 1024),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

# Create the AtomoVideoModel
model = AtomoVideoModel(
    image_encoder=image_encoder,
    text_encoder=text_encoder,
    unet=unet,
    vae_decoder=vae_decoder,
    noise_scheduler=noise_scheduler
)

# Load pretrained weights (assuming they are available)
model.load_state_dict(torch.load("pretrained_weights.pt"))

# Freeze the spatial layers
for name, param in model.unet.named_parameters():
    if not ("temporal" in name or "input" in name):
        param.requires_grad = False

# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
    for batch in train_dataloader:
        input_image = batch["input_image"]
        prompt = batch["prompt"]
        video_frames = batch["video_frames"]

        # Generate video frames
        generated_video = model(input_image, prompt, video_length=video_frames.shape[2])

        # Compute loss and perform optimization
        loss = torch.mean((generated_video - video_frames) ** 2)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

# Inference
@torch.no_grad()
def generate_video(input_image, prompt, video_length=24):
    model.eval()
    generated_video = model(input_image, prompt, video_length)
    return generated_video

In this detailed implementation, the AtomoVideoModel class is defined, which combines the image encoder, text encoder, UNet, VAE decoder, and noise scheduler components. The model also includes temporal convolution and attention blocks to capture temporal information.

The forward method of the AtomoVideoModel takes an input image, prompt, and desired video length as inputs. It encodes the image and prompt using the respective encoders and generates video frames iteratively using the UNet and noise scheduler. The temporal convolution and attention blocks are applied to the UNet output at each iteration.

During training, the generated video frames are compared with the ground truth video frames using a mean squared error loss, and the model parameters are optimized using backpropagation.

For inference, the generate_video function takes an input image, prompt, and desired video length and generates the corresponding video frames using the trained model.

Note that this implementation assumes the availability of pretrained weights for the image encoder, text encoder, and VAE decoder. You may need to modify the code to load the appropriate pretrained weights based on your specific setup.

UPDATE - Taobao Launches AtomoVideo, Turning Static Images into Creative Short Videos in One-click http://www.aastocks.com/en/stocks/news/aafn-con/NOW.1334576/positive-news/AAFN