fltwr / generative-image-dynamics

139 stars 14 forks source link

Sharing some of my experiment results #8

Open eisneim opened 1 month ago

eisneim commented 1 month ago

recently I've successfully replace unet2d with U-Dit which claims to have less compute and better performance, and also changing the VAE to stabilityai/sd-vae-ft-ema to have 8x latent compression udit_loss-

loss went from 1.0 to 0.15 is very very fast but then it stuck at this loss range, it might took huge amount of compute and my single RTX 4090 is not gonna be able to handle, so i switch back to u-net2d with some modification.

add conv3d to temporal dimension

the paper use attention on different frequency channels, i was thinking if a simple conv3d on the channel dimension would do the same thing.

first change input to (bs * 16, vae_encoded_chs, h, w), it's a grouped sample input, for example if i have 3 sample: [flower, portrait, trees] the input would be [16, vae_ch, h, w] +[16, vae_ch, h, w] +[16, vae_ch, h, w] concatenated. the idea is when doing conv3d we need to reshape them to (bs, 16, vae_chs, h, w) => conv3d

for experiment purpose, i only added two conv3d, one at beginning and one at the middle, maybe i should add them in each block ?

and here is the model:

import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import Optional, Tuple, Union
from einops import rearrange

from diffusers import UNet2DModel

class Freq_UNet2d(UNet2DModel):
    def __init__(self, *args, **kwargs
    ):
        super().__init__(*args, **kwargs)

        self.temporal_conv_in = nn.Conv3d(self.config.block_out_channels[0], self.config.block_out_channels[0], (3, 1, 1), padding=(1, 0, 0))

        mid_channel_count = self.config.block_out_channels[-1]
        self.temporal_conv_mid = nn.Conv3d(mid_channel_count, mid_channel_count, (3, 1, 1), padding=(1, 0, 0))

    def temp_conv(self, sample, module, group_size):
        B, C, H, W = sample.shape

        if group_size is not None:
            assert B % group_size == 0, "group_size应该能被batch整除"
            sample_4d = rearrange(sample, "(n g) c h w -> n g c h w", g=group_size, n=int(B / group_size))
            sample_t = module(sample_4d.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4)
            sample_t = rearrange(sample_t, "n g c h w -> (n g) c h w")
        else:
            sample_t = module(sample.permute(1, 0, 2, 3)).permute(1, 0, 2, 3)
        return sample_t

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        class_labels: Optional[torch.Tensor] = None,
        group_size=None,
    ) -> torch.Tensor:
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        B, C, H, W = sample.shape

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

        t_emb = self.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)
        emb = self.time_embedding(t_emb)

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when doing class conditioning")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb
        elif self.class_embedding is None and class_labels is not None:
            raise ValueError("class_embedding needs to be initialized in order to use class conditioning")

        # 2. pre-process
        skip_sample = sample
        sample = self.conv_in(sample)
        sample_t = self.temp_conv(sample, self.temporal_conv_in, group_size)
        sample = sample + sample_t

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = downsample_block(
                    hidden_states=sample, temb=emb, skip_sample=skip_sample
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.mid_block(sample, emb)
        # ------------------------
        #  3d conv on time dimension
        # sample.shape: (B, ?, H/8, w/8)  [2, 896, 8, 4]
        # ------------------------
        temporal_mid = self.temp_conv(sample, self.temporal_conv_mid, group_size)
        sample = sample + temporal_mid

        # 5. up
        skip_sample = None
        for upsample_block in self.up_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
            else:
                sample = upsample_block(sample, res_samples, emb)

        # 6. post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        if skip_sample is not None:
            sample += skip_sample

        if self.config.time_embedding_type == "fourier":
            timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
            sample = sample / timesteps

        return sample

here the kernel size (3, 1, 1) which is for (time, h ,w) so it means to only work on time dimension. currently at iteration 129k unet_v1_loss single rtx 4090 24G, dataset size: 288x160 150 frames, batch size 4 16, lr 4 5e-5, adamW

https://github.com/user-attachments/assets/981aa78a-f02e-4693-9413-b24f15f3e6f3

https://github.com/user-attachments/assets/a326f62f-9f00-42e7-9bf9-1dbe54f03356

https://github.com/user-attachments/assets/9ac1c838-a050-4710-a1ad-7f42339f7c9d

https://github.com/user-attachments/assets/76341ba6-8f74-4c8d-8c94-18ccc5f9717b

adding conv3d took 129k steps to beat unet2d with 600k steps

eisneim commented 1 month ago

for comparison, here are the results of unet2d without temporal conv3d 900k iteration

https://github.com/user-attachments/assets/8d5ff37c-0494-4a02-9d97-e06b3fe55ce1

https://github.com/user-attachments/assets/175c53fc-09ff-4ba6-baf5-699c2286ba37

https://github.com/user-attachments/assets/70340098-a97b-41d0-a9bf-f465674e4bd0

cov3d version at 129k is pretty close to original unet2d with 900k, so i think i should add more conv3d

Unified-Robots commented 1 month ago

@eisneim I also try to reimplement generative image dynamics. I notice the training data influence a lot. Some results are shown below. The background is stable as I manually collect videos where the background is nearly static. Could you share the script to download videos from online footage websites? I guess more data could lead to further performance improvement.

https://github.com/user-attachments/assets/877e3d0a-5668-43fb-b7c8-32935e1d4fe4

https://github.com/user-attachments/assets/7b9e2e3b-2516-4e0f-ba00-e3d976d6bb01

https://github.com/user-attachments/assets/169fee32-4637-42c9-9442-c4bf6c87c0ee

https://github.com/user-attachments/assets/c3f4fe62-698b-4790-a76f-4adf5fd24721

eisneim commented 1 month ago

updated model

3 temporal conv3d at the beginning, mid, end

import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import Optional, Tuple, Union
from einops import rearrange

from diffusers import UNet2DModel
# from diffusers.models.unets.unet_2d import UNet2DOutput

class Freq_UNet2d(UNet2DModel):
    def __init__(self, tdropout=0.5, *args, **kwargs):
        super().__init__(*args, **kwargs)

        drop_p = tdropout

        # self.temporal_conv_in = nn.Conv3d(self.config.block_out_channels[0], self.config.block_out_channels[0], (3, 1, 1), padding=(1, 0, 0))
        self.temporal_conv_in = nn.Sequential(
          nn.Conv3d(self.config.block_out_channels[0], self.config.block_out_channels[0], (3, 1, 1), padding=(1, 0, 0)),
          nn.ReLU(),
          nn.Dropout(p=drop_p),
        )

        mid_channel_count = self.config.block_out_channels[-1]
        # self.temporal_conv_mid = nn.Conv3d(mid_channel_count, mid_channel_count, (3, 1, 1), padding=(1, 0, 0))
        self.temporal_conv_mid = nn.Sequential(
            nn.Conv3d(mid_channel_count, mid_channel_count, (3, 1, 1), padding=(1, 0, 0)),
            nn.ReLU(),
            nn.Dropout(p=drop_p),
        )

        self.temporal_conv_final = nn.Sequential(
            nn.Conv3d(self.config.block_out_channels[0], self.config.block_out_channels[0], (3, 1, 1), padding=(1, 0, 0)),
            nn.ReLU(),
            nn.Dropout(p=drop_p),
        )

    def temp_conv(self, sample, module, group_size):
        B, C, H, W = sample.shape

        if group_size is not None:
            assert B % group_size == 0, "group_size应该能被batch整除"
            sample_4d = rearrange(sample, "(n g) c h w -> n g c h w", g=group_size, n=int(B / group_size))
            sample_t = module(sample_4d.permute(0, 2, 1, 3, 4)).permute(0, 2, 1, 3, 4)
            sample_t = rearrange(sample_t, "n g c h w -> (n g) c h w")
        else:
            sample_t = module(sample.permute(1, 0, 2, 3)).permute(1, 0, 2, 3)
        return sample_t

    def forward(
        self,
        sample: torch.FloatTensor,
        timestep: Union[torch.Tensor, float, int],
        class_labels: Optional[torch.Tensor] = None,
        group_size=None,
    ) -> torch.Tensor:
        # 0. center input if necessary
        if self.config.center_input_sample:
            sample = 2 * sample - 1.0

        B, C, H, W = sample.shape

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

        t_emb = self.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=self.dtype)
        emb = self.time_embedding(t_emb)

        if self.class_embedding is not None:
            if class_labels is None:
                raise ValueError("class_labels should be provided when doing class conditioning")

            if self.config.class_embed_type == "timestep":
                class_labels = self.time_proj(class_labels)

            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
            emb = emb + class_emb
        elif self.class_embedding is None and class_labels is not None:
            raise ValueError("class_embedding needs to be initialized in order to use class conditioning")

        # 2. pre-process
        skip_sample = sample
        sample = self.conv_in(sample)
        sample_t = self.temp_conv(sample, self.temporal_conv_in, group_size)
        sample = sample + sample_t

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = downsample_block(
                    hidden_states=sample, temb=emb, skip_sample=skip_sample
                )
            else:
                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.mid_block(sample, emb)
        # ------------------------
        #  3d conv on time dimension
        # sample.shape: (B, ?, H/8, w/8)  [2, 896, 8, 4]
        # ------------------------
        temporal_mid = self.temp_conv(sample, self.temporal_conv_mid, group_size)
        sample = sample + temporal_mid

        # 5. up
        skip_sample = None
        for upsample_block in self.up_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
            else:
                sample = upsample_block(sample, res_samples, emb)

        temporal_final = self.temp_conv(sample, self.temporal_conv_final, group_size)
        sample = sample + temporal_final

        # 6. post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        if skip_sample is not None:
            sample += skip_sample

        if self.config.time_embedding_type == "fourier":
            timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
            sample = sample / timesteps

        return sample

data i collected

截图 2024-10-21 20-27-14 i use CogVideoX 5b i2v model to create synthetic videos, @Unified-Robots i download video from Pond5.com, envato.com, pexels and Youtube, with chrome browser extensions, i don't have script to batch do data scraping but i write chrome extensions to batch download from webpage, then use tiny-CLIP to filter out videos with big movements, then manually filter out the rest!