Aidenzich / road-to-master

A repo to store our research footprint on AI
MIT License
17 stars 5 forks source link

Scalable Diffusion Models with Transformers #53

Open Aidenzich opened 1 month ago

Aidenzich commented 1 month ago

Scalable Diffusion Models with Transformers

Due to the remarkable achievements of Google AlphaFold 3, it also uses DiT, which combines Diffusion and Transformers. So let's take notes on the content related to this technology:

Screenshot 2024-05-20 at 4 18 44 AM

Here, let’s review the content of DDPM: Screenshot 2024-05-20 at 4 43 59 AM

Noise is represented as $X_T$, and $Σ$ is the covariance in the Reverse Process distribution. Screenshot 2024-05-20 at 4 46 10 AM

That is:

Screenshot 2024-05-20 at 4 48 02 AM

Patchify

The Patchify module in the image comes from the design concept of Vision Transformer (ViT), which segments the input image into smaller patches and treats each patch as a token for processing. It divides the input image into smaller fixed-size patches, such as $16 \times 16$, and then flattens each patch into a vector. Specifically, for an input image of shape $H \times W \times C$, the Patchify module divides it into $N$ patches, where each patch is of size $P \times P \times C$, forming $N = \frac{H}{P} \times \frac{W}{P}$ patches.

Embed:

The Embed module in this architecture functions similarly to positional embedding in the Vision Transformer (ViT), embedding specific information into a fixed-length vector space so it can be processed along with other input tokens of the transformer model. Let’s compare the differences between the two:

Positional Embedding in ViT

In ViT, images are divided into small patches, each patch is flattened and embedded into a fixed-length vector. Since the transformer model itself lacks positional information, positional embedding is added to represent each patch's position in the original image. These positional embedding vectors are added to the patch embedding vectors to form the final input sequence.

Embed in DiT

In the DiT (Diffusion Transformer) architecture, the Embed module similarly embeds conditional information (such as timestamp and label) into a fixed-length vector space. These embedded vectors are added as extra tokens to the input sequence, allowing the transformer model to use this conditional information to guide the generation process.

Working Principle of embed

  1. Positional Embedding in ViT:

    • Purpose: Provide positional information.
    • Operation: Add positional embedding vectors to the patch embedding vectors.
  2. Embed Module in DiT:

    • Purpose: Embed conditional information (such as timestamp and label) into vector space so the transformer model can use this information.
    • Operation: Convert conditional information into embedding vectors and process them along with other input tokens.

This embedding method ensures that the transformer can handle additional information (such as timesteps and labels), thereby improving the accuracy and effectiveness of the generation process.

Aidenzich commented 1 month ago

DiT Block

Screenshot 2024-05-20 at 5 21 36 AM

The different DiT Blocks, highlighting their features, how they handle conditioning information, computational cost, and specific descriptions:

Block Type Features Handling of Conditioning Information Computational Cost (Gflops) Specific Description
DiT Block with adaLN-Zero - Contains multi-head self-attention and pointwise feedforward network
- Uses adaptive layer normalization (adaLN-Zero)
- Initialized as identity function
- Handles conditioning information through adaptive layer normalization
- Conditioning information affects the scale and shift parameters of layer normalization
Lowest Uses adaptive layer normalization (adaLN-Zero), a technique where residual blocks are initialized as identity functions. It influences the scale (γ) and shift (β) parameters of layer normalization through regressing these parameters from conditioning information (such as timestamp and label). This method incurs low computational cost since it only needs to regress scale and shift parameters, making it suitable for scenarios with limited computational resources.
DiT Block with Cross-Attention - Contains multi-head self-attention and pointwise feedforward network
- Adds multi-head cross-attention layer
- Handles conditioning information as a separate sequence processed through multi-head cross-attention Moderate Adds a multi-head cross-attention layer in addition to the multi-head self-attention and pointwise feedforward network, to better handle conditioning information. Conditioning information is treated as a separate sequence and interacts with the main input sequence through the multi-head cross-attention layer. This method increases computational cost by about 15% due to the addition of the cross-attention layer.
DiT Block with In-Context Conditioning - Contains multi-head self-attention and pointwise feedforward network
- Appends conditioning information in the sequence dimension
- Appends conditioning information to the end of the input sequence, similar to cls tokens in ViT
- Conditioning information is treated as additional tokens
Lowest Appends conditioning information to the end of the input sequence in the sequence dimension, similar to the cls tokens in Vision Transformer. This method treats conditioning information as additional tokens and processes them along with the other input tokens within the transformer block. Since this method only increases the number of input tokens, it has minimal impact on computational cost, almost not adding extra computational load.

Performance Comparison

Screenshot 2024-05-20 at 5 39 14 AM