Open eisneim opened 1 month ago
for comparison, here are the results of unet2d without temporal conv3d 900k iteration
https://github.com/user-attachments/assets/8d5ff37c-0494-4a02-9d97-e06b3fe55ce1
https://github.com/user-attachments/assets/175c53fc-09ff-4ba6-baf5-699c2286ba37
https://github.com/user-attachments/assets/70340098-a97b-41d0-a9bf-f465674e4bd0
cov3d version at 129k is pretty close to original unet2d with 900k, so i think i should add more conv3d
@eisneim I also try to reimplement generative image dynamics. I notice the training data influence a lot. Some results are shown below. The background is stable as I manually collect videos where the background is nearly static. Could you share the script to download videos from online footage websites? I guess more data could lead to further performance improvement.
https://github.com/user-attachments/assets/877e3d0a-5668-43fb-b7c8-32935e1d4fe4
https://github.com/user-attachments/assets/7b9e2e3b-2516-4e0f-ba00-e3d976d6bb01
https://github.com/user-attachments/assets/169fee32-4637-42c9-9442-c4bf6c87c0ee
https://github.com/user-attachments/assets/c3f4fe62-698b-4790-a76f-4adf5fd24721
3 temporal conv3d at the beginning, mid, end
import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import Optional, Tuple, Union
from einops import rearrange
from diffusers import UNet2DModel
# from diffusers.models.unets.unet_2d import UNet2DOutput
class Freq_UNet2d(UNet2DModel):
def __init__(self, tdropout=0.5, *args, **kwargs):
super().__init__(*args, **kwargs)
drop_p = tdropout
# self.temporal_conv_in = nn.Conv3d(self.config.block_out_channels[0], self.config.block_out_channels[0], (3, 1, 1), padding=(1, 0, 0))
self.temporal_conv_in = nn.Sequential(
nn.Conv3d(self.config.block_out_channels[0], self.config.block_out_channels[0], (3, 1, 1), padding=(1, 0, 0)),
nn.ReLU(),
nn.Dropout(p=drop_p),
)
mid_channel_count = self.config.block_out_channels[-1]
# self.temporal_conv_mid = nn.Conv3d(mid_channel_count, mid_channel_count, (3, 1, 1), padding=(1, 0, 0))
self.temporal_conv_mid = nn.Sequential(
nn.Conv3d(mid_channel_count, mid_channel_count, (3, 1, 1), padding=(1, 0, 0)),
nn.ReLU(),
nn.Dropout(p=drop_p),
)
self.temporal_conv_final = nn.Sequential(
nn.Conv3d(self.config.block_out_channels[0], self.config.block_out_channels[0], (3, 1, 1), padding=(1, 0, 0)),
nn.ReLU(),
nn.Dropout(p=drop_p),
)
def temp_conv(self, sample, module, group_size):
B, C, H, W = sample.shape
if group_size is not None:
assert B % group_size == 0, "group_size应该能被batch整除"
sample_4d = rearrange(sample, "(n g) c h w -> n g c h w", g=group_size, n=int(B / group_size))
sample_t = module(sample_4d.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4)
sample_t = rearrange(sample_t, "n g c h w -> (n g) c h w")
else:
sample_t = module(sample.permute(1, 0, 2, 3)).permute(1, 0, 2, 3)
return sample_t
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
group_size=None,
) -> torch.Tensor:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
B, C, H, W = sample.shape
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
sample_t = self.temp_conv(sample, self.temporal_conv_in, group_size)
sample = sample + sample_t
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
# ------------------------
# 3d conv on time dimension
# sample.shape: (B, ?, H/8, w/8) [2, 896, 8, 4]
# ------------------------
temporal_mid = self.temp_conv(sample, self.temporal_conv_mid, group_size)
sample = sample + temporal_mid
# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
temporal_final = self.temp_conv(sample, self.temporal_conv_final, group_size)
sample = sample + temporal_final
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
return sample
i use CogVideoX 5b i2v model to create synthetic videos, @Unified-Robots i download video from Pond5.com, envato.com, pexels and Youtube, with chrome browser extensions, i don't have script to batch do data scraping but i write chrome extensions to batch download from webpage, then use tiny-CLIP to filter out videos with big movements, then manually filter out the rest!
recently I've successfully replace unet2d with U-Dit which claims to have less compute and better performance, and also changing the VAE to stabilityai/sd-vae-ft-ema to have 8x latent compression
loss went from 1.0 to 0.15 is very very fast but then it stuck at this loss range, it might took huge amount of compute and my single RTX 4090 is not gonna be able to handle, so i switch back to u-net2d with some modification.
add conv3d to temporal dimension
the paper use attention on different frequency channels, i was thinking if a simple conv3d on the channel dimension would do the same thing.
first change input to (bs * 16, vae_encoded_chs, h, w), it's a grouped sample input, for example if i have 3 sample: [flower, portrait, trees] the input would be [16, vae_ch, h, w] +[16, vae_ch, h, w] +[16, vae_ch, h, w] concatenated. the idea is when doing conv3d we need to reshape them to (bs, 16, vae_chs, h, w) => conv3d
for experiment purpose, i only added two conv3d, one at beginning and one at the middle, maybe i should add them in each block ?
and here is the model:
here the kernel size (3, 1, 1) which is for (time, h ,w) so it means to only work on time dimension. currently at iteration 129k single rtx 4090 24G, dataset size: 288x160 150 frames, batch size 4 16, lr 4 5e-5, adamW
https://github.com/user-attachments/assets/981aa78a-f02e-4693-9413-b24f15f3e6f3
https://github.com/user-attachments/assets/a326f62f-9f00-42e7-9bf9-1dbe54f03356
https://github.com/user-attachments/assets/9ac1c838-a050-4710-a1ad-7f42339f7c9d
https://github.com/user-attachments/assets/76341ba6-8f74-4c8d-8c94-18ccc5f9717b
adding conv3d took 129k steps to beat unet2d with 600k steps