tricktreat / PromptNER

Code for the paper "PromptNER: Prompt Locating and Typing for Named Entity Recognition", accepted at ACL 2023.
85 stars 7 forks source link

debug模式下多线程卡死 #14

Open xiehou-design opened 5 months ago

xiehou-design commented 5 months ago

您好,读了您的论文深受启发,所以想读一下代码如何实现的,但是在阅读代码过程中出现了一点问题,在这里想问一下是不是有解决方案?具体问题如下,我使用vscode ssh连接云服务器进行调试,但是在调试的过程中,不明白代码中的多进程为什么会死锁,无法继续debug调试下去。具体代码如下所示:

def _compute_extended_attention_mask(self, attention_mask, context_count, prompt_number):

        if not self.prompt_individual_attention and not self.sentence_individual_attention:
            # #batch x seq_len
            extended_attention_mask = attention_mask
        else:
            # #batch x seq_len x seq_len
            extended_attention_mask = attention_mask.unsqueeze(1).expand(-1, attention_mask.size(-1), -1).clone()
            # attention_mask=attention_mask.unsqueeze(1).expand(-1, attention_mask.size(-1), -1)
            # extended_attention_mask=torch.empty_like(attention_mask)
            # extended_attention_mask.copy_(attention_mask)
            for mask, c_count in zip(extended_attention_mask, context_count):
                # mask seq_len x seq_len
                # mask prompt for sentence encoding
                if self.prompt_individual_attention:
                    # encode for each prompt
                    for p in range(prompt_number):
                        mask[p*self.prompt_length:  p*self.prompt_length + self.prompt_length, :prompt_number*self.prompt_length] = 0
                        mask[p*self.prompt_length: p*self.prompt_length + self.prompt_length, p*self.prompt_length: p*self.prompt_length + self.prompt_length] = 1
                if self.sentence_individual_attention:
                    for c in range(c_count):
                        mask[c+self.prompt_length*prompt_number, :self.prompt_length*prompt_number] = 0

        return extended_attention_mask

    def _common_forward(
        self, 
        encodings: torch.tensor, 
        context_masks: torch.tensor, 
        raw_context_masks: torch.tensor, 
        inx4locator: torch.tensor, 
        pos_encoding: torch.tensor, 
        seg_encoding: torch.tensor, 
        context2token_masks:torch.tensor,
        token_masks:torch.tensor,
        image_inputs: dict = None,
        meta_doc = None):

        batch_size = encodings.shape[0]
        context_masks = context_masks.float()
        token_count = token_masks.long().sum(-1,keepdim=True)
        context_count = context_masks.long().sum(-1,keepdim=True)
        raw_context_count = raw_context_masks.long().sum(-1,keepdim=True)
        pos = None
        tgt = None
        tgt2 = None

        # pdb.set_trace()

        context_masks = self._compute_extended_attention_mask(context_masks, raw_context_count, self.prompt_number)
        # print(context_masks.shape) (n,len,len)
        # print(context_masks.shape)
        # self = self.eval()
        if self.model_type == "bert":
            model = self.bert
        if self.model_type == "roberta":
            model = self.roberta
        # model.embeddings.position_embeddings
        outputs = model(
                    input_ids=encodings,
                    attention_mask=context_masks,
                    # token_type_ids=seg_encoding,
                    # position_ids=pos_encoding,
                    output_hidden_states=True)
        # last_hidden_state, pooler_output, hidden_states
....

context_masks = self._compute_extended_attention_mask(context_masks, raw_context_count, self.prompt_number)中的extended_attention_mask = attention_mask.unsqueeze(1).expand(-1, attention_mask.size(-1), -1).clone()位置前后打断点之后就会产生多进程死锁,无法调试的情况。我看内部函数实现也没有涉及到多进程冲突,想问一下是什么原因导致呢?