Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.44k stars 3.29k forks source link

Predict on TPU using all cores #11417

Open stekiri opened 2 years ago

stekiri commented 2 years ago

🐛 Bug

When writing predictions with a torch.save together with a BasePredictionWriter (see this example) on Colab using a TPU runtime employing all 8 cores, only an eighth of the predictions are actually saved on disk.

To Reproduce

The following code is based on the TPU tutorial with a few modifications:

Package installation:

!pip install torch==1.9.1 torchtext==0.10.1 torchvision==0.10.1 pytorch-lightning==1.5.8 cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

Code:

import os
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import BasePredictionWriter
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

BATCH_SIZE = 1024

class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

    def predict_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

class LitModel(LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
        x, y = batch
        return self(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

class CustomWriter(BasePredictionWriter):
    def __init__(self, output_dir: str, write_interval: str):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_batch_end(
            self, trainer, pl_module: LightningModule, prediction: Any, batch_indices: List[int], batch: Any,
            batch_idx: int, dataloader_idx: int):
        torch.save(prediction, os.path.join(self.output_dir, f"{dataloader_idx}_{batch_idx}.pt"))

    def write_on_epoch_end(
            self, trainer, pl_module: LightningModule, predictions: List[Any], batch_indices: List[Any]):
        torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))

tmp_dir = "/tmp"
dm = MNISTDataModule()
model = LitModel(*dm.size(), dm.num_classes)
prediction_writer = CustomWriter(
    output_dir=tmp_dir,
    write_interval="epoch")
trainer = Trainer(
    tpu_cores=8,
    callbacks=[prediction_writer])

trainer.predict(model=model, datamodule=dm)

written_predictions = torch.load(os.path.join(tmp_dir, 'predictions.pt'))
nb_predictions = sum([t.shape[0] for t in written_predictions[0]])

assert nb_predictions == 10_000

When using tpu_cores=[1], all predictions are saved correctly with the downside of only using one core instead of all eight.

Expected behavior

The predictions from all cores should be saved in the file.

Environment

Colab with TPU runtime.

Additional context

Using the BasePredictionWriter was suggested in this issue. As requested by @kaushikb11, I created this new issue.

cc @kaushikb11 @rohitgr7

ananthsub commented 2 years ago

in your callback, each TPU core is overwriting the same files, f"self.output_dir/{dataloader_idx}_{batch_idx}.pt" andpredictions.pt` . So when you're loading them afterwards, you're seeing only a portion of the total (whatever was saved last). Either open the files in append mode, or write to a different file per core and group them together afterwards.

If you open the file in append mode, be sure to close it afterwards at the end of prediction. if you're partitioning the files, you can use trainer.global_rank to distinguish each process's outputs

stekiri commented 2 years ago

Thanks @ananthsub for your very helpful guidance!

Your suggestion to write it in 8 separate files works like a charm. For anyone that is coming across this issue in the future, here's how I modified the code:

class MultiFileWriter(BasePredictionWriter):
    def __init__(self, output_dir: str, write_interval: str):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_batch_end(
            self, trainer, pl_module: LightningModule, prediction: Any, batch_indices: List[int], batch: Any,
            batch_idx: int, dataloader_idx: int):
        torch.save(prediction, os.path.join(self.output_dir, f"predictions-dataloader_{dataloader_idx}-batch_{batch_idx}-globalrank_{trainer.global_rank}.pt"))

    def write_on_epoch_end(
            self, trainer, pl_module: LightningModule, predictions: List[Any], batch_indices: List[Any]):
        torch.save(predictions, os.path.join(self.output_dir, f"predictions-globalrank_{trainer.global_rank}.pt"))

tmp_dir = "/tmp"
dm = MNISTDataModule()
model = LitModel(*dm.size(), dm.num_classes)
prediction_writer = MultiFileWriter(
    output_dir=tmp_dir,
    write_interval="epoch")
trainer = Trainer(
    tpu_cores=8,
    callbacks=[prediction_writer])

trainer.predict(model=model, datamodule=dm)

Unfortunately, I couldn't make your alternative suggestion work, the one that appends to a file. This is what I've tried:

class SingleFileWriter(BasePredictionWriter):
    def __init__(self, file_buffer, write_interval: str):
        super().__init__(write_interval)
        self.file_buffer = file_buffer

    def write_on_batch_end(
            self, trainer, pl_module: LightningModule, prediction: Any, batch_indices: List[int], batch: Any,
            batch_idx: int, dataloader_idx: int):
        torch.save(prediction, self.file_buffer)

    def write_on_epoch_end(
            self, trainer, pl_module: LightningModule, predictions: List[Any], batch_indices: List[Any]):
        torch.save(predictions, self.file_buffer)

tmp_dir = "/tmp"
dm = MNISTDataModule()
model = LitModel(*dm.size(), dm.num_classes)
with open(os.path.join(tmp_dir, 'sf_predictions.pt'), 'ab') as f:
    prediction_writer = SingleFileWriter(
        file_buffer=f,
        write_interval="epoch")
    trainer = Trainer(
        tpu_cores=8,
        callbacks=[prediction_writer])

    trainer.predict(model=model, datamodule=dm)

It seems that all data is written as the file has the expected size, however, when reading the file using torch.load() only an eighth of the predictions are actually in the loaded object. Looks like the written data is somehow colliding. Maybe you have another clever tip to make this work?

ananthsub commented 2 years ago

It seems that all data is written as the file has the expected size, however, when reading the file using torch.load() only an eighth of the predictions are actually in the loaded object. Looks like the written data is somehow colliding. Maybe you have another clever tip to make this work?

Could you try writing directly to the file buffer? for instance, does this work?

    def write_on_batch_end(
            self, trainer, pl_module: LightningModule, prediction: Any, batch_indices: List[int], batch: Any,
            batch_idx: int, dataloader_idx: int):
        self.file_buffer.write(<some value>)

    def write_on_epoch_end(
            self, trainer, pl_module: LightningModule, predictions: List[Any], batch_indices: List[Any]):
        self.file_buffer.write(<some dummy value>)
stekiri commented 2 years ago

I get the same behavior. I write with self.file_buffer.write(pickle.dumps(predictions) and read it back with pickle.load() as torch.load() fails with RuntimeError: Invalid magic number; corrupt file? when loading the written buffer.

bokey007 commented 1 year ago

Hi, this discussion was very helpful for me!

But I still need to figure out how do I save image file names along with their predictions. Also I am working with a very large dataset compared to MNST so may not be able to fit all the predictions in the TPU core memory (viz 8gb) at once.

I would really appreciate any possible help! Thanks