Open xiehou-design opened 5 months ago
您好,读了您的论文深受启发,所以想读一下代码如何实现的,但是在阅读代码过程中出现了一点问题,在这里想问一下是不是有解决方案?具体问题如下,我使用vscode ssh连接云服务器进行调试,但是在调试的过程中,不明白代码中的多进程为什么会死锁,无法继续debug调试下去。具体代码如下所示:
def _compute_extended_attention_mask(self, attention_mask, context_count, prompt_number): if not self.prompt_individual_attention and not self.sentence_individual_attention: # #batch x seq_len extended_attention_mask = attention_mask else: # #batch x seq_len x seq_len extended_attention_mask = attention_mask.unsqueeze(1).expand(-1, attention_mask.size(-1), -1).clone() # attention_mask=attention_mask.unsqueeze(1).expand(-1, attention_mask.size(-1), -1) # extended_attention_mask=torch.empty_like(attention_mask) # extended_attention_mask.copy_(attention_mask) for mask, c_count in zip(extended_attention_mask, context_count): # mask seq_len x seq_len # mask prompt for sentence encoding if self.prompt_individual_attention: # encode for each prompt for p in range(prompt_number): mask[p*self.prompt_length: p*self.prompt_length + self.prompt_length, :prompt_number*self.prompt_length] = 0 mask[p*self.prompt_length: p*self.prompt_length + self.prompt_length, p*self.prompt_length: p*self.prompt_length + self.prompt_length] = 1 if self.sentence_individual_attention: for c in range(c_count): mask[c+self.prompt_length*prompt_number, :self.prompt_length*prompt_number] = 0 return extended_attention_mask def _common_forward( self, encodings: torch.tensor, context_masks: torch.tensor, raw_context_masks: torch.tensor, inx4locator: torch.tensor, pos_encoding: torch.tensor, seg_encoding: torch.tensor, context2token_masks:torch.tensor, token_masks:torch.tensor, image_inputs: dict = None, meta_doc = None): batch_size = encodings.shape[0] context_masks = context_masks.float() token_count = token_masks.long().sum(-1,keepdim=True) context_count = context_masks.long().sum(-1,keepdim=True) raw_context_count = raw_context_masks.long().sum(-1,keepdim=True) pos = None tgt = None tgt2 = None # pdb.set_trace() context_masks = self._compute_extended_attention_mask(context_masks, raw_context_count, self.prompt_number) # print(context_masks.shape) (n,len,len) # print(context_masks.shape) # self = self.eval() if self.model_type == "bert": model = self.bert if self.model_type == "roberta": model = self.roberta # model.embeddings.position_embeddings outputs = model( input_ids=encodings, attention_mask=context_masks, # token_type_ids=seg_encoding, # position_ids=pos_encoding, output_hidden_states=True) # last_hidden_state, pooler_output, hidden_states ....
在context_masks = self._compute_extended_attention_mask(context_masks, raw_context_count, self.prompt_number)中的extended_attention_mask = attention_mask.unsqueeze(1).expand(-1, attention_mask.size(-1), -1).clone()位置前后打断点之后就会产生多进程死锁,无法调试的情况。我看内部函数实现也没有涉及到多进程冲突,想问一下是什么原因导致呢?
context_masks = self._compute_extended_attention_mask(context_masks, raw_context_count, self.prompt_number)
extended_attention_mask = attention_mask.unsqueeze(1).expand(-1, attention_mask.size(-1), -1).clone()
您好,读了您的论文深受启发,所以想读一下代码如何实现的,但是在阅读代码过程中出现了一点问题,在这里想问一下是不是有解决方案?具体问题如下,我使用vscode ssh连接云服务器进行调试,但是在调试的过程中,不明白代码中的多进程为什么会死锁,无法继续debug调试下去。具体代码如下所示:
在
context_masks = self._compute_extended_attention_mask(context_masks, raw_context_count, self.prompt_number)
中的extended_attention_mask = attention_mask.unsqueeze(1).expand(-1, attention_mask.size(-1), -1).clone()
位置前后打断点之后就会产生多进程死锁,无法调试的情况。我看内部函数实现也没有涉及到多进程冲突,想问一下是什么原因导致呢?