MiuLab / PLM-ICD

PLM-ICD: Automatic ICD Coding with Pretrained Language Models
Apache License 2.0
55 stars 20 forks source link

Segment Pooling Implementation #2

Closed mgh1 closed 2 years ago

mgh1 commented 2 years ago

Dear Authors,

First I would like to thank you for your excellent contribution to the literature.

My question is: where can I find the "Segment Pooling" from section 4.2 of the paper as an implementation in this repo?

The description of this technique in the paper is:

The segment pooling mechanism first splits the whole document into segments that are shorter than the maximum length, and encodes them into segment representations with PLMs. After encoding segments, the segment representations are aggregated as the representations for the full document.

In modeling_bert.py I do see pooled_output code but I am not seeing exactly how it matches up with the description above. Can you please help me understand it better in terms of implementation?

chaoweihuang commented 2 years ago

Hi,

Thank you for your interest in our work! The implementation actually resides in run_icd.py. Specifically, if you look at the function data_collator https://github.com/MiuLab/PLM-ICD/blob/764ca73473df3f948857fb52f4db2e65b5d8c995/src/run_icd.py#L315-L350, you'll notice that we split the input_ids into segments of length args.chunk_size. So the shape of the input_ids is actually (batch_size, num_chunks, chunk_size). We send this input_ids into BERT to encode them separately, and then gather the outputs with our attention mechanisms.

Hope this answers your question. I'll close this for now, but feel free to reopen it if you have any question.

Chao-Wei