Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.83k stars 1.08k forks source link

Cannot compile CenterSpatialCrop #8191

Open ziw-liu opened 3 hours ago

ziw-liu commented 3 hours ago

Describe the bug Using torch.compile to optimize MONAI transforms generally works (apart from graph breaks), but CenterSpatialCrop (and its dictionary wrapper) does not.

To Reproduce

import torch
from monai.data.meta_obj import set_track_meta
from monai.transforms import (
    CenterSpatialCrop,
    RandAdjustContrast,
    RandAffine,
    RandFlip,
    RandGaussianNoise,
    RandGaussianSmooth,
    RandScaleIntensity,
    RandSpatialCropSamples,
)

# avoid subclassing tensor
set_track_meta(False)

transforms = [
    RandAffine(
        prob=1.0,
        rotate_range=(torch.pi, 0, 0),
        scale_range=(0, 0.3, 0.3),
        padding_mode="zeros",
        mode="bilinear",
    ),
    CenterSpatialCrop(roi_size=(1, 256, 256)),
    RandSpatialCropSamples(roi_size=(1, 256, 256), num_samples=2),
    RandFlip(prob=0.5, spatial_axis=(1, 2)),
    RandAdjustContrast(prob=0.5, gamma=(0.8, 1.2)),
    RandScaleIntensity(factors=0.5, prob=0.5),
    RandGaussianNoise(prob=0.5, mean=0.0, std=0.3),
    RandGaussianSmooth(
        sigma_x=(0.25, 0.75),
        sigma_y=(0.25, 0.75),
        sigma_z=(0.0, 0.0),
        prob=0.5,
    ),
]

img = torch.rand(1, 1, 512, 512, dtype=torch.float32, device="cuda")

@torch.compile
def apply_transform(x, tf):
    tf(x)

for tf in transforms:
    try:
        apply_transform(img, tf)
        print(f"{type(tf)} compiled successfully.")
    except Exception as e:
        assert isinstance(tf, CenterSpatialCrop)
        print(f"Failed to compile {type(tf)}.")
        print(e)

This script shows this error message:

Failed to compile <class 'monai.transforms.croppad.array.CenterSpatialCrop'>.
Failed running call_function <built-in method as_tensor of type object at 0x7f147a4e8240>(*([FakeTensor(..., size=(), dtype=torch.int16), FakeTensor(..., size=(), dtype=torch.int16), FakeTensor(..., size=(), dtype=torch.int16)],), **{'dtype': torch.int16, 'device': 'cpu'}):
The tensor has a non-zero number of elements, but its data is not allocated yet.
If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

from user code:
   File "/home/user.name/viscy/viscy/scripts/bench_compile_transform.py", line 44, in apply_transform
    tf(x)
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/transforms/croppad/array.py", line 533, in __call__
    slices=self.compute_slices(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]),
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/transforms/croppad/array.py", line 522, in compute_slices
    return super().compute_slices(roi_center=roi_center, roi_size=roi_size)
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/transforms/croppad/array.py", line 392, in compute_slices
    roi_center_t = convert_to_tensor(data=roi_center, dtype=torch.int16, wrap_sequence=True, device="cpu")
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/utils/type_conversion.py", line 174, in convert_to_tensor
    return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret
  File "/hpc/mydata/user.name/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/utils/type_conversion.py", line 149, in _convert_tensor
    tensor = torch.as_tensor(tensor, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Expected behavior CenterSpatialCrop can be compiled just as other transforms.

Environment

Ensuring you use the relevant python executable, please paste the output of:

================================
Printing MONAI config...
================================
MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.5.0+cu124
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /hpc/mydata/<username>/anaconda/2022.05/x86_64/envs/viscy/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: 0.24.0
scipy version: 1.14.0
Pillow version: 10.4.0
Tensorboard version: 2.17.1
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.20.0+cu124
tqdm version: 4.66.5
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 6.0.0
pandas version: 2.2.2
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...
================================
System: Linux
Linux version: Rocky Linux 8.10 (Green Obsidian)
Platform: Linux-4.18.0-553.16.1.el8_10.x86_64-x86_64-with-glibc2.28
Processor: x86_64
Machine: x86_64
Python version: 3.11.9
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: [popenfile(path='/home/<username>/.vscode-server/data/logs/20241031T101853/remoteagent.log', fd=19, position=5336, mode='a', flags=33793), popenfile(path='/home/<username>/.vscode-server/data/logs/20241031T101853/ptyhost.log', fd=20, position=4686, mode='a', flags=33793)]
Num physical CPUs: 16
Num logical CPUs: 16
Num usable CPUs: 16
CPU usage (%): [8.5, 8.5, 3.5, 8.1, 3.9, 3.9, 3.2, 4.6, 5.0, 18.6, 23.9, 3.5, 3.5, 3.6, 4.3, 4.6]
CPU freq. (MHz): 2935
Load avg. in last 1, 5, 15 mins (%): [0.6, 0.5, 1.4]
Disk usage (%): 93.3
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 503.8
Available memory (GB): 440.0
Used memory (GB): 27.3

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 12.4
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
cuDNN version: 90100
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA A40
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 84
GPU 0 Total memory (GB): 44.7
GPU 0 CUDA capability (maj.min): 8.6

Additional context The error message points to the tensor conversion called in the Crop class. Curiously the other cropping transform (RandSpatialCropSamples) does work.

Edit: fix typo

ziw-liu commented 3 hours ago

Also despite the fix for https://github.com/pytorch/pytorch/issues/117026 upstream, set_track_meta(False) is still needed in torch 2.5.1 to avoid using the meta tensor.