guoyww / AnimateDiff

Official implementation of AnimateDiff.
https://animatediff.github.io
Apache License 2.0
10.29k stars 843 forks source link

⭐⭐⭐⭐⭐ Claude 3 - architectual redesign POC suggestion 100X - faster training / inference times / reduced parameters / mem footprint #308

Open johndpope opened 6 months ago

johndpope commented 6 months ago

https://paperswithcode.com/method/depthwise-separable-convolution#:~:text=While%20standard%20convolution%20performs%20the,a%20linear%20combination%20of%20the

current setup Screenshot from 2024-03-16 06-33-04

Screenshot from 2024-03-16 06-22-18 https://youtu.be/vVaRhZXovbw

Among the architecture redesign options mentioned, using efficient blocks, specifically depthwise separable convolutions, is probably the easiest to implement as a proof of concept (POC). Depthwise separable convolutions are a simple yet effective technique for reducing the number of parameters in convolutional layers while maintaining the model's expressiveness.

In a standard convolutional layer, the convolution operation is performed across all input channels, resulting in a large number of parameters. Depthwise separable convolutions, on the other hand, break down the convolution into two separate steps:

Depthwise Convolution: Each input channel is convolved independently using a separate set of filters. This means that each filter only operates on a single channel, reducing the number of parameters. Pointwise Convolution: A 1x1 convolution is applied to combine the outputs of the depthwise convolution step. This step mixes the information across channels and produces the final output. By separating the spatial and channel-wise convolutions, depthwise separable convolutions can significantly reduce the number of parameters compared to standard convolutions. This reduction in parameters can lead to faster training and inference times, as well as lower memory requirements.

To implement depthwise separable convolutions in AnimateDiff as a POC, you can start by replacing the standard convolutional layers in the model with depthwise separable convolutions. This can be done in the InflatedConv3d class or by creating a new class specifically for depthwise separable convolutions.

Here's a simple example of how you can modify the InflatedConv3d class to use depthwise separable convolutions:

python

Copy code

Original Code

InflatedConv3d

class InflatedConv3d(nn.Conv2d):
    def forward(self, x):
        video_length = x.shape[2]

        x = rearrange(x, "b c f h w -> (b f) c h w")
        x = super().forward(x)
        x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)

        return x

InflatedDWSConv3d

class InflatedDWSConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.depthwise_conv = nn.Conv3d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels)
        self.pointwise_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)
        return x

In this modified version, the InflatedDWSConv3d class performs a depthwise convolution followed by a pointwise convolution. The groups parameter in the depthwise convolution is set to the number of input channels, ensuring that each channel is convolved independently.

To use this modified class in AnimateDiff, you can replace the occurrences of InflatedConv3d with InflatedDWSConv3d in the relevant parts of the code, such as the UNet3DConditionModel and the various block classes.

By making this change, you can create a POC of AnimateDiff with reduced parameters in the convolutional layers. You can then evaluate the impact of using depthwise separable convolutions on the model's performance, generation quality, and efficiency.

Keep in mind that this is just one example of how you can modify the architecture for simplification. Depending on your specific requirements and constraints, you may need to explore other techniques or combinations of techniques to achieve the desired balance between model complexity and performance.

PART 2 i ask if the InflatedGroupNorm could be fixed too. (they look quite similar) could this also be upgraded? class InflatedGroupNorm(nn.GroupNorm): def forward(self, x): video_length = x.shape[2] x = rearrange(x, "b c f h w -> (b f) c h w") x = super().forward(x) x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) return x Yes, you can certainly upgrade the InflatedGroupNorm class to make it more efficient and compatible with the depthwise separable convolutions. Here's an updated version of the InflatedGroupNorm class:

python

Copy code

class InflatedGroupNorm(nn.GroupNorm):
    def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
        super().__init__(num_groups, num_channels, eps, affine)

    def forward(self, x):
        batch_size, num_channels, video_length, height, width = x.size()
        x = x.view(batch_size, num_channels, -1)
        x = super().forward(x)
        x = x.view(batch_size, num_channels, video_length, height, width)
        return x

In this upgraded version, we make the following changes:

We remove the explicit rearrangement of the input tensor using the rearrange function from the einops library. Instead, we reshape the tensor using the view method. We first reshape the input tensor x from (batch_size, num_channels, video_length, height, width) to (batch_size, num_channels, -1). This effectively flattens the spatial dimensions (video_length, height, width) into a single dimension. We then apply the group normalization operation using the super().forward(x) call. This operation normalizes the flattened spatial dimensions independently for each group of channels. After the group normalization, we reshape the tensor back to its original shape (batch_size, num_channels, video_length, height, width) using the view method. By reshaping the tensor in this way, we avoid the need for explicit rearrangement using the rearrange function, which can be more efficient in terms of memory and computation.

This upgraded InflatedGroupNorm class is compatible with the depthwise separable convolutions used in the modified InflatedDWSConv3d class. The group normalization operation is applied independently to each group of channels, which aligns well with the depthwise convolution step where each channel is convolved separately.

You can replace the occurrences of the original InflatedGroupNorm class with this upgraded version in the AnimateDiff codebase. This change, along with the introduction of depthwise separable convolutions, can contribute to simplifying the model and reducing the number of parameters.

Remember to test the upgraded InflatedGroupNorm class to ensure that it produces the expected results and integrates well with the rest of the AnimateDiff model.

UPDATE -- from youtube video above - it's claimed a 9x improvement on 2d - I asked Claude3 to work on improvements on 3d.

Screenshot from 2024-03-16 07-07-56

For Block 0, the computational complexity is reduced by approximately 34.1 times, while for Blocks 2 and 3, it is reduced by approximately 134.8 times compared to using standard 3D convolutions.

Screenshot from 2024-03-16 07-07-44

it's saying overall - improvement would be 100X Screenshot from 2024-03-16 07-11-49

UPDATE - it seems this repo is only one that can support this - as this repo doesn't provide training scripts (unless I'm mistaken - theres no checkpoint saved in this repo) https://github.com/tumurzakov/AnimateDiff

johndpope commented 6 months ago

UPDATE - I asked Claude to provide a simple test case ( I think it removed the 4d channel - so this test may only reflect the 3d - 9x gains in file size) https://gist.github.com/johndpope/7ec726fb38e8df808256ccce6b6a70d9 Original AnimateDiff time: 19.8443 seconds Optimized AnimateDiff time: 6.1337 seconds Original output shape: torch.Size([1, 1280, 109, 32, 32]) Optimized output shape: torch.Size([1, 1280, 109, 32, 32])

updated with savd pth file

11M Mar 18 09:39 optimized_model_checkpoint.pth
92M Mar 18 09:39 original_model_checkpoint.pth

CPU

Original AnimateDiff time: 19.8443 seconds Optimized AnimateDiff time: 6.1337 seconds Original output shape: torch.Size([1, 1280, 109, 32, 32]) Optimized output shape: torch.Size([1, 1280, 109, 32, 32])

https://gist.githubusercontent.com/johndpope/7ec726fb38e8df808256ccce6b6a70d9/raw/12a31a620b715d82ab481a282cf28367907dfa7c/test.py

GPU

https://gist.github.com/johndpope/d71077f511d242af0e51266263b80e74

Original AnimateDiff time for 27 batches: 0.6368 seconds Optimized AnimateDiff time for 27 batches: 0.1000 seconds

UPDATES over here- > https://github.com/tumurzakov/AnimateDiff/issues/18