NewPlus / Multi_View_Summarization

Improving Dialogue Summarization with Speaker-aware and Topic-aware Contrastive Learning
https://www.dbpia.co.kr/journal/articleDetail?nodeId=NODE11488055
2 stars 1 forks source link

Speaker-Aware Coding #2

Closed NewPlus closed 1 year ago

NewPlus commented 1 year ago

Speaker-Aware

enc_speaker = encoder_outputs[0][0][1] for i in range(1, encoder_outputs[0].shape[0]): enc_speaker = torch.row_stack((enc_speaker, encoder_outputs[0][i][1])) self.speaker_aware(enc_speaker=enc_speaker, # speaker의 representation list speaker_input_ids=speaker_input_ids, # speaker의 id 값 list ctr_margin=1, # ctrastive learning 시, margin 값 bench_speaker=0 # P01을 기준점 = 0번째 Speaker )

- speaker_input_ids : Speaker Token의 input_ids를 int list 형태로 저장
- unique_speaker_ids : Speaker Token의 input_ids 중복 제거
- enc_speaker : Encoder를 통과한 Hidden State 중에서 Speaker([][][1])의 Representation들을 저장 -> Turn안에 있는 모든 Speaker의 Representation들을 저장하는 list
- self.speaker_aware : Speaker의 Representation 기준 Contrastive Learning Loss를 구하는 Function
> enc_speaker : speaker의 representation list
> speaker_input_ids : speaker의 id 값 list
> ctr_margin : ctrastive learning 시, margin 값
> bench_speaker : 기준점 Speaker의 인덱스 = 예) 0번째 Speaker
## Speaker-Aware Function
```Python
    def speaker_aware(self, enc_speaker, speaker_input_ids, ctr_margin, bench_speaker):
        negative_sim, positive_sim = torch.empty(1), torch.empty(1)
        num_turn = len(enc_speaker)

        for i in range(1, num_turn):
            sim = torch.dist(enc_speaker[i], enc_speaker[bench_speaker], p=2.0)
            # sim = torch.dist(enc_speaker[i], enc_speaker[bench_speaker], p=1.0)
            if speaker_input_ids[i] == speaker_input_ids[bench_speaker]:
                positive_sim = torch.row_stack((positive_sim, sim))
            else:
                negative_sim = torch.row_stack((negative_sim, sim))

        i = 0
        positive_sim = torch.cat([positive_sim[0:i], positive_sim[i+1:]])
        negative_sim = torch.cat([negative_sim[0:i], negative_sim[i+1:]])
        positive_sim_idx = positive_sim.shape[0]
        negative_sim_idx = negative_sim.shape[0]

        softmax_sim = torch.cat([positive_sim, negative_sim])
        softmax_sim_out = nn.functional.softmax(softmax_sim, dim=0)

        relu = nn.ReLU()
        positive_softmax = softmax_sim_out[random.randrange(positive_sim_idx)]
        negative_softmax = softmax_sim_out[random.randrange(positive_sim_idx, positive_sim_idx+negative_sim_idx)]

        res = relu(ctr_margin - (positive_softmax - negative_softmax))