Closed NewPlus closed 1 year ago
from transformers.modeling_utils import PreTrainedModel from transformers.models.bart.modeling_bart import ( BartPretrainedModel, Seq2SeqModelOutput, BaseModelOutput, Seq2SeqModelOutput, BartConfig, BartEncoder, BartDecoder, BART_INPUTS_DOCSTRING, _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC, _EXPECTED_OUTPUT_SHAPE, add_code_sample_docstrings, add_start_docstrings_to_model_forward, shift_tokens_right )
이렇게 해서 BartModel 코드의 일부만 수정하고(Customize) 나머지 Transformers 코드는 라이브러리의 코드를 그대로 사용하기 위함
class BartModel(BartPretrainedModel): _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] def __init__(self, config: BartConfig): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) self.encoder = BartEncoder(config, self.shared) self.decoder = BartDecoder(config, self.shared) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.shared def set_input_embeddings(self, value): self.shared = value self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC, expected_output=_EXPECTED_OUTPUT_SHAPE, ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Seq2SeqModelOutput]: # different to other models, Bart automatically creates decoder_input_ids from # input_ids if no decoder_input_ids are provided if decoder_input_ids is None and decoder_inputs_embeds is None: if input_ids is None: raise ValueError( "If no `decoder_input_ids` or `decoder_inputs_embeds` are " "passed, `input_ids` cannot be `None`. Please pass either " "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." ) decoder_input_ids = shift_tokens_right( input_ids, self.config.pad_token_id, self.config.decoder_start_token_id ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )
speaker_input_ids = [int(input_ids[i][2]) for i in range(len(input_ids))] unique_speaker_ids = list(set(speaker_input_ids))
enc_speaker = encoder_outputs[0][0][1] for i in range(1, encoder_outputs[0].shape[0]): enc_speaker = torch.row_stack((enc_speaker, encoder_outputs[0][i][1])) self.speaker_aware(enc_speaker=enc_speaker, # speaker의 representation list speaker_input_ids=speaker_input_ids, # speaker의 id 값 list ctr_margin=1, # ctrastive learning 시, margin 값 bench_speaker=0 # P01을 기준점 = 0번째 Speaker )
- speaker_input_ids : Speaker Token의 input_ids를 int list 형태로 저장 - unique_speaker_ids : Speaker Token의 input_ids 중복 제거 - enc_speaker : Encoder를 통과한 Hidden State 중에서 Speaker([][][1])의 Representation들을 저장 -> Turn안에 있는 모든 Speaker의 Representation들을 저장하는 list - self.speaker_aware : Speaker의 Representation 기준 Contrastive Learning Loss를 구하는 Function > enc_speaker : speaker의 representation list > speaker_input_ids : speaker의 id 값 list > ctr_margin : ctrastive learning 시, margin 값 > bench_speaker : 기준점 Speaker의 인덱스 = 예) 0번째 Speaker ## Speaker-Aware Function ```Python def speaker_aware(self, enc_speaker, speaker_input_ids, ctr_margin, bench_speaker): negative_sim, positive_sim = torch.empty(1), torch.empty(1) num_turn = len(enc_speaker) for i in range(1, num_turn): sim = torch.dist(enc_speaker[i], enc_speaker[bench_speaker], p=2.0) # sim = torch.dist(enc_speaker[i], enc_speaker[bench_speaker], p=1.0) if speaker_input_ids[i] == speaker_input_ids[bench_speaker]: positive_sim = torch.row_stack((positive_sim, sim)) else: negative_sim = torch.row_stack((negative_sim, sim)) i = 0 positive_sim = torch.cat([positive_sim[0:i], positive_sim[i+1:]]) negative_sim = torch.cat([negative_sim[0:i], negative_sim[i+1:]]) positive_sim_idx = positive_sim.shape[0] negative_sim_idx = negative_sim.shape[0] softmax_sim = torch.cat([positive_sim, negative_sim]) softmax_sim_out = nn.functional.softmax(softmax_sim, dim=0) relu = nn.ReLU() positive_softmax = softmax_sim_out[random.randrange(positive_sim_idx)] negative_softmax = softmax_sim_out[random.randrange(positive_sim_idx, positive_sim_idx+negative_sim_idx)] res = relu(ctr_margin - (positive_softmax - negative_softmax))
Speaker-Aware
BartModel 불러오기
이렇게 해서 BartModel 코드의 일부만 수정하고(Customize) 나머지 Transformers 코드는 라이브러리의 코드를 그대로 사용하기 위함
BartModel Class 코드 가져오기
Speaker Token의 Index와 Representation을 Speaker-Aware Function에 전달
enc_speaker = encoder_outputs[0][0][1] for i in range(1, encoder_outputs[0].shape[0]): enc_speaker = torch.row_stack((enc_speaker, encoder_outputs[0][i][1])) self.speaker_aware(enc_speaker=enc_speaker, # speaker의 representation list speaker_input_ids=speaker_input_ids, # speaker의 id 값 list ctr_margin=1, # ctrastive learning 시, margin 값 bench_speaker=0 # P01을 기준점 = 0번째 Speaker )