openvinotoolkit / anomalib

An anomaly detection library comprising state-of-the-art algorithms and features such as experiment management, hyper-parameter optimization, and edge inference.
https://anomalib.readthedocs.io/en/latest/
Apache License 2.0
3.4k stars 615 forks source link

[Bug]: v1.1.0.dev0 timm-1.0.3, feature_extractor, 'FeatureListNet' object has no attribute 'grad_checkpointing' #2129

Open letmejoin opened 3 weeks ago

letmejoin commented 3 weeks ago

Describe the bug

After upgrading timm from 0.6.13 to 1.0.3, the model that could be inferred reported an error. It was found that the return value of the feature extraction part was missing member variables. The following is the specific code location. It may be that only the method was called without initialization, resulting in the failure to load all member variables. member variables of self.feature_extractor https://github.com/openvinotoolkit/anomalib/blob/22caf3badf610641c6b0d4f7ba5d6e1b1e419ce8/src/anomalib/models/components/feature_extractors/timm.py#L124 ['T_destination', '__annotations__', '__call__', '__class__', '__contains__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_collect', '_compiled_call_impl', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_replicate_for_data_parallel', '_save_to_state_dict', '_slow_forward', '_state_dict_hooks', '_state_dict_pre_hooks', '_version', '_wrapped_call_impl', 'act1', 'add_module', 'apply', 'bfloat16', 'bn1', 'buffers', 'call_super_init', 'children', 'clear', 'compile', 'concat', 'conv1', 'cpu', 'cuda', 'default_cfg', 'double', 'dump_patches', 'eval', 'extra_repr', 'feature_info', 'float', 'forward', 'get_buffer', 'get_extra_state', 'get_parameter', 'get_submodule', 'half', 'ipu', 'items', 'keys', 'layer1', 'layer2', 'layer3', 'load_state_dict', 'maxpool', 'modules', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'pop', 'pretrained_cfg', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_full_backward_pre_hook', 'register_load_state_dict_post_hook', 'register_module', 'register_parameter', 'register_state_dict_pre_hook', 'requires_grad_', 'return_layers', 'set_extra_state', 'set_grad_checkpointing', 'share_memory', 'state_dict', 'to', 'to_empty', 'train', 'training', 'type', 'update', 'values', 'xpu', 'zero_grad'] ['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__get__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__self__', '__self_class__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__thisclass__', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_is_full_backward_hook', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_modules', '_non_persistent_buffers_set', '_parameters', '_state_dict_hooks', '_state_dict_pre_hooks', 'concat', 'default_cfg', 'feature_info', 'pretrained_cfg', 'return_layers', 'training'] But if add self.feature_extractor.set_grad_checkpointing(False) before https://github.com/openvinotoolkit/anomalib/blob/22caf3badf610641c6b0d4f7ba5d6e1b1e419ce8/src/anomalib/models/components/feature_extractors/timm.py#L123, It works fine, I don't know if there is a better way.

timm code: `class FeatureDictNet(nn.ModuleDict): """ Feature extractor with OrderedDict return

Wrap a model and extract features as specified by the out indices, the network is
partially re-built from contained modules.

There is a strong assumption that the modules have been registered into the model in the same
order as they are used. There should be no reuse of the same nn.Module more than once, including
trivial modules like `self.relu = nn.ReLU`.

Only submodules that are directly assigned to the model class (`model.feature1`) or at most
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
"""
def __init__(
        self,
        model: nn.Module,
        out_indices: OutIndicesT = (0, 1, 2, 3, 4),
        out_map: Sequence[Union[int, str]] = None,
        output_fmt: str = 'NCHW',
        feature_concat: bool = False,
        flatten_sequential: bool = False,
):
    """
    Args:
        model: Model from which to extract features.
        out_indices: Output indices of the model features to extract.
        out_map: Return id mapping for each output index, otherwise str(index) is used.
        feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
            first element e.g. `x[0]`
        flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
    """
    super(FeatureDictNet, self).__init__()
    self.feature_info = _get_feature_info(model, out_indices)
    self.output_fmt = Format(output_fmt)
    self.concat = feature_concat
    self.grad_checkpointing = False
    self.return_layers = {}

    return_layers = _get_return_layers(self.feature_info, out_map)
    modules = _module_list(model, flatten_sequential=flatten_sequential)
    remaining = set(return_layers.keys())
    layers = OrderedDict()
    for new_name, old_name, module in modules:
        layers[new_name] = module
        if old_name in remaining:
            # return id has to be consistently str type for torchscript
            self.return_layers[new_name] = str(return_layers[old_name])
            remaining.remove(old_name)
        if not remaining:
            break
    assert not remaining and len(self.return_layers) == len(return_layers), \
        f'Return layers ({remaining}) are not present in model'
    self.update(layers)

def set_grad_checkpointing(self, enable: bool = True):
    self.grad_checkpointing = enable

def _collect(self, x) -> (Dict[str, torch.Tensor]):
    out = OrderedDict()
    for i, (name, module) in enumerate(self.items()):
        **_if self.grad_checkpointing and not torch.jit.is_scripting():_**
            # Skipping checkpoint of first module because need a gradient at input
            # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
            # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
            first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
            x = module(x) if first_or_last_module else checkpoint(module, x)
        else:
            x = module(x)

        if name in self.return_layers:
            out_id = self.return_layers[name]
            if isinstance(x, (tuple, list)):
                # If model tap is a tuple or list, concat or select first element
                # FIXME this may need to be more generic / flexible for some nets
                out[out_id] = torch.cat(x, 1) if self.concat else x[0]
            else:
                out[out_id] = x
    return out

def forward(self, x) -> Dict[str, torch.Tensor]:
    return self._collect(x)

class FeatureListNet(FeatureDictNet): """ Feature extractor with list return

A specialization of FeatureDictNet that always returns features as a list (values() of dict).
"""
def __init__(
        self,
        model: nn.Module,
        out_indices: OutIndicesT = (0, 1, 2, 3, 4),
        output_fmt: str = 'NCHW',
        feature_concat: bool = False,
        flatten_sequential: bool = False,
):
    """
    Args:
        model: Model from which to extract features.
        out_indices: Output indices of the model features to extract.
        feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
            first element e.g. `x[0]`
        flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
    """
    super().__init__(
        model,
        out_indices=out_indices,
        output_fmt=output_fmt,
        feature_concat=feature_concat,
        flatten_sequential=flatten_sequential,
    )

def forward(self, x) -> (List[torch.Tensor]):
    return list(self._collect(x).values())`

Dataset

N/A

Model

N/A

Steps to reproduce the behavior

update timm to 1.0.3

OS information

OS information:

Expected behavior

get correct results as before

Screenshots

image

Pip/GitHub

pip

What version/branch did you use?

No response

Configuration YAML

# anomalib==1.1.0dev
seed_everything: true
trainer:
  accelerator: auto
  strategy: auto
  devices: 1
  num_nodes: 1
  precision: 16
  logger: null
  callbacks: null
  fast_dev_run: false
  max_epochs: 200
  min_epochs: null
  max_steps: -1
  min_steps: null
  max_time: null
  limit_train_batches: null
  limit_val_batches: null
  limit_test_batches: null
  limit_predict_batches: null
  overfit_batches: 0.0
  val_check_interval: null
  check_val_every_n_epoch: 10
  num_sanity_val_steps: null
  log_every_n_steps: null
  enable_checkpointing: true
  enable_progress_bar: true
  enable_model_summary: true
  accumulate_grad_batches: 1
  gradient_clip_val: null
  gradient_clip_algorithm: null
  deterministic: null
  benchmark: true
  inference_mode: true
  use_distributed_sampler: true
  profiler: null
  detect_anomaly: true
  barebones: false
  plugins: null
  sync_batchnorm: true
  reload_dataloaders_every_n_epochs: 0
normalization:
  normalization_method: MIN_MAX
task: CLASSIFICATION
metrics:
  image:
  - AUROC
  - F1Score
  pixel: null
  threshold:
    class_path: anomalib.metrics.F1AdaptiveThreshold
    init_args:
      default_value: 0.5
      thresholds: null
      ignore_index: null
      validate_args: true
      compute_on_cpu: false
      dist_sync_on_step: false
      sync_on_compute: true
      compute_with_cache: true
logging:
  log_graph: true
default_root_dir: results
ckpt_path: null
data:
  class_path: anomalib.data.Folder
  init_args:
    name: andriod
    normal_dir: train/good
    root: datasets/Diamond_andriod
    task: CLASSIFICATION
    abnormal_dir: test/bad
    normal_test_dir: test/good
    mask_dir: null
    normal_split_ratio: 0.0
    extensions:
    - .jpeg
    - .jpg
    train_batch_size: 1
    eval_batch_size: 1
    num_workers: 8
    image_size:
    - 1280
    - 1280
    transform: null
    train_transform:
      class_path: torchvision.transforms.v2.Compose
      init_args:
        transforms:
        - class_path: torchvision.transforms.v2.RandomAdjustSharpness
          init_args:
            sharpness_factor: 0.7
            p: 0.5
        - class_path: torchvision.transforms.v2.RandomHorizontalFlip
          init_args:
            p: 0.5
        - class_path: torchvision.transforms.v2.Resize
          init_args:
            size:
            - 1280
            - 1280
            interpolation: BILINEAR
            max_size: null
            antialias: true
        - class_path: torchvision.transforms.v2.CenterCrop
          init_args:
            size:
            - 448
            - 448
        - class_path: torchvision.transforms.v2.Normalize
          init_args:
            mean:
            - 0.485
            - 0.456
            - 0.406
            std:
            - 0.229
            - 0.224
            - 0.225
            inplace: false
    eval_transform:
      class_path: torchvision.transforms.v2.Compose
      init_args:
        transforms:
        - class_path: torchvision.transforms.v2.Resize
          init_args:
            size:
            - 1280
            - 1280
            interpolation: BILINEAR
            max_size: null
            antialias: true
        - class_path: torchvision.transforms.v2.CenterCrop
          init_args:
            size:
            - 448
            - 448
        - class_path: torchvision.transforms.v2.Normalize
          init_args:
            mean:
            - 0.485
            - 0.456
            - 0.406
            std:
            - 0.229
            - 0.224
            - 0.225
            inplace: false
    test_split_mode: from_dir
    test_split_ratio: 0.2
    val_split_mode: same_as_test
    val_split_ratio: 0.5
    seed: null
model:
  class_path: anomalib.models.ReverseDistillation
  init_args:
    backbone: resnet50
    layers:
    - layer1
    - layer2
    - layer3
    anomaly_map_mode: ADD
    pre_trained: true

Logs

Traceback (most recent call last):
  File "/data/wangjx/anomalib/tools/inference/torch_inference.py", line 139, in <module>
    infer(args=args)
  File "/data/wangjx/anomalib/tools/inference/torch_inference.py", line 102, in infer
    predictions = inferencer.predict(image=image, image_size=tuple(args.image_size), center_crop_size=tuple(args.center_crop_size))
  File "/home/algo/wangjx/anomalib/src/anomalib/deploy/inferencers/torch_inferencer.py", line 209, in predict
    predictions = self.forward(processed_image)
  File "/home/algo/wangjx/anomalib/src/anomalib/deploy/inferencers/torch_inferencer.py", line 249, in forward
    return self.model(image)
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/algo/wangjx/anomalib/src/anomalib/deploy/export.py", line 85, in forward
    return self.model(batch)
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/algo/wangjx/anomalib/src/anomalib/models/image/reverse_distillation/torch_model.py", line 72, in forward
    encoder_features = self.encoder(images)
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/algo/wangjx/anomalib/src/anomalib/models/components/feature_extractors/timm.py", line 128, in forward
    features = dict(zip(self.layers, self.feature_extractor(inputs), strict=True))
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/timm/models/_features.py", line 346, in forward
    return list(self._collect(x).values())
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/timm/models/_features.py", line 292, in _collect
    if self.grad_checkpointing and not torch.jit.is_scripting():
  File "/opt/workspace/.conda/envs/wjx_anomalib/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1709, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'FeatureListNet' object has no attribute 'grad_checkpointing'. Did you mean: 'set_grad_checkpointing'?

Code of Conduct

CarlosNacher commented 3 weeks ago

Hi @letmejoin ,

I had the same problem and updating to version 1.2.0dev and retraining in that version solved the issue. I hope it helps!

Cheers

letmejoin commented 2 weeks ago

Hi @letmejoin ,

I had the same problem and updating to version 1.2.0dev and retraining in that version solved the issue. I hope it helps!

Cheers

Sorry for the late reply as I was busy with other things. But after the anomalib version update to 1.2.0.dev, the problem still exists. Is your timm version 1.0.3?