facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.24k stars 330 forks source link

Extract feature using ViT model #547

Closed VicaYang closed 2 years ago

VicaYang commented 2 years ago

It seems that this repo does not provide the YAML file for extracting feature using ViT models, so I tried to implement it, but met troubles.

  1. full code you wrote or full changes you made (git diff) I add file configs/config/feature_extraction/trunk_only/vit_b16.yaml and try two different implementation. The first one is borrowed from other YAML files in the same folder, while the second one is implemented following the instruction extract-features-from-several-layers-of-the-trunk

@package global

config: MODEL: FEATURE_EVAL_SETTINGS: EVAL_MODE_ON: True FREEZE_TRUNK_ONLY: True EXTRACT_TRUNK_FEATURES_ONLY: True SHOULD_FLATTEN_FEATS: False LINEAR_EVAL_FEAT_POOL_OPS_MAP: [ ["norm", ["Identity", []]], ] TRUNK: NAME: vision_transformer VISION_TRANSFORMERS: IMAGE_SIZE: 224 PATCH_SIZE: 16 NUM_LAYERS: 12 NUM_HEADS: 12 HIDDEN_DIM: 768 MLP_DIM: 3072 DROPOUT_RATE: 0.1 ATTENTION_DROPOUT_RATE: 0 CLASSIFIER: token

config: MODEL: FEATURE_EVAL_SETTINGS: EVAL_MODE_ON: True FREEZE_TRUNK_ONLY: True EXTRACT_TRUNK_FEATURES_ONLY: True SHOULD_FLATTEN_FEATS: False TRUNK: NAME: vision_transformer VISION_TRANSFORMERS: IMAGE_SIZE: 224 PATCH_SIZE: 16 NUM_LAYERS: 12 NUM_HEADS: 12 HIDDEN_DIM: 768 MLP_DIM: 3072 DROPOUT_RATE: 0.1 ATTENTION_DROPOUT_RATE: 0 CLASSIFIER: token EXTRACT_FEATURES: OUTPUT_DIR: "extracted_feature" CHUNK_THRESHOLD: 0

3. what exact command you run:
I extract the feature using the provided weight

python run_distributed_engines.py \ hydra.verbose=true \ config=feature_extraction/extract_resnet_in1k_8gpu \ +config/feature_extraction/trunk_only=vit_b16 \ config.CHECKPOINT.DIR="feature/supervised" \ config.MODEL.WEIGHTS_INIT.PARAMS_FILE="../weights/vit_b16_p16_in22k_ep90_supervised.torch" \ config.MODEL.WEIGHTS_INIT.APPEND_PREFIX="trunk.base_model." \ config.MODEL.WEIGHTS_INIT.STATE_DICT_KEY_NAME=classy_state_dict

5. __full logs__ you observed:
For the first config, I got the following message

Exception in thread Thread-1: Traceback (most recent call last): File "/home/vica/anaconda3/envs/vissl/lib/python3.8/threading.py", line 932, in _bootstrap_inner self.run() File "/home/vica/anaconda3/envs/vissl/lib/python3.8/threading.py", line 870, in run self._target(*self._args, **self._kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py", line 25, in _pin_memory_loop r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/queues.py", line 116, in get return _ForkingPickler.loads(res) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 282, in rebuild_storage_fd fd = df.detach() File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/resource_sharer.py", line 57, in detach with _resource_sharer.get_connection(self._id) as conn: File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/resource_sharer.py", line 87, in get_connection c = Client(address, authkey=process.current_process().authkey) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 508, in Client answer_challenge(c, authkey) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 752, in answer_challenge message = connection.recv_bytes(256) # reject large message File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes buf = self._recv_bytes(maxlength) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes buf = self._recv(4) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 379, in _recv chunk = read(handle, remaining) ConnectionResetError: [Errno 104] Connection reset by peer Traceback (most recent call last): File "run_distributed_engines.py", line 57, in hydra_main(overrides=overrides) File "run_distributed_engines.py", line 42, in hydra_main launch_distributed( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/utils/distributed_launcher.py", line 135, in launch_distributed torch.multiprocessing.spawn( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn return start_processes(fn, args, nprocs, join, daemon, start_method='spawn') File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes while not context.join(): File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 1 terminated with the following error: Traceback (most recent call last): File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap fn(i, args) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/utils/distributed_launcher.py", line 192, in _distributed_worker run_engine( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/engine_registry.py", line 86, in run_engine engine.run_engine( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/extract_features.py", line 39, in run_engine extract_main( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/extract_features.py", line 106, in extract_main trainer.extract(output_folder=cfg.EXTRACT_FEATURES.OUTPUT_DIR or checkpoint_folder) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 365, in extract self._extract_split_features(feat_names, self.task, split, output_folder) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 443, in _extract_split_features features = task.model(input_sample["input"]) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward output = self.module(*inputs[0], *kwargs[0]) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 97, in call return self.forward(args, kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 111, in forward out = self.classy_model(*args, kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, *kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/models/base_ssl_model.py", line 180, in forward return self.single_input_forward(batch, self._output_feature_names, self.heads) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/models/base_ssl_model.py", line 128, in single_input_forward feats = self.trunk(batch, feature_names) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/models/trunks/feature_extractor.py", line 36, in forward assert len(feats) == len( AssertionError: #features returned by base model (0) != #Pooling Ops (1)


For the second one, I got

ERROR 2022-05-15 07:59:19,346 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:19,420 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:19,505 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:19,749 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:19,983 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:20,110 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:20,123 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:20,250 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 Traceback (most recent call last): File "run_distributed_engines.py", line 57, in hydra_main(overrides=overrides) File "run_distributed_engines.py", line 42, in hydra_main launch_distributed( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/utils/distributed_launcher.py", line 135, in launch_distributed torch.multiprocessing.spawn( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn return start_processes(fn, args, nprocs, join, daemon, start_method='spawn') File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes while not context.join(): File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 4 terminated with the following error: Traceback (most recent call last): File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap fn(i, *args) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/utils/distributed_launcher.py", line 192, in _distributed_worker run_engine( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/engine_registry.py", line 86, in run_engine engine.run_engine( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/extract_features.py", line 39, in run_engine extract_main( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/extract_features.py", line 106, in extract_main trainer.extract(output_folder=cfg.EXTRACT_FEATURES.OUTPUT_DIR or checkpoint_folder) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 365, in extract self._extract_split_features(feat_names, self.task, split, output_folder) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 444, in _extract_split_features flat_features_list = self._flatten_features_list(features) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 372, in _flatten_features_list assert isinstance(features, list), "features must be of type list" AssertionError: features must be of type list


## Expected behavior:

Extract the feature correctly.

## Environment:

sys.platform linux Python 3.8.13 (default, Mar 28 2022, 11:38:47) [GCC 7.5.0] numpy 1.19.2 Pillow 7.1.2 vissl 0.1.6 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl GPU available True GPU 0,1,2,3,4,5,6 NVIDIA GeForce RTX 3090 CUDA_HOME /usr/local/cuda torchvision 0.10.1 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torchvision hydra 1.0.7 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/hydra classy_vision 0.7.0.dev @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/classy_vision apex 0.1 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/apex PyTorch 1.9.1 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch PyTorch debug build False


PyTorch built with:

CPU info:


Architecture x86_64 CPU op-mode(s) 32-bit, 64-bit Byte Order Little Endian CPU(s) 80 On-line CPU(s) list 0-79 Thread(s) per core 2 Core(s) per socket 20 Socket(s) 2 NUMA node(s) 2 Vendor ID GenuineIntel CPU family 6 Model 85 Model name Intel(R) Xeon(R) Gold 5218R CPU @ 2.10GHz Stepping 7 CPU MHz 3247.779 CPU max MHz 4000.0000 CPU min MHz 800.0000 BogoMIPS 4200.00 Virtualization VT-x L1d cache 32K L1i cache 32K L2 cache 1024K L3 cache 28160K NUMA node0 CPU(s) 0-19,40-59 NUMA node1 CPU(s) 20-39,60-79


VicaYang commented 2 years ago

Well, I dig into the code and use

config:
  MODEL:
    FEATURE_EVAL_SETTINGS:
      EVAL_MODE_ON: True
      FREEZE_TRUNK_ONLY: True
      EXTRACT_TRUNK_FEATURES_ONLY: True
      SHOULD_FLATTEN_FEATS: False
      LINEAR_EVAL_FEAT_POOL_OPS_MAP: [
        ["lastCLS", ["Identity", []]],
      ]
    TRUNK:
      NAME: vision_transformer
      VISION_TRANSFORMERS:
        IMAGE_SIZE: 224
        PATCH_SIZE: 16
        NUM_LAYERS: 12
        NUM_HEADS: 12
        HIDDEN_DIM: 768
        MLP_DIM: 3072
        DROPOUT_RATE: 0.1
        ATTENTION_DROPOUT_RATE: 0
        CLASSIFIER: token

to get the feature correctly. However, I am not so familiar with the architecture of ViT. Any discussion is welcome!

QuentinDuval commented 2 years ago

Hi @VicaYang,

Thanks for using VISSL (and sorry for the late answer, I got COVID then went into 1 month PTO).

So indeed, the configuration you use will work and extract the features of the ViT. The "lastCLS" feature represents the features of the classification token of the last layer.

These are overall good features to use :)

Please ask if you have any specific questions on this, or else I suggest we can close this point !

Thank you, Quentin

VicaYang commented 2 years ago

Thanks a lot! I hope you are well.