Closed VicaYang closed 2 years ago
Well, I dig into the code and use
config:
MODEL:
FEATURE_EVAL_SETTINGS:
EVAL_MODE_ON: True
FREEZE_TRUNK_ONLY: True
EXTRACT_TRUNK_FEATURES_ONLY: True
SHOULD_FLATTEN_FEATS: False
LINEAR_EVAL_FEAT_POOL_OPS_MAP: [
["lastCLS", ["Identity", []]],
]
TRUNK:
NAME: vision_transformer
VISION_TRANSFORMERS:
IMAGE_SIZE: 224
PATCH_SIZE: 16
NUM_LAYERS: 12
NUM_HEADS: 12
HIDDEN_DIM: 768
MLP_DIM: 3072
DROPOUT_RATE: 0.1
ATTENTION_DROPOUT_RATE: 0
CLASSIFIER: token
to get the feature correctly. However, I am not so familiar with the architecture of ViT. Any discussion is welcome!
Hi @VicaYang,
Thanks for using VISSL (and sorry for the late answer, I got COVID then went into 1 month PTO).
So indeed, the configuration you use will work and extract the features of the ViT. The "lastCLS" feature represents the features of the classification token of the last layer.
These are overall good features to use :)
Please ask if you have any specific questions on this, or else I suggest we can close this point !
Thank you, Quentin
Thanks a lot! I hope you are well.
It seems that this repo does not provide the YAML file for extracting feature using ViT models, so I tried to implement it, but met troubles.
git diff
) I add fileconfigs/config/feature_extraction/trunk_only/vit_b16.yaml
and try two different implementation. The first one is borrowed from other YAML files in the same folder, while the second one is implemented following the instruction extract-features-from-several-layers-of-the-trunk@package global
config: MODEL: FEATURE_EVAL_SETTINGS: EVAL_MODE_ON: True FREEZE_TRUNK_ONLY: True EXTRACT_TRUNK_FEATURES_ONLY: True SHOULD_FLATTEN_FEATS: False LINEAR_EVAL_FEAT_POOL_OPS_MAP: [ ["norm", ["Identity", []]], ] TRUNK: NAME: vision_transformer VISION_TRANSFORMERS: IMAGE_SIZE: 224 PATCH_SIZE: 16 NUM_LAYERS: 12 NUM_HEADS: 12 HIDDEN_DIM: 768 MLP_DIM: 3072 DROPOUT_RATE: 0.1 ATTENTION_DROPOUT_RATE: 0 CLASSIFIER: token
config: MODEL: FEATURE_EVAL_SETTINGS: EVAL_MODE_ON: True FREEZE_TRUNK_ONLY: True EXTRACT_TRUNK_FEATURES_ONLY: True SHOULD_FLATTEN_FEATS: False TRUNK: NAME: vision_transformer VISION_TRANSFORMERS: IMAGE_SIZE: 224 PATCH_SIZE: 16 NUM_LAYERS: 12 NUM_HEADS: 12 HIDDEN_DIM: 768 MLP_DIM: 3072 DROPOUT_RATE: 0.1 ATTENTION_DROPOUT_RATE: 0 CLASSIFIER: token EXTRACT_FEATURES: OUTPUT_DIR: "extracted_feature" CHUNK_THRESHOLD: 0
python run_distributed_engines.py \ hydra.verbose=true \ config=feature_extraction/extract_resnet_in1k_8gpu \ +config/feature_extraction/trunk_only=vit_b16 \ config.CHECKPOINT.DIR="feature/supervised" \ config.MODEL.WEIGHTS_INIT.PARAMS_FILE="../weights/vit_b16_p16_in22k_ep90_supervised.torch" \ config.MODEL.WEIGHTS_INIT.APPEND_PREFIX="trunk.base_model." \ config.MODEL.WEIGHTS_INIT.STATE_DICT_KEY_NAME=classy_state_dict
Exception in thread Thread-1: Traceback (most recent call last): File "/home/vica/anaconda3/envs/vissl/lib/python3.8/threading.py", line 932, in _bootstrap_inner self.run() File "/home/vica/anaconda3/envs/vissl/lib/python3.8/threading.py", line 870, in run self._target(*self._args, **self._kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py", line 25, in _pin_memory_loop r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/queues.py", line 116, in get return _ForkingPickler.loads(res) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 282, in rebuild_storage_fd fd = df.detach() File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/resource_sharer.py", line 57, in detach with _resource_sharer.get_connection(self._id) as conn: File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/resource_sharer.py", line 87, in get_connection c = Client(address, authkey=process.current_process().authkey) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 508, in Client answer_challenge(c, authkey) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 752, in answer_challenge message = connection.recv_bytes(256) # reject large message File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes buf = self._recv_bytes(maxlength) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes buf = self._recv(4) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/multiprocessing/connection.py", line 379, in _recv chunk = read(handle, remaining) ConnectionResetError: [Errno 104] Connection reset by peer Traceback (most recent call last): File "run_distributed_engines.py", line 57, in
hydra_main(overrides=overrides)
File "run_distributed_engines.py", line 42, in hydra_main
launch_distributed(
File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/utils/distributed_launcher.py", line 135, in launch_distributed
torch.multiprocessing.spawn(
File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 1 terminated with the following error: Traceback (most recent call last): File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap fn(i, args) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/utils/distributed_launcher.py", line 192, in _distributed_worker run_engine( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/engine_registry.py", line 86, in run_engine engine.run_engine( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/extract_features.py", line 39, in run_engine extract_main( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/extract_features.py", line 106, in extract_main trainer.extract(output_folder=cfg.EXTRACT_FEATURES.OUTPUT_DIR or checkpoint_folder) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 365, in extract self._extract_split_features(feat_names, self.task, split, output_folder) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 443, in _extract_split_features features = task.model(input_sample["input"]) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward output = self.module(*inputs[0], *kwargs[0]) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 97, in call return self.forward(args, kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 111, in forward out = self.classy_model(*args, kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, *kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/models/base_ssl_model.py", line 180, in forward return self.single_input_forward(batch, self._output_feature_names, self.heads) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/models/base_ssl_model.py", line 128, in single_input_forward feats = self.trunk(batch, feature_names) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, kwargs) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/models/trunks/feature_extractor.py", line 36, in forward assert len(feats) == len( AssertionError: #features returned by base model (0) != #Pooling Ops (1)
ERROR 2022-05-15 07:59:19,346 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:19,420 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:19,505 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:19,749 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:19,983 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:20,110 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:20,123 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 ERROR 2022-05-15 07:59:20,250 base_ssl_model.py: 163: Mismatch in #head: 0 and #features: 1 Traceback (most recent call last): File "run_distributed_engines.py", line 57, in
hydra_main(overrides=overrides)
File "run_distributed_engines.py", line 42, in hydra_main
launch_distributed(
File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/utils/distributed_launcher.py", line 135, in launch_distributed
torch.multiprocessing.spawn(
File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 4 terminated with the following error: Traceback (most recent call last): File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap fn(i, *args) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/utils/distributed_launcher.py", line 192, in _distributed_worker run_engine( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/engine_registry.py", line 86, in run_engine engine.run_engine( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/extract_features.py", line 39, in run_engine extract_main( File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/engines/extract_features.py", line 106, in extract_main trainer.extract(output_folder=cfg.EXTRACT_FEATURES.OUTPUT_DIR or checkpoint_folder) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 365, in extract self._extract_split_features(feat_names, self.task, split, output_folder) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 444, in _extract_split_features flat_features_list = self._flatten_features_list(features) File "/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl/trainer/trainer_main.py", line 372, in _flatten_features_list assert isinstance(features, list), "features must be of type list" AssertionError: features must be of type list
sys.platform linux Python 3.8.13 (default, Mar 28 2022, 11:38:47) [GCC 7.5.0] numpy 1.19.2 Pillow 7.1.2 vissl 0.1.6 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/vissl GPU available True GPU 0,1,2,3,4,5,6 NVIDIA GeForce RTX 3090 CUDA_HOME /usr/local/cuda torchvision 0.10.1 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torchvision hydra 1.0.7 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/hydra classy_vision 0.7.0.dev @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/classy_vision apex 0.1 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/apex PyTorch 1.9.1 @/home/vica/anaconda3/envs/vissl/lib/python3.8/site-packages/torch PyTorch debug build False
PyTorch built with:
CPU info:
Architecture x86_64 CPU op-mode(s) 32-bit, 64-bit Byte Order Little Endian CPU(s) 80 On-line CPU(s) list 0-79 Thread(s) per core 2 Core(s) per socket 20 Socket(s) 2 NUMA node(s) 2 Vendor ID GenuineIntel CPU family 6 Model 85 Model name Intel(R) Xeon(R) Gold 5218R CPU @ 2.10GHz Stepping 7 CPU MHz 3247.779 CPU max MHz 4000.0000 CPU min MHz 800.0000 BogoMIPS 4200.00 Virtualization VT-x L1d cache 32K L1i cache 32K L2 cache 1024K L3 cache 28160K NUMA node0 CPU(s) 0-19,40-59 NUMA node1 CPU(s) 20-39,60-79