HanNayeoniee / papers-with-code

3 stars 0 forks source link

[week 2] RNN : truncated BPTT 구현 #3

Open HanNayeoniee opened 2 years ago

HanNayeoniee commented 2 years ago

BPTT(Backpropagation Through Time)는 '시간 방향으로 펼친 신경망의 오차역전파법' 뜻이라고 합니다.

시계열 데이터를 처리할 때 sequence 길이가 길어지면 GPU 메모리에 모든 정보를 올리기 어려워 truncated BPTT 기법을 사용한다고 합니다. 이 때 순전파는 끊지 않고, 역전파에서만 적당한 길이로 끊어 처리해준다고 합니다. 이 부분을 어떻게 구현하셨는지 궁금합니다....!

기존에 사용하던 backward() 연산으로 가능한 부분인가요?

image

왼쪽이 BPTT, 오른쪽이 truncated BPTT에 해당합니다.

sujeongim commented 2 years ago

추가 질문 : truncated BPTT가 gradient vanishing 문제도 완화할 수 있나요?