ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.97k stars 5.77k forks source link

[Train/Tune] Setting `export CUDA_VISIBLE_DEVICES=0` leads to error `ValueError: '0' is not in list`. List of GPUs is made of integers but checks for a string member. #28467

Closed orcunderscore closed 2 years ago

orcunderscore commented 2 years ago

What happened + What you expected to happen

The issue I have occurs with GPUs. I have to run export CUDA_VISIBLE_DEVICES=0 before running my code. I get an error ValueError: '0' is not in list.

I provide a small example how to reproduce this error further below.

I traced this error back to ray's train_loop_utils.TorchWorkerProfile.get_device function: Here, the line gpu_ids = ray.get_gpu_ids() yields a list of strings. However, later down the following code contains a bug (see in code comments):

                gpu_id = gpu_ids[0]  # This is a string

                cuda_visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", "")
                if cuda_visible_str and cuda_visible_str != "NoDevFiles":
                    cuda_visible_list = [
                        int(dev) for dev in cuda_visible_str.split(",")
                    ]  # This is a list of integers
                    device_id = cuda_visible_list.index(gpu_id)  # Looking for the position of a string in an array full of integers --> ValueError: '0' is not in list

Note that the error does not occur without specyfing CUDA_VISIBLE_DEVICES, however then it just picks all GPUs and not the one I specify.

Versions / Dependencies

Conda env yaml

# To create a new environment from this file run : "conda env create --file conda_environment_gpu.yaml".
name: ray_test
channels:
  - pytorch
  - conda-forge
  - defaults
dependencies:
  - python=3.9.5
  - cudatoolkit=11.3
  - pip
  - pytorch::pytorch=1.12.1
  - pip:
         - ray==2.0.0
         - pandas
         - pyarrow
         - tabulate
         - torchvision
         - pandas

conda list

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
aiosignal                 1.2.0                    pypi_0    pypi
attrs                     22.1.0                   pypi_0    pypi
blas                      1.0                         mkl  
ca-certificates           2022.9.14            ha878542_0    conda-forge
certifi                   2022.9.14                pypi_0    pypi
charset-normalizer        2.1.1                    pypi_0    pypi
click                     8.0.4                    pypi_0    pypi
cudatoolkit               11.3.1              h9edb442_10    conda-forge
distlib                   0.3.6                    pypi_0    pypi
filelock                  3.8.0                    pypi_0    pypi
frozenlist                1.3.1                    pypi_0    pypi
grpcio                    1.43.0                   pypi_0    pypi
idna                      3.4                      pypi_0    pypi
jsonschema                4.16.0                   pypi_0    pypi
ld_impl_linux-64          2.36.1               hea4e1c9_2    conda-forge
libffi                    3.3                  h58526e2_2    conda-forge
libgcc-ng                 12.1.0              h8d9b700_16    conda-forge
libsqlite                 3.39.3               h753d276_0    conda-forge
libstdcxx-ng              12.1.0              ha89aaad_16    conda-forge
libzlib                   1.2.12               h166bdaf_3    conda-forge
llvm-openmp               14.0.4               he0ac6c6_0    conda-forge
mkl                       2022.1.0           h84fe81f_915    conda-forge
msgpack                   1.0.4                    pypi_0    pypi
ncurses                   6.3                  h27087fc_1    conda-forge
numpy                     1.23.3                   pypi_0    pypi
openssl                   1.1.1q               h166bdaf_0    conda-forge
pandas                    1.4.4                    pypi_0    pypi
pillow                    9.2.0                    pypi_0    pypi
pip                       22.2.2             pyhd8ed1ab_0    conda-forge
platformdirs              2.5.2                    pypi_0    pypi
protobuf                  3.20.2                   pypi_0    pypi
pyarrow                   9.0.0                    pypi_0    pypi
pyrsistent                0.18.1                   pypi_0    pypi
python                    3.9.5           h49503c6_0_cpython    conda-forge
python-dateutil           2.8.2                    pypi_0    pypi
pytorch                   1.12.1          py3.9_cuda11.3_cudnn8.3.2_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2022.2.1                 pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
ray                       2.0.0                    pypi_0    pypi
readline                  8.1.2                h0f457ee_0    conda-forge
requests                  2.28.1                   pypi_0    pypi
setuptools                65.3.0             pyhd8ed1ab_1    conda-forge
six                       1.16.0                   pypi_0    pypi
sqlite                    3.39.3               h4ff8645_0    conda-forge
tabulate                  0.8.10                   pypi_0    pypi
tbb                       2021.5.0             h924138e_2    conda-forge
tk                        8.6.12               h27826a3_0    conda-forge
torchvision               0.13.1                   pypi_0    pypi
typing_extensions         4.3.0              pyha770c72_0    conda-forge
tzdata                    2022c                h191b570_0    conda-forge
urllib3                   1.26.12                  pypi_0    pypi
virtualenv                20.16.5                  pypi_0    pypi
wheel                     0.37.1             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zlib                      1.2.12               h166bdaf_3    conda-forge

pip list

Package            Version
------------------ ---------
aiosignal          1.2.0
attrs              22.1.0
certifi            2022.9.14
charset-normalizer 2.1.1
click              8.0.4
distlib            0.3.6
filelock           3.8.0
frozenlist         1.3.1
grpcio             1.43.0
idna               3.4
jsonschema         4.16.0
msgpack            1.0.4
numpy              1.23.3
pandas             1.4.4
Pillow             9.2.0
pip                22.2.2
platformdirs       2.5.2
protobuf           3.20.2
pyarrow            9.0.0
pyrsistent         0.18.1
python-dateutil    2.8.2
pytz               2022.2.1
PyYAML             6.0
ray                2.0.0
requests           2.28.1
setuptools         65.3.0
six                1.16.0
tabulate           0.8.10
torch              1.12.1
torchvision        0.13.1
typing_extensions  4.3.0
urllib3            1.26.12
virtualenv         20.16.5
wheel              0.37.1

Reproduction script

Example taken from https://docs.ray.io/en/latest/train/examples/torch_fashion_mnist_example.html and barely adjusted (just removed the argparse). export CUDA_VISIBLE_DEVICES=0

import argparse
from typing import Dict
from ray.air import session

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

import ray.train as train
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="~/data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="~/data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) // session.get_world_size()
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def validate_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset) // session.get_world_size()
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n "
        f"Accuracy: {(100 * correct):>0.1f}%, "
        f"Avg loss: {test_loss:>8f} \n"
    )
    return test_loss

def train_func(config: Dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]

    worker_batch_size = batch_size // session.get_world_size()

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=worker_batch_size)
    test_dataloader = DataLoader(test_data, batch_size=worker_batch_size)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.
    model = NeuralNetwork()
    model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    loss_results = []

    for _ in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        loss = validate_epoch(test_dataloader, model, loss_fn)
        loss_results.append(loss)
        session.report(dict(loss=loss))

    # return required for backwards compatibility with the old API
    # TODO(team-ml) clean up and remove return
    return loss_results

def train_fashion_mnist(num_workers=2, use_gpu=False):
    trainer = TorchTrainer(
        train_func,
        train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 4},
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
    )
    result = trainer.fit()
    print(f"Results: {result.metrics}")

if __name__ == "__main__":
    train_fashion_mnist(num_workers=1, use_gpu=True)

Issue Severity

High: It blocks me from completing my task.

orcunderscore commented 2 years ago

I changed the title and added explicit dependencies for easier reproducibility. Can someone please confirm if this is a bug or if I am doing something wrong on my end? Thank you!

amogkam commented 2 years ago

Hey @mr-abc-xyz, yes this is a bug! I'm taking a look right now!