Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.12k stars 404 forks source link

Missing `Metric.reset()` leads to DataLoader iterator crashing #1560

Closed awaelchli closed 1 year ago

awaelchli commented 1 year ago

🐛 Bug

I have a very peculiar bug in which the torchmetrics Accuracy interacts with the tensors in such a way that the DataLoader iterator crashes. A .reset() call on the metric fixes this issue, but I don't understand why.

To Reproduce

I minimized the following code as much as possible. The dataset file needs to be in the CWD: test.csv


import os

import torch
import torchmetrics
import torch.distributed
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, Dataset

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

class IMDBDataset(Dataset):
    def __init__(self, dataset_dict, partition_key="train"):
        self.partition = dataset_dict[partition_key]

    def __getitem__(self, index):
        return self.partition[index]

    def __len__(self):
        return self.partition.num_rows

def train(model, train_loader, device):
    train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
    input_ids, mask, labels = next(iter(train_loader))
    input_ids, mask, labels = input_ids.to(device), mask.to(device), labels.to(device)

    outputs = model(input_ids, attention_mask=mask, labels=labels)
    predicted_labels = torch.argmax(outputs["logits"].clone(), 1)
    train_acc.update(predicted_labels, labels)
    train_acc.compute()

    # *****************************************************************
    # WHY DO THE FOLLOWING LINES OF CODE PREVENT THE DATALOADER CRASH??
    # *****************************************************************

    # for attr, default in train_acc._defaults.items():
    #     current_val = getattr(train_acc, attr)
    #     setattr(train_acc, attr, default.to(current_val.device))

def run():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    device = torch.device("cuda", local_rank)
    torch.cuda.set_device(local_rank)

    torch.distributed.init_process_group("nccl", rank=local_rank, world_size=world_size)

    imdb_dataset = load_dataset("csv", data_files={"test": "test.csv"})
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

    def tokenize_text(x):
        return tokenizer(x["text"], truncation=True, padding=True)

    imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)
    imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])

    torch.distributed.barrier()

    train_dataset = torch.utils.data.TensorDataset(
        torch.zeros(100, 512, dtype=torch.int64),
        torch.zeros(100, 512, dtype=torch.int64),
        torch.zeros(100, dtype=torch.int64),
    )
    test_dataset = IMDBDataset(imdb_tokenized, partition_key="test")

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=12,
        num_workers=4,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=12,
        num_workers=2,
    )

    model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
    model = DistributedDataParallel(model.to(device), device_ids=[local_rank])

    train(model=model, train_loader=train_loader, device=device)

    torch.distributed.barrier()

    for _ in test_loader:
        pass

    torch.distributed.barrier()
    print("completed without errors")

if __name__ == "__main__":
    run()

Run this script with:

torchrun --nproc_per_node 2 --standalone crashes.py

to reproduce the error:


Traceback (most recent call last):
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1120, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 305, in rebuild_storage_fd
    fd = df.detach()
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 86, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 513, in Client
    answer_challenge(c, authkey)
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 757, in answer_challenge
    message = connection.recv_bytes(256)         # reject large message
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
    buf = self._recv(4)
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
    chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/adrian/repositories/lightning/crashes.py", line 230, in <module>
    for idx, batch in enumerate(test_loader):
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in __next__
    data = self._next_data()
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1316, in _next_data
    idx, data = self._get_data()
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1282, in _get_data
    success, data = self._try_get_data()
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1120, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 2339965) is killed by signal: Aborted.

Expected behavior

No crash. In the code above, you will find commented lines for the metric reset.

for attr, default in train_acc._defaults.items():
    current_val = getattr(train_acc, attr)
    setattr(train_acc, attr, default.to(current_val.device))

Why do these lines (which is part of Metric.reset) prevent the dataloader crash?

Observations

Issue only occurs with device on CUDA and in distributed setting. I only observed this problem when combining this dataset and torchmetrics. The problem might very well be with HF transformers or datasets, but since tochmetrics is involved, I am not sure where the problem needs to be fixed. The code above is very stupid, but it is the result of minimizing a real training script as much as possible that can reproduce the error.

Environment

torch                    1.13.1
datasets                 2.9.0
transformers             4.25.1
torchmetrics             0.11.1

Additional context

SkafteNicki commented 1 year ago

Hi @awaelchli, I tried to take a stab at this over the last couple of hours. Here is what I found:

All of the above is not really an answer, so I looked at what line in TM that causes this. Removing this line fixes the script: https://github.com/Lightning-AI/metrics/blob/78e9571e5e41e8ae924cd10c8200fa5d53d198e4/src/torchmetrics/utilities/distributed.py#L93-L94 however, that is the line that takes care of the distributed synchronization, so that is pretty essential. Also there does not seem to be anything wrong with that particular function. That said I can get it to work if I manually cast to another dtype like:

gathered_result = [torch.zeros_like(result).float() for _ in range(world_size)] 
torch.distributed.all_gather(gathered_result, result.float(), group) 

the default dtype for Accuracy metric states are torch.long so maybe it is a problem between torch.distributed.all_gather and torch.long on CUDA?

When the script fails the first part of the traceback I get contains:

terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: initialization error
Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:31 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x3e (0x7fadfd03d86e in /home/nsde/.conda/envs/metrics/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5c (0x7fadfd0083a8 in /home/nsde/.conda/envs/metrics/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(std::string const&, std::string const&, int, bool) + 0xb4 (0x7fae286ed584 in /home/nsde/.conda/envs/metrics/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x1ebd5 (0x7fae286c5bd5 in /home/nsde/.conda/envs/metrics/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
frame #4: c10::cuda::CUDACachingAllocator::raw_delete(void*) + 0x265 (0x7fae286c80b5 in /home/nsde/.conda/envs/metrics/lib/python3.8/site-packages/torch/lib/libc10_cuda.so)
...

I think the important part is that it fails on delete operation: c10::cuda::CUDACachingAllocator::raw_delete(void*). My best guess is that the reason why

for attr, default in train_acc._defaults.items():
  current_val = getattr(train_acc, attr)
  setattr(train_acc, attr, default.to(current_val.device))

fixes the script is that this will manually overwrite/delete the synchronized result (which is the line that causes the problem) and not rely on the deallocation happening at the end of the script. I am not a expert in multiprocessing, so it may be complete gibberish. Based on this issue from torch: https://github.com/pytorch/pytorch/issues/67978 it seems that this error also exist for others, but no clarification what causes it (but some are also indicating that this has to do with a specific combination of dtypes).

awaelchli commented 1 year ago

Hello @SkafteNicki

Sorry for the (very) late reply. I couldn't spend more time on it and so it got forgotten. Thank you for documenting this and digging deeper than I could. Nice find with the cuda caching allocator. I am fine with dropping this investigation as it is not a high priority and also we don't really know what and where to fix it. If this happens to more users, we could pick it up again.

Thanks again, your time is appreciated!

anhnami commented 1 year ago

Could be related to grpcio 1.53 (https://github.com/ray-project/ray/issues/34194). I faced this same bug and downgrading grpcio to 1.51.3 seems to fix the problem.

wzf03 commented 6 months ago

I encountered this error as well, and it turned out that torchmetrics was the culprit too. I discovered that by attaching the metrics to the LightningModule and letting Lightning handle the movement to the GPU instead of manually managing it separately, the error magically disappeared! It seems that because torchmetrics.Metric is actually a torch.nn.Module, it needs to be treated as such.