long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

[77] Interpretable Image Classification with Differentiable Prototype Assignment #85

Open long8v opened 1 year ago

long8v commented 1 year ago
image

paper

TL;DR

Details

Preliminaries: ProtoPNet(prototypical part network)

This Looks Like That: Deep Learning for Interpretable Image Recognition 이 이미지가 왜 이 클래스로 분류됐는가를 visualize 하고 싶음.

image

이미지 x가 주어졌을 때 CNN으로 f(x)를 뽑고 CNN output으로 H x W x D가 나옴 동시에 m개의 prototype은 $H_1$ x $W_1$ x D shape을 가지는데 이 prototype은 H, W보다 작아야 함. 이때 D차원은 같은데 height, width가 작으므로 각 prototype이 CNN patch처럼 사용되어 activation map을 구할 수 있음

image

전체적인 ProtoPNet 구조는 위와 같음. 학습은 3단계로 나누어지는데 (1) Stochastice gradient descent(SGD) of layers before last layer prototype P와 convolution filter를 학습하는 loss. 마지막 분류 loss와 prototype과 convolution output 내의 patch들의 최소거리가 같은 클래스일 경우 가까워지도록, 다른 클래스일 경우 멀어지도록 학습함

image

(2) Projection of prototypes prototype이 같은 클래스 내 가장 가까운 패치가 프로토타입이 되도록 할당함

image

(3) Convex optimization of last layer prototype과 CNN은 freeze 시키고 h에 대한 matrix를 학습

image

motivation

image

Architecture

image

각 class들은 K개의 slot을 가지고 있어서 거기에 shared prototype을 할당할 수 있음

이미지 x가 주어졌을 때 CNN(=f(x))으로 output H x W x D를 뽑음 이는 D 차원의 벡터가 H x W개 있는 것으로 해석 할 수 있음. 그 D차원의 벡터를 k번째 slot에 대해 유사도를 구해서 할당할 수 있음

Focal similarity

ProtoPNet과 같은 이전 연구들은 유사도를 아래와 같이 구했음

image image

근데 이렇게 들어가면 (1) f(x)의 patch들인 z가 모두 prototype으로 유사하도록, 즉 background에만 집중하도록 될 수 있고 (2) 이미지에서 activated된 요소에만 gradient가 가는 효과가 있음. 이를 방지하기 위해 focal similiarity를 제안함

image image

Assigning one prototype per slot

prototype을 hard하게 assign하지 않고 soft하게 주고 gradient가 흐를 수 있게 하도록 gumbel-softmax로 구함

image

이때 한 slot에 여러 prototype이 들어가지 않도록 loss를 추가적으로 줌

image