Vchitect / SEINE

[ICLR 2024] SEINE: Short-to-Long Video Diffusion Model for Generative Transition and Prediction
https://vchitect.github.io/SEINE-project/
Apache License 2.0
915 stars 64 forks source link

Transformer3DModel forward in training #22

Open yptang5488 opened 10 months ago

yptang5488 commented 10 months ago

Hi, I have a question that in Transformer3DModel.forward(), you split the input encoder_hidden_states into two parts, encoder_hidden_states_video and encoder_hidden_states_image, before send it to BasicTransformerBlock. I would like to ask you how you set this during training? Is the design used to make a distinction between training by images and video or allowing them at the same time? Another question is where should the text condition be placed as input during training?

(models/attention.py : class Transformer3DModel)

def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True):
        # Input
        assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
        if self.training:
            video_length = hidden_states.shape[2] - use_image_num
            hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
            encoder_hidden_states_length = encoder_hidden_states.shape[1]
            encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...]
            encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
            encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...]
            encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
            encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous()
        else:
            video_length = hidden_states.shape[2]
            hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
            encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()

        batch, channel, height, weight = hidden_states.shape
        residual = hidden_states