Open Ghelfi opened 1 week ago
Here is a snippet to reporduce the error:
from collections.abc import Callable
from typing import cast
import torch.nn.functional as F
from composer import Trainer
from composer.core import DataSpec
from composer.models import ComposerClassifier
from torch import Tensor, nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class Model(nn.Module):
def __init__(self, num_classes: int = 10):
super().__init__()
self.num_classes = num_classes
self.conv = nn.Conv2d(1, 16, (3, 3), padding=0)
self.fc = nn.Linear(16, num_classes)
def forward(self, x: Tensor) -> Tensor:
out = self.conv(x)
out = F.relu(out)
return cast(Tensor, self.fc(out))
def get_device_transform(device: str) -> Callable[[tuple[Tensor, Tensor]], tuple[Tensor, Tensor]]:
def _transform(batch: tuple[Tensor, Tensor]) -> tuple[Tensor, Tensor]:
"""Dummy on device transform to check `batch` is actally on the right device."""
assert (
batch[0].device.type == device
), f"found data on device {batch[0].device.type} while expecting device_transform to run on {device}"
assert (
batch[1].device.type == device
), f"found data on device {batch[1].device.type} while expecting device_transform to run on {device}"
return batch
return _transform
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
train_dataloader = DataSpec(
dataloader=DataLoader(dataset, batch_size=16),
device_transforms=get_device_transform(device="cuda"),
get_num_samples_in_batch=lambda batch: len(batch[0]),
)
trainer = Trainer(
model=ComposerClassifier(module=Model(), num_classes=10),
train_dataloader=train_dataloader,
max_duration="2ep",
device="gpu",
)
trainer.fit()
This outputs
AssertionError: found data on device cpu while expecting device_transform to run on cuda
--> device transforms run on cpu
This requires a gpu
to run but one can switch both "cuda"
and "gpu"
by "mps"
to reproduce the error locally on mac accelerator.
We currently measure a ~10x slow down on heavy transform pipelines... If reproducible, It would be nice to provide a patch and not waiting for the next release cycle since this appear to be a regression.
@mvpatel2000 i saw that you patched some algorithms to move data on the device ahead of time.
Is it something that should be done for every transforms? This would be unfortunate since it breaks native switch betwen on-cpu and on-device transforms.
Thanks for flagging this!
The original motivation here is that large batches of images and multimodal data add significant memory pressure if we do not move them incrementally per microbatch.
We decided to avoid transforms at microbatch level because transforms may use batch statistics, and at the time, we didn't think there would be many tradeoffs to leaving it on CPU (our workloads are rarely CPU bottlenecked). Given this issue, I see a few options:
We are still discussing, but if you have any feedback @Ghelfi feel free to chime in. I'm especially curious if option 1 affects you
We do have some transforms that are coherent through batch, meaning same parameters are applied through a whole batch. Having the transforms moved to per-microbatch might change this. This falls into what you mention as " batch statistics" somehow.
Option 2 and 3 could work. A trainer flag stating the move happens at the batch level as an opt-in could let people have both way.
Did you kept the name device_tranfrorms
for backcompatibility? Wouldn't batched_transforms
be a less confusing name?
We do have some transforms that are coherent through batch, meaning same parameters are applied through a whole batch. Having the transforms moved to per-microbatch might change this. This falls into what you mention as " batch statistics" somehow.
How intensive are these and do you care about ordering? For example, what if we had batch_transforms run on CPU and microbatch_transforms run on GPU
Did you kept the name
device_tranfrorms
for backcompatibility? Wouldn'tbatched_transforms
be a less confusing name?
Yes, primarily because of this
We use intensive transform, well accelerated on GPU.
I'd be more in favor of having flag enabling transfer at the batch level as an opt-in. This allow and easy way to fall back to previous behaviour, is opt-in, and enable "True" device_transforms
.
If breaking back-compatibility is ok, we can have batch_transforms
on CPU and device_transforms
coupled with data transfer either at the batch or microbatch level depending on a flag data_transfer_stage: Literal["batch", "microbatch"] = "microbatch"
.
Description
Since release 0.25.0
DataSpec.device_transforms
do not run on device.2 successive PR changed where batches / micro_batches are moved to device and where
device_transforms
are applied in an inconsistent way:device_transforms
and copy at the microbatch leveldevice_transforms
before moving the batch to device.Hence,
device_transforms
are applied by batch but systematically on thecpu
.