pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.18k stars 6.95k forks source link

torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow for some published model using BatchNorm2d #5104

Open oliver-batchelor opened 2 years ago

oliver-batchelor commented 2 years ago

🐛 Describe the bug

Using the new torch fx feature extractor has trouble with the standard module BatchNorm2d (and presumably others?). Edited as original report was not quite correct.

import torch
from structs.torch import shape
from torchvision.models import feature_extraction, resnet18
import timm

model = timm.create_model('regnetx_004')
input = torch.randn((1, 3, 224, 224))

extractor = feature_extraction.create_feature_extractor(model, ["s1", "s2", "s3", "s4"])
x = extractor(input)
  File "/home/oliver/miniconda3/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 135, in forward
    self._check_input_dim(input)
  File "/home/oliver/miniconda3/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py", line 406, in _check_input_dim
    if input.dim() != 4:
  File "/home/oliver/miniconda3/lib/python3.9/site-packages/torch/fx/proxy.py", line 251, in __bool__
    return self.tracer.to_bool(self)
  File "/home/oliver/miniconda3/lib/python3.9/site-packages/torch/fx/proxy.py", line 152, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

Code in question:

    def _check_input_dim(self, input):
       if input.dim() != 4:
            raise ValueError("expected 4D input (got {}D input)".format(input.dim()))

Versions

Collecting environment information... PyTorch version: 1.10.0 Is debug build: False CUDA used to build PyTorch: 11.3 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64) GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0 Clang version: Could not collect CMake version: version 3.16.3 Libc version: glibc-2.31

Python version: 3.9.5 (default, Jun 4 2021, 12:28:51) [GCC 7.5.0] (64-bit runtime) Python platform: Linux-5.4.0-91-generic-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: 11.4.120 GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2070 Nvidia driver version: 470.86 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.1 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.1 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.1 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.1 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.1 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.1 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.1

prabhat00155 commented 2 years ago

I don't see layer1 in regnetx_004 model. You can try printing the model after you create it to see the different layers and their names.

oliver-batchelor commented 2 years ago

Yes you are right - I edited the report, initially I had resnet18 (which has layer1) and then realized resnet18 didn't trigger the problem, but forgot to update the layer name. Sorry! It should be "s1" for the regnetx model.

oliver-batchelor commented 2 years ago

Btw. a simple way to fix the issue seems to be to change the if statement to use torch._assert- I am curious as to how it is decided to use ifrather than torch._assert in these cases?

prabhat00155 commented 2 years ago

Is this specific to timm models? I see the same error as above when trying to list the node names.

from torchvision.models.feature_extraction import get_graph_node_names

model = timm.create_model('regnetx_004')
get_graph_node_names(model)

I am able to use regnet models from torchvision though:

from torchvision.models import create_feature_extractor, regnet_x_8gf

model = regnet_x_8gf()
extractor = feature_extraction.create_feature_extractor(model, ["trunk_output.block1"])
oliver-batchelor commented 2 years ago

Yes you are right, I don't see any problem with the torchvision regnet_x either.

earonesty commented 1 year ago

this happens if you, for example, try to analyze a hf "gpt2" model