xiangking / ark-nlp

A private nlp coding package, which quickly implements the SOTA solutions.
Apache License 2.0
310 stars 66 forks source link

RuntimeError: sparse tensors do not have strides #45

Open yysirs opened 2 years ago

yysirs commented 2 years ago

单机多卡跑GlobalPoint模型,出现以上错误,其他模型多卡代码跑GlobalPoint没有报错

yysirs commented 2 years ago

image

yysirs commented 2 years ago

查看了一下问题,发现是由于label变成sparse导致的。稀疏化可能会防止内存爆炸,但是导致在多卡运行时出现问题,希望把生成labels的过程放在collate_fn中,每个batch去生成就应该不会有问题了。

jimme0421 commented 2 years ago

确实是稀疏化的问题。主要的问题点在于在将张量分散到多个GPU上会涉及到稀疏矩阵相乘的问题,而torch不支持稀疏和密集(稀疏)矩阵相乘。

解决方案如你所说,需要修改collate_fn。但不需要把生成labels的过程放过去,只需要在collate_fn将稀疏矩阵稠密化。

jimme0421 commented 2 years ago

修正方案如下: 在global_pointer_bert_named_entity_recognition.py 的 GlobalPointerNERTask() 中添加

def _train_collate_fn(self, batch):

    input_ids = default_collate([f['input_ids'] for f in batch])
    attention_mask = default_collate([f['attention_mask'] for f in batch])
    token_type_ids = default_collate([f['token_type_ids'] for f in batch])
    label_ids = default_collate([f['label_ids'].to_dense() for f in batch])

    tensors = {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'token_type_ids': token_type_ids,
        'label_ids': label_ids,
    }
    return tensors

def _evaluate_collate_fn(self, batch):
    return self._train_collate_fn(batch)

之后会有两处 .to_dence() 报错, 将他们删去便可以了

jimme0421 commented 2 years ago

下个版本会修复这个bug

yysirs commented 2 years ago

GlobalPoint的label是三维矩阵,直接生成label并放入内存中,如果实体数目很多,训练语料又很大。对于配置一般的同学训练起来应该挺难受的,为了能够较好的循环利用内存资源,私认为把生成label的过程放在collate_fn中这样是比较合适的。提供一个小建议,作者大大考虑一下~

yysirs commented 2 years ago
import torch
import warnings
from ark_nlp.factory.loss_function import get_loss
from ark_nlp.factory.utils import conlleval
from ark_nlp.factory.task.base._token_classification import TokenClassificationTask
from ark_nlp.factory.utils.ema import EMA
from torch.utils.data._utils.collate import default_collate

class GlobalPointerNERTask(TokenClassificationTask):
    """
    GlobalPointer的命名实体识别Task

    Args:
        module: 深度学习模型
        optimizer: 训练模型使用的优化器名或者优化器对象
        loss_function: 训练模型使用的损失函数名或损失函数对象
        class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目
        scheduler (:obj:`class`, optional, defaults to None): scheduler对象
        n_gpu (:obj:`int`, optional, defaults to 1): GPU数目
        device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU
        cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device
        ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数
        **kwargs (optional): 其他可选参数
    """  # noqa: ignore flake8"
    def __init__(self,
        module,
        tokenizer,
        optimizer,
        loss_function,
        class_num=None,
        scheduler=None,
        n_gpu=1,
        device_ids=None,
        device=None,
        cuda_device=0,
        ema_decay=None,
        ):

        super(TokenClassificationTask).__init__()
        self.module = module
        self.optimizer = optimizer
        self.tokenizer = tokenizer
        self.loss_function = get_loss(loss_function)
        self.class_num = class_num
        self.scheduler = scheduler
        self.device_ids = device_ids
        self.n_gpu = n_gpu
        self.cuda_device = cuda_device
        self.device = device
        self.ema_decay = ema_decay

        if self.device is None:
            if torch.cuda.is_available():
                if self.cuda_device == -1:
                    self.device = torch.device("cuda")
                else:
                    self.device = torch.device(f"cuda:{self.cuda_device}")
            else:
                self.device = "cpu"

        if self.n_gpu > 1:
            self.module.cuda()
            self.module = torch.nn.DataParallel(self.module, device_ids=self.device_ids)
        else:
            self.module.to(self.device)

        self.ema_decay = ema_decay
        if self.ema_decay:
            self.ema = EMA(self.module.parameters(), decay=self.ema_decay)

    def _get_module_inputs_on_train(
        self,
        inputs,
        **kwargs
    ):
        # print(inputs)
        self.train_to_device_cols = list(inputs.keys())
        for col in self.train_to_device_cols:
            if type(inputs[col]) is torch.Tensor:
                inputs[col] = inputs[col].to(self.device)
            else:
                warnings.warn(f"The {col} is not Tensor.\n")

        return inputs

    def _get_module_inputs_on_eval(
        self,
        inputs,
        **kwargs
    ):
        self.evaluate_to_device_cols = list(inputs.keys())
        for col in self.evaluate_to_device_cols:
            if type(inputs[col]) is torch.Tensor:
                inputs[col] = inputs[col].to(self.device)
            else:
                warnings.warn(f"The {col} is not Tensor.\n")

        return inputs

    def _compute_loss(
        self,
        inputs,
        logits,
        verbose=True,
        **kwargs
    ):
        loss = self.loss_function(logits, inputs['label_ids'])

        return loss

    def _on_evaluate_begin_record(self, **kwargs):

        self.evaluate_logs['eval_loss'] = 0
        self.evaluate_logs['eval_step'] = 0
        self.evaluate_logs['eval_example'] = 0

        self.evaluate_logs['labels'] = []
        self.evaluate_logs['logits'] = []
        self.evaluate_logs['input_lengths'] = []

        self.evaluate_logs['numerate'] = 0
        self.evaluate_logs['denominator'] = 0

    def _on_evaluate_step_end(self, inputs, outputs, **kwargs):

        with torch.no_grad():

            # compute loss
            logits, loss = self._get_evaluate_loss(inputs, outputs, **kwargs)

            numerate, denominator = conlleval.global_pointer_f1_score(
                inputs['label_ids'].cpu(),
                logits.cpu()
            )
            self.evaluate_logs['numerate'] += numerate
            self.evaluate_logs['denominator'] += denominator

        self.evaluate_logs['eval_example'] += len(inputs['label_ids'])
        self.evaluate_logs['eval_step'] += 1
        self.evaluate_logs['eval_loss'] += loss.item()

    def _on_evaluate_epoch_end(
        self,
        validation_data,
        epoch=1,
        is_evaluate_print=True,
        id2cat=None,
        **kwargs
    ):

        if id2cat is None:
            id2cat = self.id2cat

        if is_evaluate_print:
            print('eval loss is {:.6f}, precision is:{}, recall is:{}, f1_score is:{}'.format(
                self.evaluate_logs['eval_loss'] / self.evaluate_logs['eval_step'],
                self.evaluate_logs['numerate'],
                self.evaluate_logs['denominator'],
                2*self.evaluate_logs['numerate']/self.evaluate_logs['denominator'])
            )

    def _train_collate_fn(self, batch):
        # features = 
        input_ids_list = []
        input_mask_list = []
        segment_ids_list = []
        global_label_list = []
        for (index_, row_) in enumerate(batch):
            tokens = self.tokenizer.tokenize(row_['text'])[:self.tokenizer.max_seq_len-2]
            token_mapping = self.tokenizer.get_token_mapping(row_['text'], tokens)

            start_mapping = {j[0]: i for i, j in enumerate(token_mapping) if j}
            end_mapping = {j[-1]: i for i, j in enumerate(token_mapping) if j}

            input_ids = self.tokenizer.sequence_to_ids(tokens)

            input_ids, input_mask, segment_ids = input_ids

            global_label = torch.zeros((
                self.class_num,
                self.tokenizer.max_seq_len,
                self.tokenizer.max_seq_len)
            )

            for info_ in row_['label']:
                if info_['start_idx'] in start_mapping and info_['end_idx'] in end_mapping:
                    start_idx = start_mapping[info_['start_idx']]
                    end_idx = end_mapping[info_['end_idx']]
                    if start_idx > end_idx or info_['entity'] == '':
                        continue
                    global_label[self.cat2id[info_['type']], start_idx+1, end_idx+1] = 1

            # global_label = torch.tensor(global_label).long()
            input_ids_list.append(torch.tensor(input_ids).long())
            input_mask_list.append(torch.tensor(input_mask).long())
            segment_ids_list.append(torch.tensor(segment_ids).long())
            global_label_list.append(torch.tensor(global_label).long())

        batch_input_ids = torch.stack(input_ids_list, dim=0)
        batch_attention_mask = torch.stack(input_mask_list, dim=0)
        batch_token_type_ids = torch.stack(segment_ids_list, dim=0)
        batch_labels = torch.stack(global_label_list, dim=0)
        features = {
            'input_ids': batch_input_ids,
            'attention_mask': batch_attention_mask,
            'token_type_ids': batch_token_type_ids,
            'label_ids': batch_labels
        }
        return features

    def _evaluate_collate_fn(self, batch):
        return self._train_collate_fn(batch)

这是我昨晚修改的版本,跑起来应该没啥问题,内存上不会崩溃,给其他同学一些参考

jimme0421 commented 2 years ago

感谢你提供的建议以及能适配低内存的代码,这个改动方式能够有效减少内存的消耗,但会影响模型训练的速度。

我们考虑在v0.1.0引入lazy机制来处理这一类的问题。

yysirs commented 2 years ago

哈哈哈 也是希望自己能做点贡献 不能白嫖😂😂😂