Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.55k stars 3.31k forks source link

Support TorchEval #19716

Open jaanli opened 3 months ago

jaanli commented 3 months ago

Bug description

Half of our team uses vanilla PyTorch and the other half uses (PyTorch) Lightning. We need to use several custom metrics for our use case, and need fine-grained control over which device these metrics are on (and for use with fully-sharded data parallel language models).

However, it seems like we are blocked on using Lightning and these custom metrics, potentially because Lightning recommends one delete any model.to(device) calls: https://lightning.ai/docs/pytorch/stable/accelerators/accelerator_prepare.html#delete-cuda-or-to-calls

And custom metrics, such as those that require torcheval, sometimes need to call metric.to(device) as some computation can only happen on GPU and some only on CPU. One example would be binned precision recall curves in extreme multi-label classification, where it is computationally infeasible (takes too long) on CPU, but GPU memory is exhausted for large language models -- so metrics and intermediate steps must be copied onto and off of GPU memory during training for early stopping.

Here's one example:

https://pytorch.org/torcheval/main/_modules/torcheval/metrics/classification/binned_precision_recall_curve.html#MulticlassBinnedPrecisionRecallCurve

To reproduce, based on a vanilla PyTorch example: https://raw.githubusercontent.com/pytorch/examples/main/mnist/main.py

The specific error with custom metrics is:

ValueError: `self.log(val_acc, <torcheval.metrics.classification.accuracy.MulticlassAccuracy object at 0x16fc1ed10>)` was called, but `MulticlassAccuracy` values cannot be logged

Any advice on how to include custom metrics in Lightning from torcheval or otherwise, that require passing intermediate states on or off of GPUs?

Or must such evaluation for early stopping happen outside Lightning modules?

Thank you - any advice appreciated on the canonical ways to solve this problem - we can't be the only ones running into this blocker on Lightning usage with standard PyTorch tools...

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Vanilla PyTorch:

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

from torcheval.metrics import MulticlassAccuracy

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

def train(args, model, device, train_loader, optimizer, epoch, metric):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        metric.update(output.argmax(dim=1), target)

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.2f}%'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(),
                100. * metric.compute()
            ))
            metric.reset()
            if args.dry_run:
                break

def test(model, device, test_loader, metric):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss

            metric.update(output.argmax(dim=1), target)

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, metric.compute(), len(test_loader.dataset),
        100. * metric.compute()))
    metric.reset()

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--no-mps', action='store_true', default=False,
                        help='disables macOS GPU training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    use_mps = not args.no_mps and torch.backends.mps.is_available()

    torch.manual_seed(args.seed)

    if use_cuda:
        device = torch.device("cuda")
    elif use_mps:
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    metric = MulticlassAccuracy(num_classes=10).to(device)

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch, metric)
        test(model, device, test_loader, metric)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")

if __name__ == '__main__':
    main()

This runs fine.

With PyTorch Lightning:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torcheval.metrics import MulticlassAccuracy
import lightning as L

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

class LitMNIST(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = Net()
        self.metric = MulticlassAccuracy(num_classes=10)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.nll_loss(output, target)
        self.metric.update(output.argmax(dim=1), target)
        self.log('train_loss', loss)
        self.log('train_acc', self.metric, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.nll_loss(output, target)
        self.metric.update(output.argmax(dim=1), target)
        self.log('val_loss', loss)
        self.log('val_acc', self.metric, on_step=True, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = optim.Adadelta(self.parameters(), lr=1.0)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)
        return [optimizer], [scheduler]

def main(args):
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('../data', train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=1, pin_memory=True, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, num_workers=1, pin_memory=True)

    model = LitMNIST()

    trainer = L.Trainer(
        max_epochs=args.epochs,
        devices=1,
        accelerator='auto',
        log_every_n_steps=args.log_interval
    )

    trainer.fit(model, train_loader, test_loader)

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()

    main(args)

### Error messages and logs

ValueError: self.log(val_acc, <torcheval.metrics.classification.accuracy.MulticlassAccuracy object at 0x16fc1ed10>) was called, but MulticlassAccuracy values cannot be logged



### Environment

<details>
  <summary>Current environment</summary>

* CUDA:
    - GPU:               None
    - available:         False
    - version:           None
* Lightning:
    - lightning:         2.2.1
    - lightning-utilities: 0.11.0
    - pytorch-lightning: 2.2.1
    - torch:             2.1.2
    - torchaudio:        2.1.2
    - torcheval:         0.0.7
    - torchmetrics:      1.3.2
    - torchvision:       0.16.2
* Packages:
    - aiobotocore:       2.12.0
    - aiohttp:           3.9.1
    - aiohttp-cors:      0.7.0
    - aioitertools:      0.11.0
    - aiorwlock:         1.4.0
    - aiosignal:         1.3.1
    - annotated-types:   0.6.0
    - antlr4-python3-runtime: 4.9.3
    - anyio:             4.2.0
    - appdirs:           1.4.4
    - appnope:           0.1.3
    - argon2-cffi:       23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow:             1.3.0
    - astroid:           3.0.3
    - asttokens:         2.4.1
    - async-lru:         2.0.4
    - attrs:             23.2.0
    - babel:             2.14.0
    - beartype:          0.17.0
    - beautifulsoup4:    4.12.2
    - black:             24.2.0
    - bleach:            6.1.0
    - blessed:           1.20.0
    - boto3:             1.34.31
    - botocore:          1.34.51
    - build:             1.0.3
    - cachetools:        5.3.2
    - certifi:           2023.11.17
    - cffi:              1.16.0
    - cfgv:              3.4.0
    - charset-normalizer: 3.3.2
    - click:             8.1.7
    - cloudpickle:       3.0.0
    - colorcet:          3.0.1
    - colorful:          0.5.6
    - colorspacious:     1.1.2
    - comm:              0.2.0
    - contourpy:         1.2.0
    - cramjam:           2.7.0
    - cryptography:      42.0.5
    - cycler:            0.12.1
    - dask:              2024.1.0
    - datamapplot:       0.1.0
    - datashader:        0.16.0
    - debugpy:           1.8.0
    - decorator:         5.1.1
    - defusedxml:        0.7.1
    - dill:              0.3.8
    - distlib:           0.3.8
    - docker-pycreds:    0.4.0
    - einops:            0.7.0
    - et-xmlfile:        1.1.0
    - executing:         2.0.1
    - fastapi:           0.109.0
    - fastjsonschema:    2.19.1
    - fastparquet:       2023.10.1
    - filelock:          3.13.1
    - fonttools:         4.47.0
    - fqdn:              1.5.1
    - frozenlist:        1.4.1
    - fsspec:            2024.2.0
    - functiontrace:     0.3.7
    - gitdb:             4.0.11
    - gitpython:         3.1.41
    - google-api-core:   2.15.0
    - google-api-python-client: 2.123.0
    - google-auth:       2.26.2
    - google-auth-httplib2: 0.2.0
    - googleapis-common-protos: 1.62.0
    - gpustat:           1.1.1
    - grpcio:            1.60.0
    - h11:               0.14.0
    - hnswlib:           0.8.0
    - httplib2:          0.22.0
    - httptools:         0.6.1
    - huggingface-hub:   0.20.2
    - identify:          2.5.35
    - idna:              3.6
    - imageio:           2.33.1
    - importlib-metadata: 7.0.1
    - iniconfig:         2.0.0
    - ipdb:              0.13.13
    - iprogress:         0.4
    - ipykernel:         6.28.0
    - ipython:           8.19.0
    - ipywidgets:        8.1.1
    - isoduration:       20.11.0
    - isort:             5.13.2
    - jaxtyping:         0.2.25
    - jedi:              0.19.1
    - jinja2:            3.1.2
    - jmespath:          1.0.1
    - joblib:            1.3.2
    - json5:             0.9.14
    - jsonpointer:       2.4
    - jsonschema:        4.20.0
    - jsonschema-specifications: 2023.12.1
    - jupyter:           1.0.0
    - jupyter-client:    8.6.0
    - jupyter-console:   6.6.3
    - jupyter-core:      5.5.1
    - jupyter-events:    0.9.0
    - jupyter-lsp:       2.2.1
    - jupyter-server:    2.12.3
    - jupyter-server-terminals: 0.5.1
    - jupyterlab:        4.0.10
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.25.2
    - jupyterlab-widgets: 3.0.9
    - kiwisolver:        1.4.5
    - lazy-loader:       0.3
    - lightning:         2.2.1
    - lightning-utilities: 0.11.0
    - line-profiler:     4.1.2
    - llvmlite:          0.41.1
    - locket:            1.0.0
    - lxml:              5.1.0
    - markupsafe:        2.1.3
    - matplotlib:        3.8.2
    - matplotlib-inline: 0.1.6
    - mccabe:            0.7.0
    - mistune:           3.0.2
    - mpl-axes-aligner:  1.3
    - mpmath:            1.3.0
    - msgpack:           1.0.7
    - multidict:         6.0.4
    - multipledispatch:  1.0.0
    - mypy-extensions:   1.0.0
    - nbclient:          0.9.0
    - nbconvert:         7.14.0
    - nbformat:          5.9.2
    - nest-asyncio:      1.5.8
    - networkx:          3.2.1
    - nltk:              3.8.1
    - nodeenv:           1.8.0
    - notebook:          7.0.6
    - notebook-shim:     0.2.3
    - numba:             0.58.1
    - numpy:             1.26.2
    - nvidia-ml-py:      12.535.133
    - omegaconf:         2.3.0
    - opencensus:        0.11.4
    - opencensus-context: 0.1.3
    - openpyxl:          3.1.2
    - overrides:         7.4.0
    - packaging:         23.2
    - pandas:            2.1.4
    - pandas-read-xml:   0.3.1
    - pandocfilters:     1.5.0
    - param:             2.0.1
    - parso:             0.8.3
    - partd:             1.4.1
    - pathspec:          0.12.1
    - pexpect:           4.9.0
    - pillow:            10.1.0
    - pip:               24.0
    - pip-tools:         7.4.0
    - platformdirs:      4.1.0
    - pluggy:            1.4.0
    - polars:            0.20.2
    - pre-commit:        3.6.2
    - prometheus-client: 0.19.0
    - prompt-toolkit:    3.0.43
    - protobuf:          4.25.2
    - psutil:            5.9.7
    - ptyprocess:        0.7.0
    - pure-eval:         0.2.2
    - py-spy:            0.3.14
    - pyarrow:           14.0.2
    - pyasn1:            0.5.1
    - pyasn1-modules:    0.3.0
    - pycparser:         2.21
    - pyct:              0.5.0
    - pydantic:          2.5.3
    - pydantic-core:     2.14.6
    - pygments:          2.17.2
    - pyinstrument:      4.6.2
    - pylint:            3.0.3
    - pyod:              1.1.2
    - pyparsing:         3.1.1
    - pyproject-hooks:   1.0.0
    - pystemmer:         2.2.0.1
    - pytest:            8.1.1
    - python-dateutil:   2.8.2
    - python-dotenv:     1.0.0
    - python-json-logger: 2.0.7
    - pytorch-lightning: 2.2.1
    - pytz:              2023.3.post1
    - pyyaml:            6.0.1
    - pyzmq:             25.1.2
    - qtconsole:         5.5.1
    - qtpy:              2.4.1
    - ray:               2.9.1
    - referencing:       0.32.1
    - regex:             2023.12.25
    - requests:          2.31.0
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rpds-py:           0.16.2
    - rsa:               4.9
    - s3fs:              2024.2.0
    - s3transfer:        0.10.0
    - safetensors:       0.4.1
    - scikit-image:      0.22.0
    - scikit-learn:      1.3.2
    - scipy:             1.11.4
    - seaborn:           0.13.0
    - send2trash:        1.8.2
    - sentry-sdk:        1.39.2
    - setproctitle:      1.3.3
    - setuptools:        69.0.3
    - six:               1.16.0
    - smart-open:        6.4.0
    - smmap:             5.0.1
    - sniffio:           1.3.0
    - soupsieve:         2.5
    - stack-data:        0.6.3
    - starlette:         0.35.1
    - sympy:             1.12
    - tensorboardx:      2.6.2.2
    - terminado:         0.18.0
    - threadpoolctl:     3.2.0
    - tifffile:          2023.12.9
    - tinycss2:          1.2.1
    - tokenizers:        0.15.0
    - tomlkit:           0.12.3
    - toolz:             0.12.0
    - torch:             2.1.2
    - torchaudio:        2.1.2
    - torcheval:         0.0.7
    - torchmetrics:      1.3.2
    - torchvision:       0.16.2
    - tornado:           6.4
    - tqdm:              4.66.1
    - traitlets:         5.14.0
    - transformers:      4.36.2
    - typeguard:         2.13.3
    - types-python-dateutil: 2.8.19.20240106
    - typing-extensions: 4.9.0
    - tzdata:            2023.4
    - uri-template:      1.3.0
    - uritemplate:       4.1.1
    - urllib3:           2.0.7
    - uvicorn:           0.26.0
    - uvloop:            0.19.0
    - virtualenv:        20.25.0
    - wandb:             0.16.2
    - watchfiles:        0.21.0
    - wcwidth:           0.2.12
    - webcolors:         1.13
    - webencodings:      0.5.1
    - websocket-client:  1.7.0
    - websockets:        12.0
    - wheel:             0.42.0
    - widgetsnbextension: 4.0.9
    - wrapt:             1.16.0
    - xarray:            2023.12.0
    - xmltodict:         0.13.0
    - yarl:              1.9.4
    - zipfile36:         0.1.3
    - zipp:              3.17.0
* System:
    - OS:                Darwin
    - architecture:
        - 64bit
        - 
    - processor:         arm
    - python:            3.11.5
    - release:           23.4.0
    - version:           Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:25 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6030

</details>

### More info

_No response_

cc @borda @carmocca
awaelchli commented 3 months ago

Hey @jaanli

I'm converting this to a feature request because it isn't a bug. The docs for self.log state that you can log a float, tensor, or TorchMetrics Metric. The ValueError you get is directly from Lightning, informing the user that you can't log arbitrary objects, which is correct.

If you want to log a TorchEval Metric, I suggest you compute it normally and then pass the value (scalar tensor) to the self.log() call. Regarding the device management, that's just a general advice that Lightning gives. If you see the need, feel free to force your metric modules on CPU to perform the computations.

jaanli commented 3 months ago

Super helpful, thanks so much @awaelchli ! Will see if we have the engineering support to fix this compatibility issue :)