Open long8v opened 7 months ago
원래 논문은 CLIP은 다루지 않았는데 누가 기말과제로 올려놨다는 듯. 로직은 대충 이렇다.
이걸 positive pair에 대해 "one-hot" label을 준 뒤에 logits_per_image와 곱해줌 (아래 과정을 text, image 차원에 대해 진행)
def interpret(image, texts, model, device, start_layer=start_layer, start_layer_text=start_layer_text):
batch_size = texts.shape[0]
images = image.repeat(batch_size, 1, 1, 1)
logits_per_image, logits_per_text = model(images, texts)
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
index = [i for i in range(batch_size)]
one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
one_hot[torch.arange(logits_per_image.shape[0]), index] = 1.0
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.cuda() * logits_per_image)
model.zero_grad()
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
if start_layer == -1:
start_layer = len(image_attn_blocks) - 1
num_tokens = image_attn_blocks[0].attn_probs.shape[-1] R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device) R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens) for i, blk in enumerate(image_attn_blocks): if i < start_layer: continue grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach() cam = blk.attn_probs.detach() cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1]) grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1]) cam = grad * cam cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1]) cam = cam.clamp(min=0)mean(dim=1) R = R + torch.bmm(cam, R) image_relevance = R[:, 0, 1:]
text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())
if start_layer_text == -1:
start_layer_text = len(text_attn_blocks) - 1
num_tokens = text_attn_blocks[0].attn_probs.shape[-1] R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device) R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens) for i, blk in enumerate(text_attn_blocks): if i < start_layer_text: continue grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach() cam = blk.attn_probs.detach() cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1]) grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1]) cam = grad * cam cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1]) cam = cam.mean(dim=1) R_text = R_text + torch.bmm(cam, R_text) text_relevance = R_text
return text_relevance, image_relevance
정리하면 CLIP model에 나오는 logit에 대해 output label을 1로 두고 이에 대한 "attention map"에 대한 미분값을 attention map과 hadamard product하면서 summationg하는 과정이라고 보면 됨!
더 이해하기 쉬운 pseudo-code
def interpret(self, image, texts, model, CLS_idx, device):
batch_size = 1
inputs = self.preprocess(text=texts, images=image, padding="max_length", return_tensors="pt")
inputs = inputs.to(device)
outputs = model(**inputs, output_attentions=True)
clip_score = outputs.logits_per_image
image_attn_blocks = outputs.vision_model_output.attentions
text_attn_blocks = outputs.text_model_output.attentions
index = [i for i in range(batch_size)]
model.zero_grad()
num_tokens = text_attn_blocks[0].shape[-1]
R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].dtype).to(device)
R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, attn_map in enumerate(text_attn_blocks):
attn_map_grad = torch.autograd.grad(logits, [attn_map], retain_graph=True)[0].detach()
attn_map = attn_map.detach()
attn_map = attn_map.reshape(-1, cam.shape[-1], cam.shape[-1])
attn_map_grad = attn_map_grad.reshape(-1, grad.shape[-1], grad.shape[-1])
attn_map = attn_map * attn_map_grad
attn_map = attn_map.reshape(batch_size, -1, attn_map.shape[-1], attn_map.shape[-1])
attn_map = attn_map.clamp(min=0).mean(dim=1)
R_text = R_text + torch.bmm(cam, R_text)
text_relevance = R_text
return R_text[CLS_idx, 1:CLS_idx]
paper, code
TL;DR
Details
some notation
Relevancy initialization
relevancy map을 초기화 / 업데이트 할 거임
SA 전에는 서로 상호작용이 없어서 $R^{ii}$, $R^{tt}$는 identity. $R^{it}$는 zero tensor.
Relevancy update rules
attention map A를 가지고 relavancy를 update할 것임 전작에 따라 head 간 평균을 구하고 gradient를 사용
여기서 $\delta A$는 우리가 시각화하고 싶은 class t에 대한 output인 $y_t$를 A로 미분한 것. 평균을 취하기 전에 positive만 남겨줌(clamp)(이에 대한 이유는 딱히 없고 전작을 따라줌)
self attention에 대한 relevance 업데이트 방식은 아래와 같음 여기서 s는 query token, q는 key token임.
여기서 $R^{xx}$는 두개로 분리할 수 있는데 처음에 초기화한 $I$랑 $I$를 뺀 residual인 $\hat{R}^{xx}$임. $\hat{R}^{xx}$는 gradient를 사용하기 때문에 숫자가 절대적으로 작음. 이를 해결하기 위해 row의 합이 1이 되도록 정규화 해줌.
co-attention / cross-attention의 경우 update rule을 아래와 같이 정의해줌
Obtaining classification relevancies
[CLS] 토큰의 row에 해당하는 relevancy map을 보면 되는데 text 에 대한걸 보려면 $R^{tt}$의 첫번째 row를 보면 되고 image에 대한걸 보려면 $R^{ti}$의 첫번째 row를 보면 됨
Adaptation to attention type
Result