Tongjilibo / bert4torch

An elegent pytorch implement of transformers
https://bert4torch.readthedocs.io/
MIT License
1.22k stars 152 forks source link

DataLoader的num_workers设置为大于0时出错 #149

Closed zhouyiyuan-mt closed 1 year ago

zhouyiyuan-mt commented 1 year ago

提问时请尽可能提供如下信息:

基本信息

核心代码

# bert4torch/examples/sequence_labeling/task_sequence_labeling_ner_crf.py

...

def collate_fn(batch):
    batch_token_ids, batch_labels = [], []
    for d in batch:
        tokens = tokenizer.tokenize(d[0], maxlen=maxlen)
        mapping = tokenizer.rematch(d[0], tokens)
        start_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
        end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
        token_ids = tokenizer.tokens_to_ids(tokens)
        labels = np.zeros(len(token_ids))
        for start, end, label in d[1:]:
            if start in start_mapping and end in end_mapping:
                start = start_mapping[start]
                end = end_mapping[end]
                labels[start] = categories_label2id['B-'+label]
                labels[start + 1:end + 1] = categories_label2id['I-'+label]
        batch_token_ids.append(token_ids)
        batch_labels.append(labels)
    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    batch_labels = torch.tensor(sequence_padding(batch_labels), dtype=torch.long, device=device)
    return batch_token_ids, batch_labels

# 转换数据集
train_dataloader = DataLoader(MyDataset('/home/zhouyiyuan/bert_data/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=1) 
valid_dataloader = DataLoader(MyDataset('/home/zhouyiyuan/bert_data/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn) 

...

if __name__ == '__main__':

    evaluator = Evaluator()
    model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])

else: 

    model.load_weights('best_model.pt')

输出信息

# python task_sequence_labeling_ner_crf.py 
[INFO] Global seed set to 42
2023-09-07 10:59:12 - Start Training

2023-09-07 10:59:12 - Epoch: 1/20
Traceback (most recent call last):
  File "/home/zhouyiyuan/test_bert/task_sequence_labeling_ner_crf.py", line 184, in <module>
    model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
  File "/opt/conda/lib/python3.10/site-packages/torch4keras/trainer.py", line 352, in fit
    train_X, train_y = self._prepare_nextbatch()  # 获取下一个batch的训练数据
  File "/opt/conda/lib/python3.10/site-packages/torch4keras/trainer.py", line 298, in _prepare_nextbatch
    batch = next(self.train_dataloader_iter)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/opt/conda/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/zhouyiyuan/test_bert/task_sequence_labeling_ner_crf.py", line 76, in collate_fn
    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
  File "/opt/conda/lib/python3.10/site-packages/torch/cuda/__init__.py", line 235, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

自我尝试

当把DataLoader的num_workers设置为大于0时,训练报错。 原因:在collect_fn里加了to device操作。当num_workers大于0时,主进程会启用多个子进程来load数据,每个子进程都会调用collect_fn进行to device操作,触发CUDA初始化,进而造成re-initialize CUDA。PyTorch DataLoader的官方文档(https://pytorch.org/docs/stable/data.html) 中建议:It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing). Instead, we recommend using automatic memory pinning (i.e., setting pin_memory=True), which enables fast data transfer to CUDA-enabled GPUs。

建议在collect_fn不加to device,在fit函数里: ... train_X, train_y = self._prepare_nextbatch() # 获取下一个batch的训练数据 train_X, train_y = train_X.to(device), train_y.to(device) ...

Tongjilibo commented 1 year ago

model.fit()之前设置一下model.move_to_model_device=True试试看呢,这样在collect_fn里面就不需要to(device)

Tongjilibo commented 1 year ago

此外,在下一个版本中,会默认设置move_to_model_device=True

zhouyiyuan-mt commented 1 year ago

好的,感谢