Open long8v opened 2 years ago
https://github.com/facebookresearch/SLIP
트랜스포머는 torch.nn.MultiheadAttention 사용함
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
이해가 좀 안되지만 CLIP 논문 읽을때 다시 보자
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
vision_width: int,
vision_model: nn.Module,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
**kwargs,
):
super().__init__()
self.context_length = context_length
self.vision_width = vision_width
self.visual = vision_model
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
# 이미지와 캡션 임베딩을 embed_dim(=논문에서 512)로 줄여주는 것. 왜 nn.Linear안썼지?
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def encode_image(self, image):
x = self.visual(image)
x = x @ self.image_projection
return x
def encode_text(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
vision model로 임베딩 받고 linear -> bn / relu -> linear -> bn / relu -> linear로 내보냄
class SIMCLR(nn.Module):
def __init__(self,
# vision
vision_width: int,
vision_model: nn.Module,
# ssl
ssl_mlp_dim: int,
ssl_emb_dim: int,
**kwargs,
):
super().__init__()
self.vision_width = vision_width
self.visual = vision_model
self.image_mlp = self._build_mlp(in_dim=vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim)
def _build_mlp(self, in_dim, mlp_dim, out_dim):
return nn.Sequential(OrderedDict([
("layer1", nn.Linear(in_dim, mlp_dim)),
("bn1", nn.SyncBatchNorm(mlp_dim)),
("relu1", nn.ReLU(inplace=True)),
("layer2", nn.Linear(mlp_dim, mlp_dim)),
("bn2", nn.SyncBatchNorm(mlp_dim)),
("relu2", nn.ReLU(inplace=True)),
("layer3", nn.Linear(mlp_dim, out_dim)),
]))
def encode_image(self, image):
x = self.visual(image)
return x
def forward(self, aug1, aug2):
h1 = self.visual(aug1)
h2 = self.visual(aug2)
aug1_embed = self.image_mlp(h1)
aug2_embed = self.image_mlp(h2)
return {'aug1_embed': aug1_embed,
'aug2_embed': aug2_embed}
이미지 임베딩과 텍스트 임베딩을 받은 뒤에 정규화(p=2, dim=-1이니 마지막 차원(=임베딩차원)에 각각 제곱의 합으로 나눈 것)를 적용한다 현 GPU의 이미지 임베딩과 모든 GPU의 텍스트 임베딩의 내적을 구해 코사인 유사도를 구하고 이를 이미지의 logit으로 사용(vice versa) 이때 loss는 실제 레이블과 각 이미지에 대한 logit 간의 crossentropy 이때 label이 첫번째 GPU라면 [0, 1, 2, 3, 4 ... ] 두번째 GPU라면 [배치사이즈, 배치사이즈 + 1, ... ] 이렇게 달아지는데 무슨 의민지 모르겠음 아...input을 넣을 때, 인풋을 셔플링 없이 클래스1, 클래스 2, 클래스3, ... 이런식으로 넣나보다. 그래서 그냥 label 값을 어디서 가져오는게 아니라 0번째면 0번째 클래스 이런식으로 들어가는듯? 근데 그렇게 되려면 하나의 배치에 모든 클래스가 들어가야하는 것 같음. 코드상.
class CLIPLoss(nn.Module):
def __init__(self):
super().__init__()
self.labels = None
self.last_local_batch_size = None
def forward(self, outputs):
image_embed = outputs['image_embed']
text_embed = outputs['text_embed']
logit_scale = outputs['logit_scale']
local_batch_size = image_embed.size(0)
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * utils.get_rank() + torch.arange(
local_batch_size, device=image_embed.device
)
self.last_local_batch_size = local_batch_size
# normalized features
image_embed = F.normalize(image_embed, dim=-1, p=2)
text_embed = F.normalize(text_embed, dim=-1, p=2)
# gather features from all GPUs
image_embed_all, text_embed_all = \
utils.all_gather_batch([image_embed, text_embed])
# cosine similarity as logits
logits_per_image = logit_scale * image_embed @ text_embed_all.t()
logits_per_text = logit_scale * text_embed @ image_embed_all.t()
loss = (F.cross_entropy(logits_per_image, self.labels) + \
F.cross_entropy(logits_per_text, self.labels)) / 2
# compute accuracy
with torch.no_grad():
pred = torch.argmax(logits_per_image, dim=-1)
correct = pred.eq(self.labels).sum()
acc = 100 * correct / local_batch_size
return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc}
어그멘트된 이미지 1과 이미지 2를 정규화 하고, 모든 이미지를 불러 온다음 transpose 후 matmul로 코사인유사도 구하고 tau로 나눠줌(temperatue) 이걸 logit으로 두고 label과 크로스엔트로피 로스. masks가 있으면 label에 mask 값을 빼주는데 아마...자기자신/padding된걸 label이 아니게 하는게 아닐까(아마 마이너스 무한대로 mask돼 있을듯)?
class SIMCLRLoss(nn.Module):
"""
This is the SimCLR loss in https://arxiv.org/abs/2002.05709
The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and
the memory layout that can be reshaped into shape (2, batch_size, embedding_dim).
This memory layout is consistent with the SimCLR collator in
https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py
Config params:
temperature (float): the temperature to be applied on the logits
"""
def __init__(self, temperature=0.1):
super().__init__()
self.tau = temperature
self.labels = None
self.masks = None
self.last_local_batch_size = None
def forward(self, outputs):
q_a = outputs['aug1_embed']
q_b = outputs['aug2_embed']
q_a = F.normalize(q_a, dim=-1, p=2)
q_b = F.normalize(q_b, dim=-1, p=2)
local_batch_size = q_a.size(0)
k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b])
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * utils.get_rank() + torch.arange(
local_batch_size, device=q_a.device
)
total_batch_size = local_batch_size * utils.get_world_size()
self.masks = F.one_hot(self.labels, total_batch_size) * 1e9
self.last_local_batch_size = local_batch_size
logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau
logits_aa = logits_aa - self.masks
logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau
logits_bb = logits_bb - self.masks
logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau
logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau
loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels)
loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels)
loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples
# compute accuracy
with torch.no_grad():
pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1)
correct = pred.eq(self.labels).sum()
acc = 100 * correct / local_batch_size
return {'loss': loss, 'ssl_loss': loss, 'ssl_acc': acc}
CLIP + SIMCLR의 합
class SLIPLoss(nn.Module):
def __init__(self, ssl_loss, ssl_scale):
super().__init__()
self.clip_loss = CLIPLoss()
self.ssl_loss = ssl_loss
self.ssl_scale = ssl_scale
def forward(self, outputs):
clip_loss_dict = self.clip_loss(outputs)
clip_loss = clip_loss_dict['clip_loss']
clip_acc = clip_loss_dict['clip_acc']
ssl_loss_dict = self.ssl_loss(outputs)
ssl_loss = ssl_loss_dict['ssl_loss']
ssl_acc = ssl_loss_dict['ssl_acc']
return {'loss': clip_loss + self.ssl_scale * ssl_loss,
'clip_loss': clip_loss,
'clip_acc': clip_acc,
'ssl_loss': ssl_loss,
'ssl_acc': ssl_acc}
아래와 같이 prompt를 정의
{
"food101": [
"a photo of {}, a type of food."
],
"cifar10": [
"a photo of a {}.",
"a blurry photo of a {}.",
"a black and white photo of a {}.",
"a low contrast photo of a {}.",
"a high contrast photo of a {}.",
"a bad photo of a {}.",
"a good photo of a {}.",
...
],
arxiv, code
problem : Self-supervised-Learning이 CLIP 모델 구조에서도 잘 작동하는지 실험 solution : 이미지와 텍스트에 대해서는 contrastive learning을 하고, 이미지에 대해서는 self-supervision learning을 하여 로스를 합하여 학습. result : linear prediction(=representation에 FCN 붙여서 분류하는것. linear만 학습함. BERT에서 언급하는 'feature-based learning'), zero-shot learning, end-to-end finetuning에서 SOTA