InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.19k stars 232 forks source link

CheckerboardLatentCodec broken in torch<2.0 #302

Open lucasrelic99 opened 2 months ago

lucasrelic99 commented 2 months ago

Bug

Using the CheckerboardLatentCodec with a non-identity context_prediction module results in a runtime error during the forward pass. I believe this should only occur when using a torch version less than 2.0.

To Reproduce

Steps to reproduce the behavior:

  1. Instantiate a CheckerboardLatentCodec.
  2. Create any tensor and pass it to the forward() method of the latent codec.
  3. Observe bug.

Minimal working example:

import torch
from compressai.latent_codecs import CheckerboardLatentCodec, GaussianConditionalLatentCodec
from compressai.layers.layers import CheckerboardMaskedConv2d

lc = CheckerboardLatentCodec(
    latent_codec = {
        "y": GaussianConditionalLatentCodec()
    },
    context_prediction = CheckerboardMaskedConv2d(4, 8, kernel_size=5, stride=1, padding=2)
)

t = torch.randn((1, 4, 64, 64)) # arbitrary shape, just must match channel size in context_prediction layer
ctx = torch.randn((1, 8, 16, 16)) # arbitrary shape

output = lc(t, ctx)

This code results in the error:

File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/software/compressai/compressai/latent_codecs/checkerboard.py:149, in CheckerboardLatentCodec.forward(self, y, side_params)
    147     return self._forward_onepass(y, side_params)
    148 if self.forward_method == "twopass":
--> 149     return self._forward_twopass(y, side_params)
    150 if self.forward_method == "twopass_faster":
    151     return self._forward_twopass_faster(y, side_params)

File ~/software/compressai/compressai/latent_codecs/checkerboard.py:192, in CheckerboardLatentCodec._forward_twopass(self, y, side_params)
    187 B, C, H, W = y.shape
    189 params = y.new_zeros((B, C * 2, H, W))
    191 y_hat_anchors = self._forward_twopass_step(
--> 192     y, side_params, params, self._y_ctx_zero(y), "anchor"
    193 )
    195 y_hat_non_anchors = self._forward_twopass_step(
    196     y, side_params, params, self.context_prediction(y_hat_anchors), "non_anchor"
    197 )
    199 y_hat = y_hat_anchors + y_hat_non_anchors

File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/software/compressai/compressai/latent_codecs/checkerboard.py:272, in CheckerboardLatentCodec._y_ctx_zero(self, y)
    269 @torch.no_grad()
    270 def _y_ctx_zero(self, y: Tensor) -> Tensor:
    271     """Create a zero tensor with correct shape for y_ctx."""
--> 272     y_ctx_meta = self.context_prediction(y.to("meta"))
    273     return y.new_zeros(y_ctx_meta.shape)

File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/software/compressai/compressai/layers/layers.py:144, in MaskedConv2d.forward(self, x)
    141 def forward(self, x: Tensor) -> Tensor:
    142     # TODO(begaintj): weight assigment is not supported by torchscript
    143     self.weight.data = self.weight.data * self.mask
--> 144     return super().forward(x)

File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/nn/modules/conv.py:457, in Conv2d.forward(self, input)
    456 def forward(self, input: Tensor) -> Tensor:
--> 457     return self._conv_forward(input, self.weight, self.bias)

File ~/conda/miniconda3-ubuntu22/envs/sdv2-new/lib/python3.8/site-packages/torch/nn/modules/conv.py:453, in Conv2d._conv_forward(self, input, weight, bias)
    449 if self.padding_mode != 'zeros':
    450     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    451                     weight, bias, self.stride,
    452                     _pair(0), self.dilation, self.groups)
--> 453 return F.conv2d(input, weight, bias, self.stride,
    454                 self.padding, self.dilation, self.groups)

NotImplementedError: convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function

Expected behavior

The code should not throw an error.

Environment

Output from python3 -m torch.utils.collect_env:

PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.8.18 (default, Sep 11 2023, 13:40:15)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-44-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.7.99
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.1
[pip3] open-clip-torch==2.7.0
[pip3] pytorch-lightning==1.4.2
[pip3] pytorch-msssim==1.0.0
[pip3] torch==1.12.1
[pip3] torch_geometric==2.5.3
[pip3] torchaudio==0.12.1
[pip3] torchmetrics==0.6.0
[pip3] torchvision==0.13.1
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1              hb98b00a_13    conda-forge
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py38h7f8727e_0  
[conda] mkl_fft                   1.3.1            py38hd3c417c_0  
[conda] mkl_random                1.2.2            py38h51133e4_0  
[conda] numpy                     1.23.1           py38h6c91a56_0  
[conda] numpy-base                1.23.1           py38ha15fc14_0  
[conda] open-clip-torch           2.7.0                    pypi_0    pypi
[conda] pytorch                   1.12.1          py3.8_cuda11.3_cudnn8.3.2_0    pytorch
[conda] pytorch-lightning         1.4.2                    pypi_0    pypi
[conda] pytorch-msssim            1.0.0                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch                     1.13.0+cu117             pypi_0    pypi
[conda] torch-geometric           2.5.3                    pypi_0    pypi
[conda] torchaudio                0.13.0+cu117             pypi_0    pypi
[conda] torchmetrics              0.6.0              pyhd8ed1ab_0    conda-forge
[conda] torchvision               0.14.0+cu117             pypi_0    pypi
- PyTorch / CompressAI Version: 1.21.1 / 1.2.6
- OS: Linux, Ubuntu 22.04.3
- How you installed PyTorch / CompressAI: source
- Build command you used (if compiling from source):
    git clone https://github.com/InterDigitalInc/CompressAI compressai
    cd compressai
    pip install -U pip && pip install -e .
- Python version: 3.8.18
- CUDA/cuDNN version: 11.7
- GPU models and configuration: 1x NVIDIA GeForce RTX 3090
- Any other relevant information: N/A

Additional context

I am quite certain this is due to the fact that older pytorch versions do not support operations on tensors which are on the "meta" device. I think this was introduced with PyTorch 2.0 but I couldn't find anything definitive from a quick search.

I traced this back to commit eddb1bc, which uses meta device tensors to compute the expected size of the checkerboard context tensor. Replacing these lines with the previous version resolved the issue for me.

YodaEmbedding commented 2 months ago

Thanks for the report.

I could add a version check for torch<2.0:

from packaging.version import Version

class CheckerboardLatentCodec(LatentCodec):
    def _y_ctx_zero(self, y: Tensor) -> Tensor:
        if Version(torch.__version__) < Version("2.0.0"):
            return self._mask(self.context_prediction(y).detach(), "all")
        return y.new_zeros(self.context_prediction(y.to("meta")).shape)

...but perhaps simpler is just to revert:

class CheckerboardLatentCodec(LatentCodec):
    def _y_ctx_zero(self, y: Tensor) -> Tensor:
        return self._mask(self.context_prediction(y).detach(), "all")
lucasrelic99 commented 2 months ago

To be fair I don't actually know which is the earliest torch version that supports meta device tensors as I couldn't find any solid information.

Although I think the simpler fix is probably good enough. On my machine with a 14900K and a 3090 and an (unreasonably large) context size of (16, 192, 512, 512) it takes 0.06ms to execute that line on GPU. It does take about 4 seconds on CPU, but with a more reasonable context size of (16, 192, 32, 32) it takes roughly 80ms on CPU.