dhkim0225 / 1day_1paper

read 1 paper everyday (only weekday)
54 stars 1 forks source link

[99] Fisher SAM: Information Geometry and Sharpness Aware Minimisation #129

Open dhkim0225 opened 2 years ago

dhkim0225 commented 2 years ago

paper

ASAM (#128) 은 다음과 같은 문제가 있다고 한다.

However, this approach to determining the flatness ellipsoid of interest is heuristic 
and might severely degrade the neighborhood structure.

읭...?

we consider the information geometry of the model parameter space 읭...?

SAM 의 euclidean ball 을 Fisher information 를 유도한 ellipsoids로 대체시킨다고 한다.

Fisher Information

수학부터 해야한다. 참고자료

  1. https://www.youtube.com/watch?v=pneluWj-U-o&t=261s
  2. https://yukyunglee.github.io/Fisher-Information-Matrix/
  3. https://www.youtube.com/watch?v=m62I5_ow3O8
  4. https://stats.stackexchange.com/questions/174600/help-with-taylor-expansion-of-log-likelihood-function

Fisher information matrix 는 likelihood function 을 2번 미분한 값. likelihood 가 얼마나 curvature 를 갖고 있는 지 알 수 있다. Fisher information 이 크다면,

  1. curvature 가 크고,
  2. peak 값이 크고,
  3. 더 많은 constraining data 가 있다는 뜻이다.

가우시안에서 Fisher matrix 의 역수가 covariance matrix 와 같다. 딥러닝에서 hessian matrix 와 같다고 봐도 무방하기 때문에 효율을 위해 Fisher matrix 로 근사해서 문제를 자주 푼다고 한다.

SAM (#127)

image image

ASAM (#128)

image image

Fisher SAM

두 weight $\theta$ $\theta^{'}$ 사이의 거리를 잴 때, L2 로 재면 문제가 많다 그래서, KL div 를 쓴다고 한다. image

KL Div 는 $\epsilon$ 값이 작을 때, ${\huge d(\theta + \epsilon, \theta) \approx \epsilon^T F(\theta)\epsilon}$ 이다. $F(\theta)$ 는 Fisher Information matrix이다. (appendix B 참고)

그래서 해당 문제는, 다음과 같이 정리된다. image

요 친구를 first-order approximated objective 를 이용하여 풀어낸다 ${\huge l(\theta + \epsilon) \approx l(\theta) + \Delta l(\theta)^{\top} \epsilon }$ (SAM 과 비슷). 이렇게 하면 quadratic constraineed linear programming 문제가 나오는데, 라그랑지안은 다음과 같이 나오고, image

epsilon 값에 대해 미분취한 값이 0이다 놓고 전개를 하면, epsilon 값을 다음과 같이 뽑아낼 수 있다. image

해당 값을 ellipsoidal constraint (from KKT condition) 에 넣어주면, optimal lambda 값이 결정되고, 최종 수식은 다음과 같다. image

요 값을 이용해서 파리미터를 업데이트 해주면 된다. image

문제는 아직 남아있다. $F(\theta)$ 는 여전히 비싼 computation 을 요구한다. 해서, minibatch를 활용하여 다음과 같이 사용한다. image image

image

Results

image image