jungwonguk-up / KD-Font

0 stars 0 forks source link

cross attention 시 context tensor shape 문제 #14

Closed gih0109 closed 1 year ago

gih0109 commented 1 year ago

개요

Unet 내 cross attention 구현 시 내부 feature tensor 의 shape 와 condition 으로 들어오는 context tensor 간의 shape 를 일치시켜야 하는 문제

내용

unet 내 feature tensor 는 *(batch channel width height) 인 4d shape 를 지닌다. 그리고 unet block 을 거치면서 channel 이 64, 128, 256, 512 로 변경된다.

attention 연산 시, q와 k, v 로 들어오는 두 벡터간 마지막 차원 (weight * height) 은 nn.linear 또는 nn.Conv2d 를 사용하기 때문에 일치시킬 필요가 없으나 batch, channel 은 차원을 일치시켜야 한다.

context tensor 가 unet 에 들어갈때 어떤 shape 를 갖는지, channel 이 변경되어도 attetion 연산이 가능하게 context tensor shape 를 변경시켜야 하는데, 어떻게 변경시켜야 하는지 알아봐야 한다.

참고

Stabel Diffusion (paper, git) Perceiver General Perception with Iterative Attetion (paper)

gih0109 commented 1 year ago

Stable Diffusion 공식 git 에서 확인한 것:

다른 사용자가 만든 scrach 에서: https://scholar.harvard.edu/binxuw/classes/machine-learning-scratch/materials/stable-diffusion-scratch

Stable diffusion git 내 구조가 복잡해서 정확하게 어떻게 변경하는지 더 찾아봐야한다.

gih0109 commented 1 year ago

Frozen CLIP Text Encoder 는 다음 코드에 의해서 작동된다. 위치: stable-diffusion/ldm/modules/encoders

/modules.py

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length
        self.freeze()

    def freeze(self):
        self.transformer = self.transformer.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        outputs = self.transformer(input_ids=tokens)

        z = outputs.last_hidden_state
        return z

    def encode(self, text):
        return self(text)

여기서 return 값은 CLIPTextModel.last_hidden_state 이다.

공식문서 를 참고하면 Input 은 torch.LongTensor of shape (batch_size, sequence_length) 로 되어있으며, last_hidden_state 는 torch.FloatTensor of shape (batch_size, sequence_length, hidden_size) 로 적혀있다.

즉 CLIP 에서 나오는 context shape 는 batch_size 가 포함된 3d vector 인 것을 확인할 수 있다

gih0109 commented 1 year ago

개요

cross attention 시 context tensor 의 shape 가 어떻게 입력되고 변경 후 계산하는지 stable diffusion git을 통해 검증할 수 있었다.

내용

stable diffusion Unet 내 Spatial Transformer Block 에서 latent tensor 와 context tensor 입력시 shape 는 다음과 같다.

shape 변경 및 attention 계산을 위해 사용된 패키지는 다음과 같다.

from torch import einsum
from einops import rearrange, repeat


Unet 내 tensor 의 shape 변경 및 계산 흐름은 다음과 같다.

1. latent tensor shape 변환

class SpatialTransformer 에서

x = rearrange(x, 'b c h w -> b (h w) c')

에 의해 latent tensor 의 shape 는 (batch, channel, height, weight) -> (batch, height X weight , channel) 이 된다. 예시) (4, 16, 32, 32) -> (4, 1024, 16)

2. query, key, value 변환

class CrossAttention 에서

inner_dim = dim_head * heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

에 의해 latent tensor 는 query 로 변환되고, context tensor 는 key, value 로 변환된다. shape 는 다음과 같다

예시) dim_head = 64, heads = 8 일 경우

3. query, key, value 의 shape 변환

class CrossAttetion 에서

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

에 의해 query 와 key, value 의 shape 가 변경된다

예시)

4. 유사도 계산

class CrossAttetion 에서

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim=-1)

에 의해 Scaled Dot Product Attention 연산이 수행되며 유사도 vector 'sim' 이 나온다.

예시)

5. Attention value 계산

`class CrossAttention' 에서

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)

에 의해 Attention value 를 계산한 뒤 shape 를 변경한다.

예시)

6. latent tensor 의 크기에 맞게 출력 크기 변경

다시 latenct tensor 의 크기에 맞게 Linear layer 를 통과시켜 출력 크기를 맞춘다. class CrossAttention 에서

self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

에 의해 (batch, height X weight, inner_dim) -> (batch, height X weight, channel) 로 shape 가 변경된다.

class SpatialTransformer 에서

x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

에 의해 out 의 shape 는 (batch, height X weight , channel) -> (batch, channel, height, weight) 으로 처음 latent tensor 와 shape 가 동일하게 변경된다.

예시) out = (4, 1024, 512) -> (4, 1024, 16) -> (4, 16, 32, 32)