High Fidelity Image-to-Video Generation - 24-march.
More importantly, as we have fixed the weights of the spatial layers, our work can seamlessly integrate with existing
plugins such as ControlNet [42], LoRAs [19], and stylized
base models.
https://atomo-video.github.io/
surprised that ChatGpt got this paper - it's becoming a search engine.
AtomoVideo - by ClaudeAI - Opus - first pass
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel, DDPMScheduler, DPMSolverMultistepScheduler
class TemporalConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
super().__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=(kernel_size, 1, 1), padding=(padding, 0, 0))
self.norm = nn.GroupNorm(num_groups=32, num_channels=out_channels)
self.activation = nn.SiLU()
def forward(self, x):
return self.activation(self.norm(self.conv(x)))
class TemporalAttentionBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.norm = nn.GroupNorm(num_groups=32, num_channels=channels)
self.query = nn.Conv3d(channels, channels, kernel_size=1)
self.key = nn.Conv3d(channels, channels, kernel_size=1)
self.value = nn.Conv3d(channels, channels, kernel_size=1)
self.proj = nn.Conv3d(channels, channels, kernel_size=1)
def forward(self, x):
batch, channel, frames, height, width = x.shape
x = x.view(batch, channel, frames, height * width).permute(0, 3, 2, 1) # (batch, height*width, frames, channels)
q = self.query(x).permute(0, 3, 1, 2) # (batch, channels, height*width, frames)
k = self.key(x).permute(0, 3, 2, 1) # (batch, channels, frames, height*width)
v = self.value(x).permute(0, 3, 1, 2) # (batch, channels, height*width, frames)
energy = torch.matmul(q, k) / (channel ** 0.5)
attention = torch.softmax(energy, dim=-1)
out = torch.matmul(v, attention).permute(0, 2, 3, 1).contiguous().view(batch, channel, frames, height, width)
out = self.proj(out)
return out + x
class AtomoVideoModel(nn.Module):
def __init__(self, image_encoder, text_encoder, unet, vae_decoder, noise_scheduler):
super().__init__()
self.image_encoder = image_encoder
self.text_encoder = text_encoder
self.unet = unet
self.vae_decoder = vae_decoder
self.noise_scheduler = noise_scheduler
self.temporal_conv = nn.ModuleList([
TemporalConvBlock(128, 128),
TemporalConvBlock(256, 256),
TemporalConvBlock(512, 512),
TemporalConvBlock(1024, 1024),
])
self.temporal_attn = nn.ModuleList([
TemporalAttentionBlock(128),
TemporalAttentionBlock(256),
TemporalAttentionBlock(512),
TemporalAttentionBlock(1024),
])
def forward(self, input_image, prompt, video_length, num_inference_steps=50):
# Encode input image and prompt
image_latents = self.image_encoder(input_image)
text_embeddings = self.text_encoder(prompt)
frame_mask = torch.ones(1, 1, video_length, image_latents.shape[2], image_latents.shape[3])
image_latents = image_latents.unsqueeze(2).repeat(1, 1, video_length, 1, 1)
input = torch.cat([torch.randn(1, 6, video_length, image_latents.shape[3], image_latents.shape[4]), frame_mask, image_latents], dim=1)
for t in range(num_inference_steps):
# Predict noise residual
model_output = self.unet(input, t, encoder_hidden_states=text_embeddings)
# Apply temporal convolutions and attention
for conv, attn in zip(self.temporal_conv, self.temporal_attn):
model_output = conv(model_output)
model_output = attn(model_output)
# Compute previous noisy sample x_t -> x_t-1
input = self.noise_scheduler.step(model_output, t, input)
# Decode the generated video frames
generated_video = self.vae_decoder(input)
return generated_video
# Initialize the model components
image_encoder = torch.hub.load("openai/clip-vit-base-patch16", "encode_image").cuda()
text_encoder = torch.hub.load("openai/clip-vit-base-patch16", "encode_text").cuda()
vae_decoder = torch.hub.load("CompVis/stable-diffusion-v1-4", "vae_decoder").cuda()
unet = UNet2DConditionModel(
sample_size=512,
in_channels=9,
out_channels=6,
layers_per_block=2,
block_out_channels=(128, 256, 512, 1024),
down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
# Create the AtomoVideoModel
model = AtomoVideoModel(
image_encoder=image_encoder,
text_encoder=text_encoder,
unet=unet,
vae_decoder=vae_decoder,
noise_scheduler=noise_scheduler
)
# Load pretrained weights (assuming they are available)
model.load_state_dict(torch.load("pretrained_weights.pt"))
# Freeze the spatial layers
for name, param in model.unet.named_parameters():
if not ("temporal" in name or "input" in name):
param.requires_grad = False
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for batch in train_dataloader:
input_image = batch["input_image"]
prompt = batch["prompt"]
video_frames = batch["video_frames"]
# Generate video frames
generated_video = model(input_image, prompt, video_length=video_frames.shape[2])
# Compute loss and perform optimization
loss = torch.mean((generated_video - video_frames) ** 2)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Inference
@torch.no_grad()
def generate_video(input_image, prompt, video_length=24):
model.eval()
generated_video = model(input_image, prompt, video_length)
return generated_video
In this detailed implementation, the AtomoVideoModel class is defined, which combines the image encoder, text encoder, UNet, VAE decoder, and noise scheduler components. The model also includes temporal convolution and attention blocks to capture temporal information.
The forward method of the AtomoVideoModel takes an input image, prompt, and desired video length as inputs. It encodes the image and prompt using the respective encoders and generates video frames iteratively using the UNet and noise scheduler. The temporal convolution and attention blocks are applied to the UNet output at each iteration.
During training, the generated video frames are compared with the ground truth video frames using a mean squared error loss, and the model parameters are optimized using backpropagation.
For inference, the generate_video function takes an input image, prompt, and desired video length and generates the corresponding video frames using the trained model.
Note that this implementation assumes the availability of pretrained weights for the image encoder, text encoder, and VAE decoder. You may need to modify the code to load the appropriate pretrained weights based on your specific setup.
High Fidelity Image-to-Video Generation - 24-march. More importantly, as we have fixed the weights of the spatial layers, our work can seamlessly integrate with existing plugins such as ControlNet [42], LoRAs [19], and stylized base models. https://atomo-video.github.io/
surprised that ChatGpt got this paper - it's becoming a search engine.
AtomoVideo - by ClaudeAI - Opus - first pass
In this detailed implementation, the AtomoVideoModel class is defined, which combines the image encoder, text encoder, UNet, VAE decoder, and noise scheduler components. The model also includes temporal convolution and attention blocks to capture temporal information.
The forward method of the AtomoVideoModel takes an input image, prompt, and desired video length as inputs. It encodes the image and prompt using the respective encoders and generates video frames iteratively using the UNet and noise scheduler. The temporal convolution and attention blocks are applied to the UNet output at each iteration.
During training, the generated video frames are compared with the ground truth video frames using a mean squared error loss, and the model parameters are optimized using backpropagation.
For inference, the generate_video function takes an input image, prompt, and desired video length and generates the corresponding video frames using the trained model.
Note that this implementation assumes the availability of pretrained weights for the image encoder, text encoder, and VAE decoder. You may need to modify the code to load the appropriate pretrained weights based on your specific setup.
UPDATE - Taobao Launches AtomoVideo, Turning Static Images into Creative Short Videos in One-click http://www.aastocks.com/en/stocks/news/aafn-con/NOW.1334576/positive-news/AAFN