Lightning-AI / utilities

Common Python utilities and GitHub Actions in Lightning Ecosystem
Apache License 2.0
49 stars 14 forks source link

`apply_to_collection` doesn't work for cached properties #279

Open jackdent opened 2 months ago

jackdent commented 2 months ago

Motivation

When running apply_to_collection on a dataclass, cached properties do not get modified. This can cause subtle issues: for example, suppose I initialize a dataclass on CPU in a dataworker, and then move it onto GPU for a model batch. All of the dataclass fields that contain Tensors get moved correctly, but the cached_propertys continue to residue on the original device.

Steps to reproduce

import dataclasses
from functools import cached_property

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor

@dataclasses.dataclass
class Data:
    a: Tensor

    @cached_property
    def b(self):
        print("*" * 10)
        print("Computing and cache prop b")
        print("*" * 10)
        return self.a * 2

print("*" * 10)
print("Data on CPU")
print("*" * 10)

data = Data(a=torch.tensor([1, 2, 3], device="cuda"))
print(f"{data.a=}")
print(f"{data.a.device=}")

print(f"{data.b=}")
print(f"{data.b=}")  # do this a second time to make sure we're caching it
print(f"{data.b.device=}")

print("*" * 10)
print("Move Data to GPU")
print("*" * 10)

new_data = apply_to_collection(data, Tensor, lambda x: x.to("cpu"))
print(f"{new_data.a=}")
print(f"{new_data.a.device=}")

print(f"{new_data.b=}")
print(f"{new_data.b=}")  # do this a second time to make sure we're caching it
print(f"{new_data.b.device=}")

Yields the following output:

**********
Start with data on GPU
**********
data.a=tensor([1, 2, 3], device='cuda:0')
data.a.device=device(type='cuda', index=0)
**********
Computing and cache prop b
**********
data.b=tensor([2, 4, 6], device='cuda:0')
data.b=tensor([2, 4, 6], device='cuda:0')
data.b.device=device(type='cuda', index=0)
**********
Move Data to CPU
**********
new_data.a=tensor([1, 2, 3])
new_data.a.device=device(type='cpu')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b.device=device(type='cuda', index=0)
jackdent commented 2 months ago

The Lightning apply_to_collection logic is defined here and relies on dataclass.fields, which doesn't include cached properties

Borda commented 1 month ago

@awaelchli, do you have any experience with this one?

awaelchli commented 1 month ago

Hey @jackdent This is a rare use case and I won't have the bandwidth to look into it. We would be grateful for a contribution here if you're interested. The fix is probably to just reset the cache when running apply_to_collection.