google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
228 stars 26 forks source link

Refactor layers for CLIP text encoder of SD model #30

Closed yichunk closed 1 month ago

yichunk commented 1 month ago

Refactor layers for CLIP text encoder of SD model. Tested to successfully generate a proper image as before refactoring.

There are several updates in this PR.

  1. Clean the split UNet implementation in SD
  2. Refactor CLIP to use the layers module
  3. Add GATED_SHARED type in FeedForwardType, which is used in CLIP.
  4. Add qkv_transpose_before_split filed in AttentionConfig, which is used in CLIP.
  5. Add GELU_QUICK type in ActivationType, which is used in CLIP.
  6. Add attn_fused_qkv_proj in TensorNames, which bundle qkv projection tensors in one tensor, used in CLIP.
  7. Add embedding_position in TensorNames, which is a learned position embedding, used in CLIP.

BUG=b/311216181

yichunk commented 1 month ago

@haozha111 could you re-approve the PR? I added another commit for formatting.