ViCCo-Group / thingsvision

Python package for extracting representations from state-of-the-art computer vision models
https://vicco-group.github.io/thingsvision/
MIT License
157 stars 21 forks source link

TypeError from flatten_acts() when executing extractor.extract_features() #109

Closed PhilippKaniuth closed 2 years ago

PhilippKaniuth commented 2 years ago

Executing this:

import torch
from thingsvision import get_extractor
from thingsvision.utils.storing import save_features
from thingsvision.utils.data import ImageDataset, DataLoader
from thingsvision.core.extraction import center_features

model_name = 'clip'
module_name = 'visual'
source = 'custom'
batch_size = 32
root = '/home/pkaniuth/path/root'
out_path = '/home/pkaniuth/path/out'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

extractor = get_extractor(
  model_name=model_name, 
  pretrained=True, 
  model_path=None, 
  device=device, 
  source=source, 
  model_parameters={'variant': 'ViT-B/32'},
)

dataset = ImageDataset(
  root=root,
  out_path=out_path,
  backend=extractor.get_backend(),
  transforms=extractor.get_transformations()
)

batches = DataLoader(
  dataset=dataset,
  batch_size=batch_size, 
  backend=extractor.get_backend()
)

features = extractor.extract_features(
  batches=batches,
  module_name=module_name,
  flatten_acts=True
)

throws:

...Creating dataset.
Batch:   0%|                                                                                                                              | 0/2 [00:06<?, ?it/s]
Traceback (most recent call last):
  File "/home/pkaniuth/extractivations/new.py", line 38, in <module>
    features = extractor.extract_features(
  File "/home/pkaniuth/thingsvision/conda_env/lib/python3.9/site-packages/thingsvision/core/extraction/base.py", line 95, in extract_features
    self._extract_features(
  File "/home/pkaniuth/thingsvision/conda_env/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/pkaniuth/thingsvision/conda_env/lib/python3.9/site-packages/thingsvision/core/extraction/mixin.py", line 58, in _extract_features
    act = self.flatten_acts(act)
TypeError: flatten_acts() missing 2 required positional arguments: 'img' and 'module_name'

@andropar you hypothesized it might be that the flattening function wasn't adjusted accordingly.

PhilippKaniuth commented 2 years ago

In a clean env with a fresh thingsvision install (v2.2.8) I tested:

and all works like a charm. This issue seems therefore resolved.