facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.07k stars 1.09k forks source link

How to add val_dataset #365

Closed z-jiaming closed 3 weeks ago

z-jiaming commented 3 weeks ago

Thanks for your great work for model and training framework both! We want to add val_dataset and it seems that the framework has implemented it.

We add following changes in sam2.1_hiera_b+_MOSE_finetune.yaml:

# @package _global_

scratch:
  resolution: 1024
  train_batch_size: 1
  num_train_workers: 10
  num_frames: 8
  max_num_objects: 3
  base_lr: 5.0e-6
  vision_lr: 3.0e-06
  phases_per_epoch: 1
  num_epochs: 40

dataset:
  # PATHS to Dataset
  img_folder: /fsx-onevision/shared/data/academic_vos_data/MOSE/train/JPEGImages # PATH to MOSE JPEGImages folder
  gt_folder: /fsx-onevision/shared/data/academic_vos_data/MOSE/train/Annotations/  # PATH to MOSE Annotations folder
  file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training
  multiplier: 2

# Video transforms
vos:
  train_transforms:
    - _target_: training.dataset.transforms.ComposeAPI
      transforms:
        - _target_: training.dataset.transforms.RandomHorizontalFlip
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomAffine
          degrees: 25
          shear: 20
          image_interpolation: bilinear
          consistent_transform: True
        - _target_: training.dataset.transforms.RandomResizeAPI
          sizes: ${scratch.resolution}
          square: true
          consistent_transform: True
        - _target_: training.dataset.transforms.ColorJitter
          consistent_transform: True
          brightness: 0.1
          contrast: 0.03
          saturation: 0.03
          hue: null
        - _target_: training.dataset.transforms.RandomGrayscale
          p: 0.05
          consistent_transform: True
        - _target_: training.dataset.transforms.ColorJitter
          consistent_transform: False
          brightness: 0.1
          contrast: 0.05
          saturation: 0.05
          hue: null
        - _target_: training.dataset.transforms.ToTensorAPI
        - _target_: training.dataset.transforms.NormalizeAPI
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]

trainer:
  _target_: training.trainer.Trainer
  mode: train_only
  max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}}
  accelerator: cuda
  seed_value: 123

  model:
    _target_: training.model.sam2.SAM2Train
    image_encoder:
      _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
      scalp: 1
      trunk:
        _target_: sam2.modeling.backbones.hieradet.Hiera
        embed_dim: 112
        num_heads: 2
        drop_path_rate: 0.1
      neck:
        _target_: sam2.modeling.backbones.image_encoder.FpnNeck
        position_encoding:
          _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
          num_pos_feats: 256
          normalize: true
          scale: null
          temperature: 10000
        d_model: 256
        backbone_channel_list: [896, 448, 224, 112]
        fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
        fpn_interp_model: nearest

    memory_attention:
      _target_: sam2.modeling.memory_attention.MemoryAttention
      d_model: 256
      pos_enc_at_input: true
      layer:
        _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
        activation: relu
        dim_feedforward: 2048
        dropout: 0.1
        pos_enc_at_attn: false
        self_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [32, 32]
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
        d_model: 256
        pos_enc_at_cross_attn_keys: true
        pos_enc_at_cross_attn_queries: false
        cross_attention:
          _target_: sam2.modeling.sam.transformer.RoPEAttention
          rope_theta: 10000.0
          feat_sizes: [32, 32]
          rope_k_repeat: True
          embedding_dim: 256
          num_heads: 1
          downsample_rate: 1
          dropout: 0.1
          kv_in_dim: 64
      num_layers: 4

    memory_encoder:
        _target_: sam2.modeling.memory_encoder.MemoryEncoder
        out_dim: 64
        position_encoding:
          _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
          num_pos_feats: 64
          normalize: true
          scale: null
          temperature: 10000
        mask_downsampler:
          _target_: sam2.modeling.memory_encoder.MaskDownSampler
          kernel_size: 3
          stride: 2
          padding: 1
        fuser:
          _target_: sam2.modeling.memory_encoder.Fuser
          layer:
            _target_: sam2.modeling.memory_encoder.CXBlock
            dim: 256
            kernel_size: 7
            padding: 3
            layer_scale_init_value: 1e-6
            use_dwconv: True  # depth-wise convs
          num_layers: 2

    num_maskmem: 7
    image_size: ${scratch.resolution}
    # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
    sigmoid_scale_for_mem_enc: 20.0
    sigmoid_bias_for_mem_enc: -10.0
    use_mask_input_as_output_without_sam: true
    # Memory
    directly_add_no_mem_embed: true
    no_obj_embed_spatial: true
    # use high-resolution feature map in the SAM mask decoder
    use_high_res_features_in_sam: true
    # output 3 masks on the first click on initial conditioning frames
    multimask_output_in_sam: true
    # SAM heads
    iou_prediction_use_sigmoid: True
    # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
    use_obj_ptrs_in_encoder: true
    add_tpos_enc_to_obj_ptrs: true
    proj_tpos_enc_in_obj_ptrs: true
    use_signed_tpos_enc_to_obj_ptrs: true
    only_obj_ptrs_in_the_past_for_eval: true
    # object occlusion prediction
    pred_obj_scores: true
    pred_obj_scores_mlp: true
    fixed_no_obj_ptr: true
    # multimask tracking settings
    multimask_output_for_tracking: true
    use_multimask_token_for_obj_ptr: true
    multimask_min_pt_num: 0
    multimask_max_pt_num: 1
    use_mlp_for_obj_ptr_proj: true
    # Compilation flag
    # compile_image_encoder: False

    ####### Training specific params #######
    # box/point input and corrections
    prob_to_use_pt_input_for_train: 0.5
    prob_to_use_pt_input_for_eval: 0.0
    prob_to_use_box_input_for_train: 0.5  # 0.5*0.5 = 0.25 prob to use box instead of points
    prob_to_use_box_input_for_eval: 0.0
    prob_to_sample_from_gt_for_train: 0.1  # with a small prob, sampling correction points from GT mask instead of prediction errors
    num_frames_to_correct_for_train: 2  # iteratively sample on random 1~2 frames (always include the first frame)
    num_frames_to_correct_for_eval: 1  # only iteratively sample on first frame
    rand_frames_to_correct_for_train: True  # random #init-cond-frame ~ 2
    add_all_frames_to_correct_as_cond: True  # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
    # maximum 2 initial conditioning frames
    num_init_cond_frames_for_train: 2
    rand_init_cond_frames_for_train: True  # random 1~2
    num_correction_pt_per_frame: 7
    use_act_ckpt_iterative_pt_sampling: false

    num_init_cond_frames_for_eval: 1  # only mask on the first frame
    forward_backbone_per_frame_for_eval: True

  data:
    train:
      _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
      phases_per_epoch: ${scratch.phases_per_epoch}
      batch_sizes:
        - ${scratch.train_batch_size}

      datasets:
        - _target_: training.dataset.utils.RepeatFactorWrapper
          dataset:
            _target_: training.dataset.utils.ConcatDataset
            datasets:
            - _target_: training.dataset.vos_dataset.VOSDataset
              transforms: ${vos.train_transforms}
              training: true
              video_dataset:
                _target_: training.dataset.vos_raw_dataset.PNGRawDataset
                img_folder: ${dataset.img_folder}
                gt_folder: ${dataset.gt_folder}
                file_list_txt: training/assets/MOSE_sample_train_list.txt
              sampler:
                _target_: training.dataset.vos_sampler.RandomUniformSampler
                num_frames: ${scratch.num_frames}
                max_num_objects: ${scratch.max_num_objects}
              multiplier: ${dataset.multiplier}
      shuffle: True
      num_workers: ${scratch.num_train_workers}
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: training.utils.data_utils.collate_fn
        _partial_: true
        dict_key: all
    val:
      _target_: training.dataset.sam2_datasets.TorchTrainMixedDataset
      phases_per_epoch: ${scratch.phases_per_epoch}
      batch_sizes:
        - ${scratch.train_batch_size}

      datasets:
        - _target_: training.dataset.utils.RepeatFactorWrapper
          dataset:
            _target_: training.dataset.utils.ConcatDataset
            datasets:
              - _target_: training.dataset.vos_dataset.VOSDataset
                transforms: ${vos.train_transforms}
                training: false
                video_dataset:
                  _target_: training.dataset.vos_raw_dataset.PNGRawDataset
                  img_folder: ${dataset.img_folder}
                  gt_folder: ${dataset.gt_folder}
                  file_list_txt: training/assets/MOSE_sample_val_list.txt
                sampler:
                  _target_: training.dataset.vos_sampler.RandomUniformSampler
                  num_frames: ${scratch.num_frames}
                  max_num_objects: ${scratch.max_num_objects}
                multiplier: ${dataset.multiplier}
      shuffle: True
      num_workers: ${scratch.num_train_workers}
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: training.utils.data_utils.collate_fn
        _partial_: true
        dict_key: all

  optim:
    amp:
      enabled: True
      amp_dtype: bfloat16

    optimizer:
      _target_: torch.optim.AdamW

    gradient_clip:
      _target_: training.optimizer.GradientClipper
      max_norm: 0.1
      norm_type: 2

    param_group_modifiers:
      - _target_: training.optimizer.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: 0.9
        apply_to: 'image_encoder.trunk'
        overrides:
          - pattern: '*pos_embed*'
            value: 1.0

    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.base_lr}
            end_value: ${divide:${scratch.base_lr},10}
        - scheduler:
            _target_: fvcore.common.param_scheduler.CosineParamScheduler
            start_value: ${scratch.vision_lr}
            end_value: ${divide:${scratch.vision_lr},10}
          param_names:
            - 'image_encoder.*'
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.1
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            - '*bias*'
          module_cls_names: ['torch.nn.LayerNorm']

  loss:
    all:
      _target_: training.loss_fns.MultiStepMultiMasksAndIous
      weight_dict:
        loss_mask: 20
        loss_dice: 1
        loss_iou: 1
        loss_class: 1
      supervise_all_iou: true
      iou_use_l1_loss: true
      pred_obj_scores: true
      focal_gamma_obj_score: 0.0
      focal_alpha_obj_score: -1.0

  distributed:
    backend: nccl
    find_unused_parameters: True

  logging:
    tensorboard_writer:
      _target_: training.utils.logger.make_tensorboard_logger
      log_dir:  ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
      should_log: True
    log_dir: ${launcher.experiment_log_dir}/logs
    log_freq: 10

  # initialize from a SAM 2 checkpoint
  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 0 # 0 only last checkpoint is saved.
    model_weight_initializer:
      _partial_: True
      _target_: training.utils.checkpoint_utils.load_state_dict_into_model
      strict: True
      ignore_unexpected_keys: null
      ignore_missing_keys: null

      state_dict:
        _target_: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels
        checkpoint_path: ./checkpoints/sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint
        ckpt_state_dict_keys: ['model']

launcher:
  num_nodes: 1
  gpus_per_node: 8
  experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}

# SLURM args if running on a cluster
submitit:
  partition: null
  account: null
  qos: null
  cpus_per_task: 10
  use_cluster: false
  timeout_hour: 24
  name: null
  port_range: [10000, 65000]

However, it mistakes when assert in https://github.com/z-jiaming/sam2_plus/blob/aa9b8722d0585b661ded4b3dff1bd103540554ae/training/trainer.py#L973

What should I do to fix it, such as fix trainer.loss?

image
ronghanghu commented 3 weeks ago

Hi @z-jiaming, we don't have support for val_dataset in this release (our internal version of validation depends on Meta infrastructure). The vos_inference.py at https://github.com/facebookresearch/sam2/tree/main/tools is the adapted public version that could run in open-source environments.

For validation, you can use vos_inference.py to save the predictions on VOS datasets (and possibly split the dataset video list into per-GPU chunks to save PNG outputs using multiple GPUs), and then use the eval toolkits on each dataset to get the results.

z-jiaming commented 3 weeks ago

Thank you for your quick response. I saw the val_datasetand assumed that the current framework had implemented it. Looking forward to your subsequent implementation!