Lightning-Universe / lightning-bolts

Toolbox of models, callbacks, and datasets for AI/ML researchers.
https://lightning-bolts.readthedocs.io
Apache License 2.0
1.68k stars 320 forks source link

Shared replay buffer #280

Open MihaiAnca13 opened 3 years ago

MihaiAnca13 commented 3 years ago

🚀 Feature

The RL implementations added do not have the num_workers option. I have a feeling this is because the code doesn't support a shared replay buffer.

Motivation

Adding this would enable distributed training, which is very important in RL. Certain models wouldn't work at all without this (e.g.: A3C).

Pitch

When num_workers args is specified and is greater than 1, the replay buffer used should be shared amongst all workers.

Alternatives

I believe OpenAI does this using mpi4py, but it would probably defeat the point since PyTorch handles multiprocessing.

Additional context

In the PyTorch docs, they said: "When num_workers > 0, each worker process will have a different copy of the dataset object". This is not true. I tested it using torch.utils.data.get_worker_info() and each worker was pointing to the same address. This is good because it means the replay buffer can be initialised in the Iterable class.

The code below is what I've used for testing. Line 27 (self.replay_buffer.share_memory_()) is where the magic happens. Try running it with that line commented and uncommented to see the difference. I'm happy to provide more explanation if needed or to help implement this.

import numpy as np
import pytorch_lightning as pl
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
import torch.nn as nn
from torch.utils.data._utils import collate

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 5),
            nn.ReLU(),
            nn.Linear(5,1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

class RLDataset(IterableDataset):
    def __init__(self):
        self.replay_buffer = torch.tensor([0, 0, 0, 0], dtype=torch.float32)
        # self.replay_buffer.share_memory_()

    def __iter__(self):
        self.replay_buffer[torch.utils.data.get_worker_info().id] = torch.utils.data.get_worker_info().id
        print(self.replay_buffer)
        for i in range(np.random.randint(5, 10)):
            yield self.replay_buffer

class MLightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = MLP()

    def training_step(self, batch, nb_batch):
        batch = torch.cat(batch, dim=-1).unsqueeze(dim=-1)
        loss = self.net(batch)
        loss = loss.mean()

        result = pl.TrainResult(loss)
        result.log('loss',
                   loss,
                   on_step=True,
                   on_epoch=True,
                   prog_bar=False,
                   logger=False)

        return result

    def configure_optimizers(self):
        """ Initialize Adam optimizer"""
        optimizer = optim.Adam(self.net.parameters(), lr=0.001)
        return [optimizer]

    def __dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving
        experiences"""
        dataset = RLDataset()
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=4,
            num_workers=4,
            collate_fn=collate.default_convert,
            pin_memory=True
        )
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader"""
        return self.__dataloader()

def main():
    model = MLightning()

    trainer = pl.Trainer(
        gpus=1,
        # distributed_backend='dp',
        max_epochs=2000,
    )

    trainer.fit(model)

torch.manual_seed(42)
np.random.seed(42)

main()
github-actions[bot] commented 3 years ago

Hi! thanks for your contribution!, great first issue!

Borda commented 3 years ago

@MihaiAnca13 good point, mind sending a PR?

MihaiAnca13 commented 3 years ago

@Borda thanks for your reply. Can you please have a quick look at the SharedReplayBuffer I've added? I'm not sure what I should do/add next to make this useful.