mosaicml / composer

Supercharge Your Model Training
http://docs.mosaicml.com
Apache License 2.0
5.16k stars 419 forks source link

`DataSpec.device_transforms` do not run on device #3699

Open Ghelfi opened 1 week ago

Ghelfi commented 1 week ago

Description

Since release 0.25.0 DataSpec.device_transforms do not run on device.

2 successive PR changed where batches / micro_batches are moved to device and where device_transforms are applied in an inconsistent way:

Hence, device_transforms are applied by batch but systematically on the cpu.

Ghelfi commented 1 week ago

Here is a snippet to reporduce the error:

from collections.abc import Callable
from typing import cast

import torch.nn.functional as F
from composer import Trainer
from composer.core import DataSpec
from composer.models import ComposerClassifier
from torch import Tensor, nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class Model(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()

        self.num_classes = num_classes
        self.conv = nn.Conv2d(1, 16, (3, 3), padding=0)
        self.fc = nn.Linear(16, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv(x)
        out = F.relu(out)
        return cast(Tensor, self.fc(out))

def get_device_transform(device: str) -> Callable[[tuple[Tensor, Tensor]], tuple[Tensor, Tensor]]:
    def _transform(batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]:
        """Dummy on device transform to check `batch` is actally on the right device."""
        assert (
            batch[0].device.type == device
        ), f"found data on device {batch[0].device.type} while expecting device_transform to run on {device}"
        assert (
            batch[1].device.type == device
        ), f"found data on device {batch[1].device.type} while expecting device_transform to run on {device}"

        return batch

    return _transform

transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
train_dataloader = DataSpec(
    dataloader=DataLoader(dataset, batch_size=16),
    device_transforms=get_device_transform(device="cuda"),
    get_num_samples_in_batch=lambda batch: len(batch[0]),
)

trainer = Trainer(
    model=ComposerClassifier(module=Model(), num_classes=10),
    train_dataloader=train_dataloader,
    max_duration="2ep",
    device="gpu",
)
trainer.fit()

This outputs

AssertionError: found data on device cpu while expecting device_transform to run on cuda

--> device transforms run on cpu

This requires a gpu to run but one can switch both "cuda" and "gpu" by "mps" to reproduce the error locally on mac accelerator.

We currently measure a ~10x slow down on heavy transform pipelines... If reproducible, It would be nice to provide a patch and not waiting for the next release cycle since this appear to be a regression.

Ghelfi commented 1 week ago

@mvpatel2000 i saw that you patched some algorithms to move data on the device ahead of time.

Is it something that should be done for every transforms? This would be unfortunate since it breaks native switch betwen on-cpu and on-device transforms.

mvpatel2000 commented 1 week ago

Thanks for flagging this!

The original motivation here is that large batches of images and multimodal data add significant memory pressure if we do not move them incrementally per microbatch.

We decided to avoid transforms at microbatch level because transforms may use batch statistics, and at the time, we didn't think there would be many tradeoffs to leaving it on CPU (our workloads are rarely CPU bottlenecked). Given this issue, I see a few options:

We are still discussing, but if you have any feedback @Ghelfi feel free to chime in. I'm especially curious if option 1 affects you

Ghelfi commented 1 week ago

We do have some transforms that are coherent through batch, meaning same parameters are applied through a whole batch. Having the transforms moved to per-microbatch might change this. This falls into what you mention as " batch statistics" somehow.

Option 2 and 3 could work. A trainer flag stating the move happens at the batch level as an opt-in could let people have both way.

Did you kept the name device_tranfrorms for backcompatibility? Wouldn't batched_transforms be a less confusing name?

mvpatel2000 commented 1 week ago

We do have some transforms that are coherent through batch, meaning same parameters are applied through a whole batch. Having the transforms moved to per-microbatch might change this. This falls into what you mention as " batch statistics" somehow.

How intensive are these and do you care about ordering? For example, what if we had batch_transforms run on CPU and microbatch_transforms run on GPU

Did you kept the name device_tranfrorms for backcompatibility? Wouldn't batched_transforms be a less confusing name?

Yes, primarily because of this

Ghelfi commented 1 week ago

We use intensive transform, well accelerated on GPU.

I'd be more in favor of having flag enabling transfer at the batch level as an opt-in. This allow and easy way to fall back to previous behaviour, is opt-in, and enable "True" device_transforms.

If breaking back-compatibility is ok, we can have batch_transforms on CPU and device_transforms coupled with data transfer either at the batch or microbatch level depending on a flag data_transfer_stage: Literal["batch", "microbatch"] = "microbatch".