eubinecto / train-of-thoughts

learn by issue tracking
6 stars 0 forks source link

RNN의 기울기 소실 & 폭주문제의 근원이 무엇일까? #20

Closed eubinecto closed 2 years ago

eubinecto commented 2 years ago

Why?

말로 설명할 수는 있다. RNN의 순환적 구조가 기울기 소실 / 폭주문제의 근원이다. 오차역전파를 위해선 각 시간대에서의 국소 기울기를 계산해야한다 (dHt/dH(t-1)). 순환구조를 가지므로 각 국소기울기는 동일한 RNN 함수의 편미분이다. 떄문에 오차역전파과정에서 동일한 W가 길이에 비례하여 제곱이된다. 이는 입력 나열의 길이가 길어질수록 국소 기울기의 값이 기하급수적으로 작아지거나 커질 수 있음을 의미한다. 그래서 기울기 소실 & 폭주문제가 발생한다.

그런데 막상 수식으로 설명해보세요~ 라고 부탁한다면 잘 못하겠다. 오늘은 이걸 한번 고민해보자.

eubinecto commented 2 years ago

일단 RNN의 수식을 한번 써보자.

Ht = tanh(H(t-1) * Whh + Xt * Wxh)

근원인 순환구조에만 집중하기 위해 tanh는 잠시 수식에서 빼보겠다.

Ht = H(t-1) * Whh + Xt * Wxh

만약 감성분석 테스크가 목표라면, 주로 Ht를 문장의 최종 벡터표현으로 쓴다. 이로부터 binary cross entropy loss를 계산하게 될 것이다. 이 로스함수는 그냥 O라고 하겠다. 그렇다면 최종 로스는

L = O(Ht)

우리의 목표는 두 편미분: dL/dWhh, dL/Wxh를 구해서 경사도 하강으로 각 가중치를 최적화하는 것에 있다:

Whh = Whh - lr * dL/dWhh
Wxh = Wxh - lr * dL/dWxh

그렇다면 저 둘은 어떻게 계산할 수 있을까? 연쇄법칙을 통한 오차역전파를 통해 계산할 수 있다.

dL/dWhh = dL/dHt * dHt / dH(t-1) * dH(t-1) / dH(t-2) * dH(t-2) / dH(t-3) * ... dH2 / dH1 * dH1 / dWhh

즉 ...

dL/dWhh = dL/dHt * PI^{n=t}_{n=1 }(dH_n /d H_{n-1}) * dH1 / dWhh

이때,

dHn / dH{n-1} =  Whh

이므로

dL/dWhh = dL/dHt * Whh^t * dH1 / dWhh

가 된다.

이때, 동일한 가중치 W_hh가 나열의 길이 t 만큼 재곱이 된다. 그래서 기울기 폭주 / 소실이 발생하는 것이다.

나중에..

Wxh도 비슷한 이유로 인해 기울기 폭주/소실이 일어날 것. 이건 나중에 정리해보자.

eubinecto commented 2 years ago

이걸 보다 더 쉽게 설명할 수 있는 방법이 없을까?

다시 설명으로 돌아와서... 수학을 제거하고 비유로 설명할 수는 없나?

음... 핵심은 순환구조로 인한 dHt/dH(t-1)의 반복, 나열의 길이에 비례하여 제곱이 되는 W에 있다.

음.. 시간여행에 비유를 해보자.

미래에서 과거로 시간여행을 할 수 있는 타임머신을 발견했다. 들뜬 마음에 과거로 여행을 하려고 한다. 하지만 타임머신을 자세히 들여다보니 큼지막한 경고문구가 적혀있다. 1년 뒤로 시간여행을 할때마다 "망각의 터널"을 통과해야한다고 한다. 이 터널을 한번 통과할 때마다 기억의 절반을 망각하게된다고 한다.

예를 들어, 중학생 때로 돌아가기 위해 2012년으로 시간 여행을 하고 싶다면, 현재 내 기억의 (1 - 1/2^10) * 100 = 99%를 망각하게된다. , 2022년에서 2021년으로 여행할때 한번 기억의 절반을, 2021년에서 2020년으로 여행할때 또 그 기억의 절반을 .. 이런식으로, 절반을 망각하는 과정을 10번 반복하게되므로, 남게되는 기억은 현재 기억의 1 / 1024만 남게된다.

그런 타임머신이 있다면, 1년, 2년정도 되돌릴 때는 쓸모가 있겠지만, 조선시대로 시간을 되돌리기에는..가능은 하더라도 거의 모든 기억을 잃어버릴테니 별 쓸모가 없을 것이다.

RNN또한 마찬가지다. 나열의 길이가 길어짐에 따라 "망각의 터널"을 지나가는 횟수가 증가한다. 그 횟수가 증가함에 따라 최종 로스의 W에 대한 편미분값이 빠르게 줄어들거나 커진다. 이 편미분의 인수로 W^time이 포함되어 있기 때문이다.

그래서 나열의 길이가 20을 넘어가면 슬슬 학습능력이 떨어지기 시작한다. 이 논문에서 실험을 통해 밝혀냈듯이...
image
eubinecto commented 2 years ago

음.. 적절한 비유가 될지는 모르겠다. 일단 설명에 성공을 한 것 같으니 이건 여기에서 마무리.

eubinecto commented 2 years ago

망각의 터널을 지나면 왜 망각을 하는 것인가? 이것도 비유로 설명해벌 수 있을까?

eubinecto commented 2 years ago

도함수에 W^time이 포함되어서.. 인데. 이것까지 비유로 설명하는 건 힘들듯