Hi
Thanks a lot for sharing this work
I have Question, please
I do not understand this function("def get_ids_for_local_context_extractor(self, text_indices): ") and how it works??
and why you use ==102 in this line and 5 in get labels function
" sep_index = np.argmax((text_ids[text_i] == 102)"
"def get_ids_for_local_context_extractor(self, text_indices): # convert BERT-SPC input to BERT-BASE format text_ids = text_indices.detach().cpu().numpy() for text_i in range(len(text_ids)): sep_index = np.argmax((text_ids[text_i] == 102)) text_ids[text_i][sep_index + 1:] = 0 return torch.tensor(text_ids).to(self.args.device)"
Hi Thanks a lot for sharing this work I have Question, please I do not understand this function("def get_ids_for_local_context_extractor(self, text_indices): ") and how it works?? and why you use ==102 in this line and 5 in get labels function " sep_index = np.argmax((text_ids[text_i] == 102)"
"def get_ids_for_local_context_extractor(self, text_indices): # convert BERT-SPC input to BERT-BASE format text_ids = text_indices.detach().cpu().numpy() for text_i in range(len(text_ids)): sep_index = np.argmax((text_ids[text_i] == 102)) text_ids[text_i][sep_index + 1:] = 0 return torch.tensor(text_ids).to(self.args.device)"