Open HanNayeoniee opened 2 years ago
BPTT(Backpropagation Through Time)는 '시간 방향으로 펼친 신경망의 오차역전파법' 뜻이라고 합니다.
시계열 데이터를 처리할 때 sequence 길이가 길어지면 GPU 메모리에 모든 정보를 올리기 어려워 truncated BPTT 기법을 사용한다고 합니다. 이 때 순전파는 끊지 않고, 역전파에서만 적당한 길이로 끊어 처리해준다고 합니다. 이 부분을 어떻게 구현하셨는지 궁금합니다....!
기존에 사용하던 backward() 연산으로 가능한 부분인가요?
backward()
왼쪽이 BPTT, 오른쪽이 truncated BPTT에 해당합니다.
추가 질문 : truncated BPTT가 gradient vanishing 문제도 완화할 수 있나요?
BPTT(Backpropagation Through Time)는 '시간 방향으로 펼친 신경망의 오차역전파법' 뜻이라고 합니다.
시계열 데이터를 처리할 때 sequence 길이가 길어지면 GPU 메모리에 모든 정보를 올리기 어려워 truncated BPTT 기법을 사용한다고 합니다. 이 때 순전파는 끊지 않고, 역전파에서만 적당한 길이로 끊어 처리해준다고 합니다. 이 부분을 어떻게 구현하셨는지 궁금합니다....!
기존에 사용하던
backward()
연산으로 가능한 부분인가요?