ashleve / lightning-hydra-template

PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡🔥⚡
4.11k stars 638 forks source link

Adapting the template to reinforcement learning applications #203

Open 00krishna opened 2 years ago

00krishna commented 2 years ago

Hello @ashleve thanks so much for creating this template. I am finding it so helpful in a lot of my development. I can't thank you enough.

I had a question about adapting this template for Reinforcement learning workflows. I was trying to port some code from a Pytorch Lightning DQN demo to your template, and just had a couple of questions. Here is a link to the PL docs with this code. I pasted the code below for your convenience as well.

One thing I noticed in this code is that the Dataloader is not a separate file, as in the template. And since the train.py function will have a function with something like trainer.fit(model=model, datamodule=datamodule), this could cause a problem. So I have to think about how to adapt the code to fit into the current template. I have to also think about how to handle the cpu/gpu device calls in this code, since some of the classes are not nn.Modules so PL might not be able to control where the tensors are created like in other PL modules. etc.

I was just wondering if you had any suggestions about how to port this code. I imagine I would need to create a separate file for the Datamodule and carry on like any other model. But if you can think of an easier way that I might not have considered then please suggest away.

Here is the full code, in case it is more convenient to look at it here rather than going to the PL repo.


import argparse
from collections import deque, namedtuple, OrderedDict
from typing import Iterator, List, Tuple

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset

import pytorch_lightning as pl
from pl_examples import cli_lightning_logo

class DQN(nn.Module):
    """Simple MLP network.
    >>> DQN(10, 5)  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    DQN(
      (net): Sequential(...)
    )
    """

    def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
        """
        Args:
            obs_size: observation/state size of the environment
            n_actions: number of discrete actions available in the environment
            hidden_size: size of hidden layers
        """
        super().__init__()
        self.net = nn.Sequential(nn.Linear(obs_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, n_actions))

    def forward(self, x):
        return self.net(x.float())

# Named tuple for storing experience steps gathered in training
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "done", "new_state"])

class ReplayBuffer:
    """Replay Buffer for storing past experiences allowing the agent to learn from them.
    >>> ReplayBuffer(5)  # doctest: +ELLIPSIS
    <...reinforce_learn_Qnet.ReplayBuffer object at ...>
    """

    def __init__(self, capacity: int) -> None:
        """
        Args:
            capacity: size of the buffer
        """
        self.buffer = deque(maxlen=capacity)

    def __len__(self) -> int:
        return len(self.buffer)

    def append(self, experience: Experience) -> None:
        """Add experience to the buffer.
        Args:
            experience: tuple (state, action, reward, done, new_state)
        """
        self.buffer.append(experience)

    def sample(self, batch_size: int) -> Tuple:
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices))

        return (
            np.array(states),
            np.array(actions),
            np.array(rewards, dtype=np.float32),
            np.array(dones, dtype=np.bool),
            np.array(next_states),
        )

class RLDataset(IterableDataset):
    """Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training.
    >>> RLDataset(ReplayBuffer(5))  # doctest: +ELLIPSIS
    <...reinforce_learn_Qnet.RLDataset object at ...>
    """

    def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
        """
        Args:
            buffer: replay buffer
            sample_size: number of experiences to sample at a time
        """
        self.buffer = buffer
        self.sample_size = sample_size

    def __iter__(self) -> Iterator:
        states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
        for i in range(len(dones)):
            yield states[i], actions[i], rewards[i], dones[i], new_states[i]

class Agent:
    """Base Agent class handling the interaction with the environment.
    >>> env = gym.make("CartPole-v1")
    >>> buffer = ReplayBuffer(10)
    >>> Agent(env, buffer)  # doctest: +ELLIPSIS
    <...reinforce_learn_Qnet.Agent object at ...>
    """

    def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
        """
        Args:
            env: training environment
            replay_buffer: replay buffer storing experiences
        """
        self.env = env
        self.replay_buffer = replay_buffer
        self.reset()
        self.state = self.env.reset()

    def reset(self) -> None:
        """Resets the environment and updates the state."""
        self.state = self.env.reset()

    def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
        """Using the given network, decide what action to carry out using an epsilon-greedy policy.
        Args:
            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device
        Returns:
            action
        """
        if np.random.random() < epsilon:
            action = self.env.action_space.sample()
        else:
            state = torch.tensor([self.state])

            if device not in ["cpu"]:
                state = state.cuda(device)

            q_values = net(state)
            _, action = torch.max(q_values, dim=1)
            action = int(action.item())

        return action

    @torch.no_grad()
    def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = "cpu") -> Tuple[float, bool]:
        """Carries out a single interaction step between the agent and the environment.
        Args:
            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device
        Returns:
            reward, done
        """

        action = self.get_action(net, epsilon, device)

        # do step in the environment
        new_state, reward, done, _ = self.env.step(action)

        exp = Experience(self.state, action, reward, done, new_state)

        self.replay_buffer.append(exp)

        self.state = new_state
        if done:
            self.reset()
        return reward, done

class DQNLightning(pl.LightningModule):
    """Basic DQN Model.
    >>> DQNLightning(env="CartPole-v1")  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    DQNLightning(
      (net): DQN(
        (net): Sequential(...)
      )
      (target_net): DQN(
        (net): Sequential(...)
      )
    )
    """

    def __init__(
        self,
        env: str,
        replay_size: int = 200,
        warm_start_steps: int = 200,
        gamma: float = 0.99,
        eps_start: float = 1.0,
        eps_end: float = 0.01,
        eps_last_frame: int = 200,
        sync_rate: int = 10,
        lr: float = 1e-2,
        episode_length: int = 50,
        batch_size: int = 4,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.replay_size = replay_size
        self.warm_start_steps = warm_start_steps
        self.gamma = gamma
        self.eps_start = eps_start
        self.eps_end = eps_end
        self.eps_last_frame = eps_last_frame
        self.sync_rate = sync_rate
        self.lr = lr
        self.episode_length = episode_length
        self.batch_size = batch_size

        self.env = gym.make(env)
        obs_size = self.env.observation_space.shape[0]
        n_actions = self.env.action_space.n

        self.net = DQN(obs_size, n_actions)
        self.target_net = DQN(obs_size, n_actions)

        self.buffer = ReplayBuffer(self.replay_size)
        self.agent = Agent(self.env, self.buffer)
        self.total_reward = 0
        self.episode_reward = 0
        self.populate(self.warm_start_steps)

    def populate(self, steps: int = 1000) -> None:
        """Carries out several random steps through the environment to initially fill up the replay buffer with
        experiences.
        Args:
            steps: number of random steps to populate the buffer with
        """
        for i in range(steps):
            self.agent.play_step(self.net, epsilon=1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Passes in a state `x` through the network and gets the `q_values` of each action as an output.
        Args:
            x: environment state
        Returns:
            q values
        """
        output = self.net(x)
        return output

    def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        """Calculates the mse loss using a mini batch from the replay buffer.
        Args:
            batch: current mini batch of replay data
        Returns:
            loss
        """
        states, actions, rewards, dones, next_states = batch

        state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

        with torch.no_grad():
            next_state_values = self.target_net(next_states).max(1)[0]
            next_state_values[dones] = 0.0
            next_state_values = next_state_values.detach()

        expected_state_action_values = next_state_values * self.gamma + rewards

        return nn.MSELoss()(state_action_values, expected_state_action_values)

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict:
        """Carries out a single step through the environment to update the replay buffer. Then calculates loss
        based on the minibatch received.
        Args:
            batch: current mini batch of replay data
            nb_batch: batch number
        Returns:
            Training loss and log metrics
        """
        device = self.get_device(batch)
        epsilon = max(self.eps_end, self.eps_start - (self.global_step + 1) / self.eps_last_frame)

        # step through environment with agent
        reward, done = self.agent.play_step(self.net, epsilon, device)
        self.episode_reward += reward

        # calculates training loss
        loss = self.dqn_mse_loss(batch)

        if done:
            self.total_reward = self.episode_reward
            self.episode_reward = 0

        # Soft update of target network
        if self.global_step % self.sync_rate == 0:
            self.target_net.load_state_dict(self.net.state_dict())

        log = {
            "total_reward": torch.tensor(self.total_reward).to(device),
            "reward": torch.tensor(reward).to(device),
            "steps": torch.tensor(self.global_step).to(device),
        }

        return OrderedDict({"loss": loss, "log": log, "progress_bar": log})

    def configure_optimizers(self) -> List[Optimizer]:
        """Initialize Adam optimizer."""
        optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        return [optimizer]

    def __dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences."""
        dataset = RLDataset(self.buffer, self.episode_length)
        dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, sampler=None)
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader."""
        return self.__dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch."""
        return batch[0].device.index if self.on_gpu else "cpu"

    @staticmethod
    def add_model_specific_args(parent_parser):  # pragma: no-cover
        parser = parent_parser.add_argument_group("DQNLightning")
        parser.add_argument("--batch_size", type=int, default=16, help="size of the batches")
        parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
        parser.add_argument("--env", type=str, default="CartPole-v1", help="gym environment tag")
        parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
        parser.add_argument("--sync_rate", type=int, default=10, help="how many frames do we update the target network")
        parser.add_argument("--replay_size", type=int, default=1000, help="capacity of the replay buffer")
        parser.add_argument(
            "--warm_start_steps",
            type=int,
            default=1000,
            help="how many samples do we use to fill our buffer at the start of training",
        )
        parser.add_argument("--eps_last_frame", type=int, default=1000, help="what frame should epsilon stop decaying")
        parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon")
        parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon")
        parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode")
        return parent_parser

def main(args) -> None:
    model = DQNLightning(**vars(args))

    trainer = pl.Trainer(gpus=1, strategy="dp", val_check_interval=100)

    trainer.fit(model)

if __name__ == "__main__":
    cli_lightning_logo()
    torch.manual_seed(0)
    np.random.seed(0)

    parser = argparse.ArgumentParser(add_help=False)
    parser = DQNLightning.add_model_specific_args(parser)
    args = parser.parse_args()

    main(args)
ashleve commented 2 years ago

@00krishna Thank you for using the repo!

If you don't want to create a datamodule then simply don't do it - it should be easy to remove it from the pipeline and delete its configs.

You could also consider replacing the datamodule config with dataloader config and pass it as train_dataloaders:

trainer.fit(model, train_dataloaders=train_loader)
00krishna commented 2 years ago

@ashleve Thanks for the suggestions. Yes, I was actually able to get this working :). I found a different repo that has example code using RL with pytorch lightning. So that file had the dataloader and model in separate files. And I was able to adapt that code to work inside of the framework that you developed :). So perfecto. It was not too difficult.

Thanks again for creating and sharing the template. It just make it so much easier to work on the model itself, instead of writing all that code for hyperparameter tuning, and managing experiments, and everything else. Now I can focus on getting data into the system and then quickly testing different model variations.

If you like I can share the code with you on adapting your template to the RL algorithms I found. So far I just have the DQN model in there--that was my test case. But today I will adapt a few more algorithm now that DQN works. Let me know.

ashleve commented 2 years ago

If you like I can share the code with you on adapting your template to the RL algorithms I found.

Sure, if you want you can post it here for others for reference :)

I personally prefer using Ray instead of lightning when doing RL, since it's easy to scale it to multiple concurrent environments, but as for now I don't really have any example of using Ray with the template

ZeguanXiao commented 2 years ago

@00krishna Can you share the code?

00krishna commented 2 years ago

Hello @ZeguanXiao which code would you like me to share? I have been using this template for most of my usual NN code, like image segmentation, text stuff etc. For reinforcement learning I have been looking into the Ray library as @ashleve said. There are now some very nice connectors between Ray and Pytorch Lightning, so I will further adjust this template to leverage those connections for RL.

ZeguanXiao commented 2 years ago

Hello @ZeguanXiao which code would you like me to share?

Hi, I mean your RL code adapted from this template. I thought you have finished an RL flavor of this template. You can share it if it's convenient for you.

00krishna commented 2 years ago

@ZeguanXiao yeah, I can do that. I still have a bit of work to do on the template, but it is not that bad. There is a plugin for the Ray/RLlib library and pytorch lightning. So that should make it pretty easy to connect the two libraries. I have some simple demo code for a custom environment that I can include, just to get things started. It will probably take a week, but I can post the link here when it is set.

realjoenguyen commented 2 years ago

@00krishna can I ask for your template? Thanks

ashleve commented 2 years ago

This is a nice example (Lightning + Hydra + RL): https://github.com/belerico/lightning-rl