fd873630 / RNN-Transducer

RNN-Transducer for korean
38 stars 3 forks source link

아마 rnnt_loss는 bptt를 걱정하지 않아도 될지 모르겠습니다. #6

Closed YooSungHyun closed 1 year ago

YooSungHyun commented 1 year ago

일단 bptt는 근간이, input lengths와 target lengths를 알아야합니다. '나는 밥을 먹었다' 일때, '나는'에서 역전파 1번 '나는 밥을'에서 역전파 1번 '나는 밥을 먹었다' 에서 역전파 1번 총 3번을 진행해야합니다.

하지만 rnnt_loss를 사용하면, 위에 3개가 모두 다 적용된 최종 loss 1개만 return됩니다. (물론 batch_reduction mean or sum 해서요)

그러면 이론상 input lengths가 1000개짜리를 rnnt_loss로 역전파 하려면, batch size 1 logits가 (1,1000,500,101)을 가정할때 1000번에 대한 for문을 돌면서 각각 backward해야하는데, 그러려면 backward 속도가 너무 느려져서 실제로 논문 구현 그대로를 해야될지 모릅니다. rnnt_loss가 논문구현체로 알고있는데, 좀 납득이 안되서 열심히 찾아봤는데,

image

이런 소스가 있더군요? 내부에 trans_net와 pred_net의 gradient가 memory에 적재되어있으면, 그라디언트를 활용하는 어떠한 소스였습니다. time sequence의 전체를 집어넣어야 하므로 어쩌면, 각 sequence에 맞는 trans_net과 pred_net의 각각의 grad를 계산해서 cuda로 메모리에 직적용 시키는 무언가가 있는지도 모르겠습니다. torch 구현체를 보면, backward도 재정의되어있는 것 같구요...

어쩌면 제가 너무 앞서서 걱정하지 않았나 싶은지도 모르겠습니다. 제가 C++를 탁월하게 잘하진 않아서 눈대중으로 분석한 결과이니, 참고만 하시기 바랍니다.

YooSungHyun commented 1 year ago

확인하시고 댓글 달아주시면 해당 이슈는 이슈가 아니라 정보공유 차원이었으므로 close 하겠습니다.

fd873630 commented 1 year ago

공유 감사합니다!