facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.25k stars 331 forks source link

Training multiple output heads on frozen trunk #479

Open markbarna opened 2 years ago

markbarna commented 2 years ago

Hello,

I'm trying to use a pre-trained XCiT trunk to train 3 parallel heads for a multi-output image classification problem. Basically, I have images that need to be classified across three categories (so each image will receive three labels--one from each head).

I have set up my configuration (pasted at the bottom) to contain three mlp heads for three, four, and three label classes, respectively. When I try to run a forward training pass, I get this error:

Traceback (most recent call last):
  File "/home/mbarna/Projects/mldevautomation/run_pipeline.py", line 35, in <module>
    main()
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/main.py", line 32, in decorated_main
    _run_hydra(
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/utils.py", line 346, in _run_hydra
    run_and_report(
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/utils.py", line 201, in run_and_report
    raise ex
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/utils.py", line 198, in run_and_report
    return func()
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/utils.py", line 347, in <lambda>
    lambda: hydra.run(
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/_internal/hydra.py", line 107, in run
    return run_job(
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/hydra/core/utils.py", line 129, in run_job
    ret.return_value = task_function(task_cfg)
  File "/home/mbarna/Projects/mldevautomation/run_pipeline.py", line 30, in main
    summary = pipeline.execute()
  File "/home/mbarna/Projects/mldevautomation/pipelines/training.py", line 49, in execute
    request = self.trainer.handle(request=request)
  File "/home/mbarna/Projects/mldevautomation/trainers/vissl_runner.py", line 42, in handle
    launch_distributed(
  File "/home/mbarna/Projects/vissl/vissl/utils/distributed_launcher.py", line 164, in launch_distributed
    raise e
  File "/home/mbarna/Projects/vissl/vissl/utils/distributed_launcher.py", line 150, in launch_distributed
    _distributed_worker(
  File "/home/mbarna/Projects/vissl/vissl/utils/distributed_launcher.py", line 192, in _distributed_worker
    run_engine(
  File "/home/mbarna/Projects/vissl/vissl/engines/engine_registry.py", line 86, in run_engine
    engine.run_engine(
  File "/home/mbarna/Projects/vissl/vissl/engines/train.py", line 39, in run_engine
    train_main(
  File "/home/mbarna/Projects/vissl/vissl/engines/train.py", line 130, in train_main
    trainer.train()
  File "/home/mbarna/Projects/vissl/vissl/trainer/trainer_main.py", line 211, in train
    raise e
  File "/home/mbarna/Projects/vissl/vissl/trainer/trainer_main.py", line 193, in train
    task = train_step_fn(task)
  File "/home/mbarna/Projects/vissl/vissl/trainer/train_steps/standard_train_step.py", line 143, in standard_train_step
    model_output = task.model(sample["input"])
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 97, in __call__
    return self.forward(*args, **kwargs)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/classy_vision/models/classy_model.py", line 111, in forward
    out = self.classy_model(*args, **kwargs)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mbarna/Projects/vissl/vissl/models/base_ssl_model.py", line 180, in forward
    return self.single_input_forward(batch, self._output_feature_names, self.heads)
  File "/home/mbarna/Projects/vissl/vissl/models/base_ssl_model.py", line 138, in single_input_forward
    return self.heads_forward(feats, heads)
  File "/home/mbarna/Projects/vissl/vissl/models/base_ssl_model.py", line 159, in heads_forward
    output = head(output)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mbarna/Projects/vissl/vissl/models/heads/mlp.py", line 111, in forward
    out = self.clf(batch)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/apex/amp/wrap.py", line 28, in wrapper
    return orig_fn(*new_args, **kwargs)
  File "/home/mbarna/.pyenv/versions/3.8.12/envs/mldevautomation-vissl-3.8/lib/python3.8/site-packages/torch/nn/functional.py", line 1848, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 384x4)

After digging through the VISSL code, specifically the heads_forward() function in vissl/models/base_ssl_model.py, it looks like this bit of code is applying the heads in series rather than in parallel:

 # Example case: Head consisting of several layers
        elif (len(heads) > 1) and (len(feats) == 1):
            output = feats[0]
            for head in heads:
                output = head(output)
            # our model is multiple output.
            return [output]

I have checked that the head layers are set up in parallel:

(heads): ModuleList(
    (0): Sequential(
      (0): MLP(
        (clf): Sequential(
          (0): Linear(in_features=384, out_features=3, bias=True)
        )
      )
    )
    (1): Sequential(
      (0): MLP(
        (clf): Sequential(
          (0): Linear(in_features=384, out_features=4, bias=True)
        )
      )
    )
    (2): Sequential(
      (0): MLP(
        (clf): Sequential(
          (0): Linear(in_features=384, out_features=3, bias=True)
        )
      )
    )
  )

It looks like I want the function to use the first if clause instead but this seems to be meant for extracting outputs from multiple layers in the trunk instead of using the final layer for multiple heads. I assume that I must be missing something in the configuration setup.

Incidentally, I did not see any guidance on how to structure the disk_filelist labels for multiple heads so I assumed it would be a 2d array, where the columns represent the different heads:

array([[2, 3, 2],
       [0, 0, 1],
       [0, 0, 1],
       [0, 0, 1],
       [0, 0, 1]])

I was working on implementing my own loss function for this.

Here is the full config. Thanks for your help!

CHECKPOINT:
  APPEND_DISTR_RUN_ID: false
  AUTO_RESUME: false
  BACKEND: disk
  CHECKPOINT_FREQUENCY: 5
  CHECKPOINT_ITER_FREQUENCY: -1
  DIR: /home/test_data/labeled/checkpoints
  LATEST_CHECKPOINT_RESUME_FILE_NUM: 1
  OVERWRITE_EXISTING: true
  USE_SYMLINK_CHECKPOINT_FOR_RESUME: false
CLUSTERFIT:
  CLUSTER_BACKEND: faiss
  DATA_LIMIT: -1
  DATA_LIMIT_SAMPLING:
    SEED: 0
  FEATURES:
    DATASET_NAME: ''
    DATA_PARTITION: TRAIN
    DIMENSIONALITY_REDUCTION: 0
    EXTRACT: false
    LAYER_NAME: ''
    PATH: .
    TEST_PARTITION: TEST
  NUM_CLUSTERS: 16000
  NUM_ITER: 50
  OUTPUT_DIR: .
DATA:
  DDP_BUCKET_CAP_MB: 25
  ENABLE_ASYNC_GPU_COPY: true
  NUM_DATALOADER_WORKERS: 5
  PIN_MEMORY: true
  TEST:
    BASE_DATASET: generic_ssl
    BATCHSIZE_PER_REPLICA: 2
    COLLATE_FUNCTION: default_collate
    COLLATE_FUNCTION_PARAMS: {}
    COPY_DESTINATION_DIR: ''
    COPY_TO_LOCAL_DISK: false
    DATASET_NAMES:
    - hand_hygiene_example_data
    DATA_LIMIT: -1
    DATA_LIMIT_SAMPLING:
      IS_BALANCED: false
      SEED: 0
      SKIP_NUM_SAMPLES: 0
    DATA_PATHS:
    - /home/test_data/labeled/val_images.npy
    DATA_SOURCES:
    - disk_filelist
    DEFAULT_GRAY_IMG_SIZE: 224
    DROP_LAST: false
    ENABLE_QUEUE_DATASET: false
    INPUT_KEY_NAMES:
    - data
    LABEL_PATHS:
    - /home/test_data/labeled/val_labels.npy
    LABEL_SOURCES:
    - disk_filelist
    LABEL_TYPE: standard
    MMAP_MODE: true
    NEW_IMG_PATH_PREFIX: ''
    RANDOM_SYNTHETIC_IMAGES: false
    REMOVE_IMG_PATH_PREFIX: ''
    TARGET_KEY_NAMES:
    - label
    TRANSFORMS:
    - name: Resize
      size: 256
    - name: CenterCrop
      size: 224
    - name: ToTensor
    - mean:
      - 0.485
      - 0.456
      - 0.406
      name: Normalize
      std:
      - 0.229
      - 0.224
      - 0.225
    USE_DEBUGGING_SAMPLER: false
    USE_STATEFUL_DISTRIBUTED_SAMPLER: false
  TRAIN:
    BASE_DATASET: generic_ssl
    BATCHSIZE_PER_REPLICA: 2
    COLLATE_FUNCTION: default_collate
    COLLATE_FUNCTION_PARAMS: {}
    COPY_DESTINATION_DIR: ''
    COPY_TO_LOCAL_DISK: false
    DATASET_NAMES:
    - hand_hygiene_example_data
    DATA_LIMIT: -1
    DATA_LIMIT_SAMPLING:
      IS_BALANCED: false
      SEED: 0
      SKIP_NUM_SAMPLES: 0
    DATA_PATHS:
    - /home/test_data/labeled/train_images.npy
    DATA_SOURCES:
    - disk_filelist
    DEFAULT_GRAY_IMG_SIZE: 224
    DROP_LAST: false
    ENABLE_QUEUE_DATASET: false
    INPUT_KEY_NAMES:
    - data
    LABEL_PATHS:
    - /home/test_data/labeled/train_labels.npy
    LABEL_SOURCES:
    - disk_filelist
    LABEL_TYPE: standard
    MMAP_MODE: true
    NEW_IMG_PATH_PREFIX: ''
    RANDOM_SYNTHETIC_IMAGES: false
    REMOVE_IMG_PATH_PREFIX: ''
    TARGET_KEY_NAMES:
    - label
    TRANSFORMS:
    - name: RandomResizedCrop
      size: 224
    - name: RandomHorizontalFlip
    - name: ToTensor
    - mean:
      - 0.485
      - 0.456
      - 0.406
      name: Normalize
      std:
      - 0.229
      - 0.224
      - 0.225
    USE_DEBUGGING_SAMPLER: false
    USE_STATEFUL_DISTRIBUTED_SAMPLER: false
DISTRIBUTED:
  BACKEND: nccl
  BROADCAST_BUFFERS: true
  INIT_METHOD: tcp
  MANUAL_GRADIENT_REDUCTION: false
  NCCL_DEBUG: false
  NCCL_SOCKET_NTHREADS: ''
  NUM_NODES: 1
  NUM_PROC_PER_NODE: 1
  RUN_ID: auto
EXTRACT_FEATURES:
  CHUNK_THRESHOLD: 0
  OUTPUT_DIR: ''
HOOKS:
  CHECK_NAN: true
  LOG_GPU_STATS: true
  MEMORY_SUMMARY:
    DUMP_MEMORY_ON_EXCEPTION: false
    LOG_ITERATION_NUM: 0
    PRINT_MEMORY_SUMMARY: true
  MODEL_COMPLEXITY:
    COMPUTE_COMPLEXITY: false
    INPUT_SHAPE:
    - 3
    - 224
    - 224
  PERF_STATS:
    MONITOR_PERF_STATS: true
    PERF_STAT_FREQUENCY: 10
    ROLLING_BTIME_FREQ: 313
  TENSORBOARD_SETUP:
    EXPERIMENT_LOG_DIR: null
    FLUSH_EVERY_N_MIN: 20
    LOG_DIR: .
    LOG_PARAMS: true
    LOG_PARAMS_EVERY_N_ITERS: 310
    LOG_PARAMS_GRADIENTS: true
    USE_TENSORBOARD: false
  WANDB_SETUP:
    EXPERIMENT_LOG_DIR: wandb
    EXP_NAME: ''
    LOG_DIR: ''
    LOG_PARAMS: false
    LOG_PARAMS_EVERY_N_ITERS: -1
    LOG_PARAMS_GRADIENTS: false
    PROJECT_NAME: ''
    USE_WANDB: false
IMG_RETRIEVAL:
  CROP_QUERY_ROI: false
  DATASET_PATH: ''
  DEBUG_MODE: false
  EVAL_BINARY_PATH: ''
  EVAL_DATASET_NAME: Paris
  FEATS_PROCESSING_TYPE: ''
  GEM_POOL_POWER: 4.0
  IMG_SCALINGS:
  - 1
  NORMALIZE_FEATURES: true
  NUM_DATABASE_SAMPLES: -1
  NUM_QUERY_SAMPLES: -1
  NUM_TRAINING_SAMPLES: -1
  N_PCA: 512
  RESIZE_IMG: 1024
  SAVE_FEATURES: false
  SAVE_RETRIEVAL_RANKINGS_SCORES: true
  SIMILARITY_MEASURE: cosine_similarity
  SPATIAL_LEVELS: 3
  TRAIN_DATASET_NAME: Oxford
  TRAIN_PCA_WHITENING: true
  USE_DISTRACTORS: false
  WHITEN_IMG_LIST: ''
LOG_FREQUENCY: 10
LOSS:
  CrossEntropyLoss:
    ignore_index: -1
  barlow_twins_loss:
    embedding_dim: 8192
    lambda_: 0.0051
    scale_loss: 0.024
  bce_logits_multiple_output_single_target:
    normalize_output: false
    reduction: none
    world_size: 1
  cross_entropy_multiple_output_multiple_target:
    ignore_index: -1
    normalize_output: false
    reduction: mean
    temperature: 1.0
    weight: null
  cross_entropy_multiple_output_single_target:
    ignore_index: -1
    normalize_output: false
    reduction: mean
    temperature: 1.0
    weight: null
  deepclusterv2_loss:
    BATCHSIZE_PER_REPLICA: 256
    DROP_LAST: true
    kmeans_iters: 10
    memory_params:
      crops_for_mb:
      - 0
      embedding_dim: 128
    num_clusters:
    - 3000
    - 3000
    - 3000
    num_crops: 2
    num_train_samples: -1
    temperature: 0.1
  dino_loss:
    crops_for_teacher:
    - 0
    - 1
    ema_center: 0.9
    momentum: 0.996
    normalize_last_layer: true
    output_dim: 65536
    student_temp: 0.1
    teacher_temp_max: 0.07
    teacher_temp_min: 0.04
    teacher_temp_warmup_iters: 37500
  moco_loss:
    embedding_dim: 128
    momentum: 0.999
    queue_size: 65536
    temperature: 0.2
  multicrop_simclr_info_nce_loss:
    buffer_params:
      effective_batch_size: 4096
      embedding_dim: 128
      world_size: 64
    num_crops: 2
    temperature: 0.1
  name: cross_entropy_multiple_output_multiple_target
  nce_loss_with_memory:
    loss_type: nce
    loss_weights:
    - 1.0
    memory_params:
      embedding_dim: 128
      memory_size: -1
      momentum: 0.5
      norm_init: true
      update_mem_on_forward: true
    negative_sampling_params:
      num_negatives: 16000
      type: random
    norm_constant: -1
    norm_embedding: true
    num_train_samples: -1
    temperature: 0.07
    update_mem_with_emb_index: -100
  simclr_info_nce_loss:
    buffer_params:
      effective_batch_size: 4096
      embedding_dim: 128
      world_size: 64
    temperature: 0.1
  swav_loss:
    crops_for_assign:
    - 0
    - 1
    embedding_dim: 128
    epsilon: 0.05
    normalize_last_layer: true
    num_crops: 2
    num_iters: 3
    num_prototypes:
    - 3000
    output_dir: .
    queue:
      local_queue_length: 0
      queue_length: 0
      start_iter: 0
    temp_hard_assignment_iters: 0
    temperature: 0.1
    use_double_precision: false
  swav_momentum_loss:
    crops_for_assign:
    - 0
    - 1
    embedding_dim: 128
    epsilon: 0.05
    momentum: 0.99
    momentum_eval_mode_iter_start: 0
    normalize_last_layer: true
    num_crops: 2
    num_iters: 3
    num_prototypes:
    - 3000
    queue:
      local_queue_length: 0
      queue_length: 0
      start_iter: 0
    temperature: 0.1
    use_double_precision: false
MACHINE:
  DEVICE: gpu
METERS:
  accuracy_list_meter:
    meter_names: []
    num_meters: 1
    topk_values:
    - 1
  enable_training_meter: true
  mean_ap_list_meter:
    max_cpu_capacity: -1
    meter_names: []
    num_classes: 9605
    num_meters: 1
  model_output_mask: false
  name: accuracy_list_meter
  names:
  - accuracy_list_meter
  precision_at_k_list_meter:
    meter_names: []
    num_meters: 1
    topk_values:
    - 1
  recall_at_k_list_meter:
    meter_names: []
    num_meters: 1
    topk_values:
    - 1
MODEL:
  ACTIVATION_CHECKPOINTING:
    NUM_ACTIVATION_CHECKPOINTING_SPLITS: 2
    USE_ACTIVATION_CHECKPOINTING: false
  AMP_PARAMS:
    AMP_ARGS:
      opt_level: O1
    AMP_TYPE: apex
    USE_AMP: true
  BASE_MODEL_NAME: multi_input_output_model
  CUDA_CACHE:
    CLEAR_CUDA_CACHE: false
    CLEAR_FREQ: 100
  FEATURE_EVAL_SETTINGS:
    EVAL_MODE_ON: true
    EVAL_TRUNK_AND_HEAD: false
    EXTRACT_TRUNK_FEATURES_ONLY: false
    FREEZE_TRUNK_AND_HEAD: false
    FREEZE_TRUNK_ONLY: true
    LINEAR_EVAL_FEAT_POOL_OPS_MAP: []
    SHOULD_FLATTEN_FEATS: true
  FSDP_CONFIG:
    AUTO_WRAP_THRESHOLD: 0
    bucket_cap_mb: 0
    clear_autocast_cache: true
    compute_dtype: float32
    flatten_parameters: true
    fp32_reduce_scatter: false
    mixed_precision: true
    verbose: true
  GRAD_CLIP:
    MAX_NORM: 1
    NORM_TYPE: 2
    USE_GRAD_CLIP: false
  HEAD:
    BATCHNORM_EPS: 1.0e-05
    BATCHNORM_MOMENTUM: 0.1
    PARAMS:
    - - - mlp
        - dims:
          - 384
          - 3
    - - - mlp
        - dims:
          - 384
          - 4
    - - - mlp
        - dims:
          - 384
          - 3
    PARAMS_MULTIPLIER: 1.0
  INPUT_TYPE: rgb
  MULTI_INPUT_HEAD_MAPPING: []
  NON_TRAINABLE_PARAMS: []
  SHARDED_DDP_SETUP:
    USE_SDP: false
    reduce_buffer_size: -1
  SINGLE_PASS_EVERY_CROP: false
  SYNC_BN_CONFIG:
    CONVERT_BN_TO_SYNC_BN: false
    GROUP_SIZE: 8
    SYNC_BN_TYPE: apex
  TEMP_FROZEN_PARAMS_ITER_MAP: []
  TRUNK:
    CONVIT:
      CLASS_TOKEN_IN_LOCAL_LAYERS: false
      LOCALITY_DIM: 10
      LOCALITY_STRENGTH: 1.0
      N_GPSA_LAYERS: 10
      USE_LOCAL_INIT: true
    EFFICIENT_NETS: {}
    NAME: xcit
    REGNET: {}
    RESNETS:
      DEPTH: 50
      GROUPNORM_GROUPS: 32
      GROUPS: 1
      LAYER4_STRIDE: 2
      NORM: BatchNorm
      STANDARDIZE_CONVOLUTIONS: false
      WIDTH_MULTIPLIER: 1
      WIDTH_PER_GROUP: 64
      ZERO_INIT_RESIDUAL: false
    VISION_TRANSFORMERS:
      ATTENTION_DROPOUT_RATE: 0
      CLASSIFIER: token
      DROPOUT_RATE: 0
      DROP_PATH_RATE: 0.1
      HIDDEN_DIM: 384
      IMAGE_SIZE: 224
      MLP_DIM: 1536
      NUM_HEADS: 6
      NUM_LAYERS: 12
      PATCH_SIZE: 16
      QKV_BIAS: false
      QK_SCALE: false
      name: null
    XCIT:
      ATTENTION_DROPOUT_RATE: 0
      DROPOUT_RATE: 0
      DROP_PATH_RATE: 0.05
      ETA: 1
      HIDDEN_DIM: 384
      IMAGE_SIZE: 224
      NUM_HEADS: 8
      NUM_LAYERS: 12
      PATCH_SIZE: 16
      QKV_BIAS: true
      QK_SCALE: false
      TOKENS_NORM: true
      name: null
  WEIGHTS_INIT:
    APPEND_PREFIX: ''
    PARAMS_FILE: /home/data/pre_trained_weights/vissl/dino_300ep_xcitsmall16.torch
    REMOVE_PREFIX: ''
    SKIP_LAYERS:
    - num_batches_tracked
    STATE_DICT_KEY_NAME: classy_state_dict
  _MODEL_INIT_SEED: 0
MONITORING:
  MONITOR_ACTIVATION_STATISTICS: 0
MULTI_PROCESSING_METHOD: forkserver
NEAREST_NEIGHBOR:
  L2_NORM_FEATS: false
  SIGMA: 0.1
  TOPK: 200
OPTIMIZER:
  betas:
  - 0.9
  - 0.999
  construct_single_param_group_only: false
  head_optimizer_params:
    use_different_lr: false
    use_different_wd: false
    weight_decay: 0
  larc_config:
    clip: false
    eps: 1.0e-08
    trust_coefficient: 0.001
  momentum: 0.9
  name: sgd
  nesterov: false
  non_regularized_parameters: []
  num_epochs: 5
  param_schedulers:
    lr:
      auto_lr_scaling:
        auto_scale: true
        base_lr_batch_size: 256
        base_value: 0.1
        scaling_type: linear
      end_value: 0.0
      interval_scaling: &id001
      - rescaled
      - rescaled
      lengths: &id002
      - 0.1
      - 0.9
      milestones: &id003
      - 1
      name: cosine
      schedulers: &id004
      - end_value: 0.0
        name: cosine
        start_value: 0.00078125
      start_value: 0.00078125
      update_interval: step
      value: 0.1
      values: &id005
      - 0.01
      - 0.001
    lr_head:
      auto_lr_scaling:
        auto_scale: true
        base_lr_batch_size: 256
        base_value: 0.1
        scaling_type: linear
      end_value: 0.0
      interval_scaling: *id001
      lengths: *id002
      milestones: *id003
      name: cosine
      schedulers: *id004
      start_value: 0.00078125
      update_interval: step
      value: 0.1
      values: *id005
  regularize_bias: true
  regularize_bn: true
  use_larc: false
  use_zero: false
  weight_decay: 0
PROFILING:
  MEMORY_PROFILING:
    TRACK_BY_LAYER_MEMORY: false
  NUM_ITERATIONS: 10
  OUTPUT_FOLDER: .
  PROFILED_RANKS:
  - 0
  - 1
  RUNTIME_PROFILING:
    LEGACY_PROFILER: false
    PROFILE_CPU: true
    PROFILE_GPU: true
    USE_PROFILER: false
  START_ITERATION: 0
  STOP_TRAINING_AFTER_PROFILING: false
  WARMUP_ITERATIONS: 0
REPRODUCIBILITY:
  CUDDN_DETERMINISTIC: false
SEED_VALUE: 0
SLURM:
  ADDITIONAL_PARAMETERS: {}
  COMMENT: vissl job
  CONSTRAINT: ''
  LOG_FOLDER: .
  MEM_GB: 250
  NAME: vissl
  NUM_CPU_PER_PROC: 8
  PARTITION: ''
  PORT_ID: 40050
  TIME_HOURS: 72
  TIME_MINUTES: 0
  USE_SLURM: false
SVM:
  cls_list: []
  costs:
    base: -1.0
    costs_list:
    - 0.1
    - 0.01
    power_range:
    - 4
    - 20
  cross_val_folds: 3
  dual: true
  force_retrain: false
  loss: squared_hinge
  low_shot:
    dataset_name: voc
    k_values:
    - 1
    - 2
    - 4
    - 8
    - 16
    - 32
    - 64
    - 96
    sample_inds:
    - 1
    - 2
    - 3
    - 4
    - 5
  max_iter: 2000
  normalize: true
  penalty: l2
TEST_EVERY_NUM_EPOCH: 1
TEST_MODEL: true
TEST_ONLY: false
TRAINER:
  TASK_NAME: self_supervision_task
  TRAIN_STEP_NAME: standard_train_step
VERBOSE: false