modulabs / beyondBERT

11.5기의 beyondBERT의 토론 내용을 정리하는 repository입니다.
MIT License
60 stars 6 forks source link

Mask-Predict: Parallel Decoding of Conditional Masked Language Models #13

Closed seopbo closed 4 years ago

seopbo commented 4 years ago

어떤 내용의 논문인가요? 👋

Abstract (요약) 🕵🏻‍♂️

Most machine translation systems generate text autoregressively from left to right. We, instead, use a masked language modeling objective to train a model to predict any subset of the target words, conditioned on both the input text and a partially masked target translation. This approach allows for efficient iterative decoding, where we first predict all of the target words non-autoregressively, and then repeatedly mask out and regenerate the subset of words that the model is least confident about. By applying this strategy for a constant number of iterations, our model improves state-of-the-art performance levels for nonautoregressive and parallel decoding translation models by over 4 BLEU on average. It is also able to reach within about 1 BLEU point of a typical left-to-right transformer model, while decoding significantly faster.

이 논문을 읽어서 무엇을 배울 수 있는지 알려주세요! 🤔

Arch.

IMG_0153


Key Points

Multi-modality Problem

스크린샷 2020-07-18 오전 2 54 02 스크린샷 2020-07-18 오전 2 58 46

Decoding Speed

스크린샷 2020-07-18 오전 3 00 53

Sentence Length

스크린샷 2020-07-18 오전 3 21 15 스크린샷 2020-07-18 오전 3 10 04

Distillation

스크린샷 2020-07-18 오전 3 07 08
fairseq-train \
    --task translation_lev \
    --criterion nat_loss \
    --arch nonautoregressive_transformer \
    ...
    --decoder-learned-pos \
    --encoder-learned-pos \
        ...

같이 읽어보면 좋을 만한 글이나 이슈가 있을까요?

레퍼런스의 URL을 알려주세요! 🔗

Taekyoon commented 4 years ago

[질문]

우선 제가 이해한 내용에 대해서 설명을 하고 질문을 하겠습니다.

이해한 내용

논문에서 설명하는 바에 따르면 기존 transformer 방식에서 학습 하는 것과 같이 source와 target 데이터를 가지고 학습을 하는 방법입니다. 여기서 다른줌은 decoder 학습 방식을 기존의 autoregressive 방식이 아닌 mlm 방식으로 하는데 여기에 encoder 정보가 들어오니 conditional mlm (c-mlm)을 한다고 이해를 했습니다. 그리고 예측 시에는 decoder에 입력할 정보가 없으니 임의의 시퀀스 길이 N을 정의하고 mask 값을 100% => 90% => 80% 이렇게 줄여가면서 시퀀스 토큰들을 예측하여 non-autoregressive 모델 방식으로 예측한다(?) 라고 이해를 했습니다. (우선 시퀀스에 대한 batch inference를 하지 않는 방식이어서 parallel이라는 의미는 잘 모르겠습니다. 단, mlm형태로 inference를 하다보니 independent assumption에 따른 generation방식을 parallel이라 이야기 한다 정도로 이해하고 있습니다.)

질문 1

이해한 내용을 토대로만 봤을 때 제가 궁금한 부분은 학습 상황에서 alignment map(어텐션 맵) 형성을 어떻게 검증했을 지 궁금합니다. (seq2seq 계열 모델 학습 시에는 training loss와 별개로 alignment map 형성이 어떻게 되었는지 보는게 중요하다고 볼 수 있다고 알고 있습니다.). 왜 이런 궁금점이 생겼나면 decoder 학습 시에 mask 정보가 input으로 들어오면서 alignment 정보를 어떻게 학습할 수 있는지 예상이 되질 않는다고 봤었기 때문입니다. 물론 GLEU 스코어를 통해서 output에 대한 성능은 문제가 없습니다만, 모델 학습이나 예측 도중에 디버깅을 하는 입장에서는 저 alignment map 정보 형성이 중요할 수 있다고 보고 있거든요.

질문 2

alignment에 관한 또 다른 질문은 예측 상황에서 보이는 alignment map이 있을 텐데 이 정보에 대해서는 어떻게 해석을 해야할까요? 보통 encoder에 입력 토큰 정보와 decoder의 순차적인 입력 토큰 정보를 토대로 alignment map을 해석할 수 있는데 cmlm 방식으로 시퀀스를 생성하고 mask 정보가 개입이 되면서 기존의 alignment map 정보의 의미와는 다를 것 같습니다. 이 경우에는 어떻게 보고 해석할 수 있어야 할지 궁금합니다.

질문 3

C-MLM 방식의 모델에서는 alignment map에 대한 고려를 하지 않는게 맞을까요?

모델이 autoregressive에서 벗어난 모델이 되면서 source와 target간의 정보 연결이 어떻게 되어있는지 설명이 논문에 보이지 않았습니다. 이 부분에 대해서 공감하시면 이야기를 나눠봤으면 좋겠습니다! ^^

soeque1 commented 4 years ago

[질문]

Length를 정확히 예측하는 것도 sequence 생성에 중요한 역할을 하는 것 같습니다.

@Huffon comment => pos encoding과도 관련!!

warnikchow commented 4 years ago

[질문]

Length에 대한 prediction이 예측의 정확도에 영향을 미치는 것은 패러프레이징 계열 (번역, 스타일 변환, 음성인식 등) 에서는 효과가 있을 수 있지만, 요약이나 대화 생성 등의 태스크에서는 어떤 일관된 경향성을 배우기 어려울 수도 있을 것 같습니다. 그러한 경우, utterance 단위의 입력 뿐 아닌, context에 관한 정보가 조금 더 많이 들어간다면, 여전히 length에 관한 부분이 효과적으로 작용할 수 있을까요?

(스터디 참석을 하지 못하여 죄송합니다 ( ))

kh-kim commented 4 years ago

[질문]

이해한 내용: 기존의 auto regressive (AR) 방식은 left to right이므로, 이전 time step의 예측 품질과 상관없이 다음 time step을 이에 기반해서 해야한다는 점이 있었는데, 이 방법은 잘 예측한 단어들을 기반으로 다음 예측 iteration을 수행하기 때문에 손실을 최소화할 수 있는 것 같습니다.

질문1: 어쨌든 다음 예측 interation은 입력과 이전까지 잘 예측된 단어들을 기반으로 예측이 진행되는데, 단지 방향이 없을 뿐 AR 아닌가요? 즉, 이 논문은 AR로 인한 loss를 최소화 하기보단, 그냥 속도 개선에 중점인 논문인가요?

질문2: MLM 방식을 통해 학습한 decoder가 훨씬 expressive할 것 같은데요.. 이로 인한 posterior collapse 경향을 addressing하지 않았는지? — 즉, encoder로부터의 정보를 잘 활용하도록 학습되지 않는지? 이를 방지하기 위한 constraint나 후속 연구는 없었는지?

질문3: Masking을 기계적으로 하지 않고, 네트워크를 통해 mask token을 선정하는 방법?

예전에 본 비슷한 논문: [Xia et al., 2017]

simonjisu commented 4 years ago

[질문]

약간 alignment와 연관된 질문일 수도 있는데, 논문의 가정중에 하나가 마스크된 토큰들(Y_mask)이 입력(src)과 마스크가 안된 토큰들(Y_obs)에 대해서 조건부 독립이라고 했는데요, 이걸 어떻게 확인 할 수 있을지 궁금합니다. 각 마스크된 토큰끼리는 연관이 없는 것일지..?

seopbo commented 4 years ago

pre-trained language mode이 붙었다고 했을 때, 1같은 경우 엄청 잘되지 않을까요?

  1. mass + mask-predict
  2. bart + mask-predict
  3. mass + vanilla transformer
  4. bart + vanilla transformer
kh-kim commented 4 years ago

pre-trained language mode이 붙었다고 했을 때, 1같은 경우 엄청 잘되지 않을까요?

  1. mass + mask-predict
  2. bart + mask-predict
  3. mass + vanilla transformer
  4. bart + vanilla transformer

Enc/dec 둘 다 bert를 쓰면 되는거 아닐까요?