ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.28k stars 5.63k forks source link

Rllib: Fractional GPU setup #46176

Open DavidAkinpelu opened 3 months ago

DavidAkinpelu commented 3 months ago

I am trying to use Ray rllib run multiple environments that require GPU resources. My goal is to allocate a fraction of the GPU (e.g., 0.05) for the learner worker (policy) and share the remaining fraction among the environment runners. I got this error

  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/_private/worker.py", line 2613, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/_private/worker.py", line 861, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(IndexError): ray::PPO.train() (pid=154230, ip=192.168.0.28, actor_id=7c0a2f83028de76a8a91951401000000, repr=PPO)
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 328, in train
    result = self.step()
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 873, in step
    train_results, train_iter_ctx = self._run_one_training_iteration()
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 3154, in _run_one_training_iteration
    results = self.training_step()
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 428, in training_step
    return self._training_step_old_and_hybrid_api_stacks()
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 594, in _training_step_old_and_hybrid_api_stacks
    train_results = multi_gpu_train_one_step(self, train_batch)
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/rllib/execution/train_ops.py", line 152, in multi_gpu_train_one_step
    num_loaded_samples[policy_id] = local_worker.policy_map[
  File "/home/david/anaconda3/envs/ray/lib/python3.10/site-packages/ray/rllib/policy/torch_policy_v2.py", line 802, in load_batch_into_buffer
    return len(slices[0])
IndexError: list index out of range

simply increasing the sample_timeout_s doesn't affect anything.

Versions / Dependencies

Ray 2.24 Python 3.10 Ubuntu 22.04

Reproduction script

config.resources( num_cpus_per_worker=1, num_gpus_per_worker=(1 - 0.05) / 10 if torch.cuda.is_available() else 0, num_gpus_per_learner_worker=0.05 if torch.cuda.is_available() else 0, num_cpus_per_learner_worker=1, )

config.env_runners(num_env_runners=10, rollout_fragment_length=200, num_envs_per_env_runner=1 )

Issue Severity

High: It blocks me from completing my task.

fireuse commented 3 months ago

I have run into the same issue, you can work around it by using only one env_runner and setting create_env_on_local_worker to True, although it will be much slower.

simonsays1980 commented 3 months ago

@DavidAkinpelu thanks for raising this. Could you provide a reproducable script?

From first sight It looks like you want to use our new API stack (the learners) and at the same time use the old stack (torch_policy_v2). To activate the new stack API use in the config:


config.api_stack(
      enable_rl_module_and_learner=True,
      enable_env_runner_and_connector_v2,
)
DavidAkinpelu commented 3 months ago

@simonsays1980 Thanks for your response. Here is a reproducible script.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import gym
from gym import spaces
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from ray import tune
import ray

class ResNet(nn.Module):
    def __init__(self, model_name='resnet18', num_classes=10, pretrained=True):
        super(ResNet, self).__init__()
        if model_name == 'resnet18':
            self.model = models.resnet18(pretrained=pretrained)
        elif model_name == 'resnet34':
            self.model = models.resnet34(pretrained=pretrained)
        else:
            raise ValueError(f"Unsupported model: {model_name}")

        # Modify the last layer for the number of classes
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.model(x)

class SimpleLearningRateEnv(gym.Env):
    def __init__(self, model, train_loader, test_loader, criterion, initial_lr, optimizer, device, max_steps):
        super(SimpleLearningRateEnv, self).__init__()

        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = criterion
        self.initial_lr = initial_lr
        self.lr = initial_lr
        self.device = device
        self.max_steps = max_steps
        self.current_step = 0

        if optimizer == 'sgd':
            self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9)
        elif optimizer == 'adam':
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        self.action_space = spaces.Discrete(3)  # 0: decrease, 1: maintain, 2: increase
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float32)

    def step(self, action):
        if action == 0:
            self.lr = max(1e-7, self.lr * 0.9)
        elif action == 2:
            self.lr = min(0.1, self.lr * 1.1)

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

        train_loss = self.train_epoch()
        eval_loss = self.eval_epoch()

        reward = -eval_loss
        self.current_step += 1
        terminated = False
        truncated = self.current_step >= self.max_steps

        self.state = np.array([train_loss, eval_loss, self.lr], dtype=np.float32)

        info = {
            'train_loss': train_loss,
            'val_loss': eval_loss,
            'lr': self.lr
        }

        return self.state, reward, terminated, truncated, info

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        for inputs, targets in self.train_loader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item()
        return running_loss / len(self.train_loader)

    def eval_epoch(self):
        self.model.eval()
        running_loss = 0.0
        with torch.no_grad():
            for inputs, targets in self.test_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                running_loss += loss.item()
        return running_loss / len(self.test_loader)

    def reset(self, seed=None, options=None):
        if seed is not None:
            torch.manual_seed(seed)

        self.lr = self.initial_lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

        self.current_step = 0
        train_loss = self.train_epoch()
        eval_loss = self.eval_epoch()

        self.state = np.array([train_loss, eval_loss, self.lr], dtype=np.float32)

        return self.state, {}

def get_cifar10_data(batch_size):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, test_loader

def env_creator(env_config):
    train_loader, test_loader = get_cifar10_data(batch_size=env_config["batch_size"])

    model = ResNet(model_name=env_config["model_name"], 
                   num_classes=env_config["num_classes"], 
                   pretrained=env_config["pretrained"])

    criterion = nn.CrossEntropyLoss()

    return SimpleLearningRateEnv(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        criterion=criterion,
        initial_lr=env_config["initial_lr"],
        optimizer=env_config["optimizer"],
        device=env_config["device"],
        max_steps=env_config["max_steps"]
    )

env_config = {
    "batch_size": 128,
    "model_name": "resnet18",
    "num_classes": 10,
    "pretrained": True,
    "initial_lr": 0.1,
    "optimizer": "sgd",
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "max_steps": 100
}

env = env_creator(env_config)

register_env("learning_rate_env", lambda config: env_creator(env_config))
config = (
    PPOConfig()
    .env_runners(num_env_runners=10)
    .training(
        train_batch_size=1024,
        train_batch_size_per_learner=1024,
        lr=1e-4,
        gamma=0.95,
        lambda_=0.9,
        use_gae=True,
        clip_param=0.4,
        grad_clip=None,
        entropy_coeff=0.1,
        vf_loss_coeff=0.5,
        sgd_minibatch_size=256,
        num_sgd_iter=4,
    )
    .environment(
        env_config = env_config,
        env="learning_rate_env")
    .debugging(log_level="ERROR")
    .framework(framework="torch")
    .resources(
      num_gpus=0.1 if torch.cuda.is_available() else 0,
      num_cpus_per_worker=1,
      num_gpus_per_worker=(1 - 0.1) / 10 if torch.cuda.is_available() else 0,
      num_gpus_per_learner_worker=0.05 if torch.cuda.is_available() else 0,
      num_cpus_per_learner_worker=1,
      num_cpus_for_main_process=1,
      )
)

result = tune.run(
    "PPO",
    name="PPO",
    stop={"timesteps_total": 5000},
    config=config.to_dict(),
    )
DavidAkinpelu commented 3 months ago

@fireuse Can you post a sample config code here?

fireuse commented 2 months ago
pbt = PB2(time_attr="training_iteration",
          perturbation_interval=4,
          metric="env_runners/episode_reward_mean",
          mode="max",
          hyperparam_bounds={
              "lambda": [0.9, 1.0],
              "clip_param": [0.1, 0.5],
              "lr": [1e-5, 1e-3],
              #"entropy_coeff": [0, 0.01],
              "num_sgd_iter": [3, 10],
              #"gamma": [0.9, 0.99]
          })
stopping_criteria = {"training_iteration": 1000}
tuner = tune.Tuner(
    "PPO",
    tune_config=tune.TuneConfig(
        scheduler=pbt,
        num_samples=4,
    ),
    param_space={
        "_disable_preprocessor_api": True,
        "name": "defaultRun",
        "env": "test",
        "callbacks": MyCallbacks,
        "create_env_on_local_worker": True,
        "framework": "torch",
        "num_workers": 1,
        "num_cpus": 1,  # number of CPUs to use per trial
        "num_gpus": 0.25,
        "num_env_per_worker": 8,
        "rollout_fragment_length": 50,
        "lambda": tune.uniform(0.9, 1.0),
        "clip_param": tune.uniform(0.1, 0.5),
        "lr": tune.loguniform(1e-5, 1e-3),
        #"entropy_coeff": tune.uniform(0, 0.001),
        "num_sgd_iter": tune.randint(3, 10),
        #"gamma": tune.uniform(0.9, 0.99),
        "sgd_minibatch_size": 1024,
        "train_batch_size": 30720,
        "grad_clip": 4,
        "model": {
            "custom_model": "my_torch_model",
            # Extra kwargs to be passed to your model's c'tor.
            "custom_model_config": {},

        },
    },
    run_config=air.RunConfig(stop=stopping_criteria, verbose=3,
                             progress_reporter=CLIReporter(max_report_frequency=300)),
)

This is using tune with old api stack. If you have more cpus you can increase num_samples and decrease num_gpus.