Closed sahal-786 closed 1 week ago
Hi, are you trying to train (or fine-tune) the SAM 2 model? If so, you should follow the MOSE fine-tuning config, which sets trainer.data.train.phases_per_epoch
to scratch.phases_per_epoch
(which is 1) as https://github.com/facebookresearch/sam2/blob/c98aa6bea377d5c000cdc80197ce402dbf5304dc/sam2/configs/sam2.1_training/sam2.1_hiera_b%2B_MOSE_finetune.yaml#L205
Hi, are you trying to train (or fine-tune) the SAM 2 model? If so, you should follow the MOSE fine-tuning config, which sets
trainer.data.train.phases_per_epoch
toscratch.phases_per_epoch
(which is 1) as
Bro i am doing training on image dataset and i do the same thing that you commented you can check my yaml file `# @package global
scratch: resolution: 1024 train_batch_size: 1 num_train_workers: 10 num_frames: 8 max_num_objects: 3 base_lr: 5.0e-6 vision_lr: 3.0e-06 phases_per_epoch: 1 num_epochs: 40
dataset:
img_folder: /home/sahal/sam2/768_x_1024_input_images # PATH to MOSE JPEGImages folder gt_folder: /home/sahal/sam2/masks # PATH to MOSE Annotations folder file_list_txt: training/assets/MOSE_sample_train_list.txt # Optional PATH to filelist containing a subset of videos to be used for training multiplier: 2
trainer: target: training.trainer.Trainer mode: train_only max_epochs: ${times:${scratch.num_epochs},${scratch.phases_per_epoch}} accelerator: cuda seed_value: 123
model: target: training.model.sam2.SAM2Train image_encoder: target: sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: target: sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 drop_path_rate: 0.1 neck: target: sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: target: sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null temperature: 10000 d_model: 256 backbone_channel_list: [896, 448, 224, 112] fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features fpn_interp_model: nearest
memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4
memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2
num_maskmem: 7
image_size: ${scratch.resolution}
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
no_obj_embed_spatial: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: true
proj_tpos_enc_in_obj_ptrs: true
use_signed_tpos_enc_to_obj_ptrs: true
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
# compile_image_encoder: False
####### Training specific params #######
# box/point input and corrections
prob_to_use_pt_input_for_train: 0.5
prob_to_use_pt_input_for_eval: 0.0
prob_to_use_box_input_for_train: 0.5 # 0.5*0.5 = 0.25 prob to use box instead of points
prob_to_use_box_input_for_eval: 0.0
prob_to_sample_from_gt_for_train: 0.1 # with a small prob, sampling correction points from GT mask instead of prediction errors
num_frames_to_correct_for_train: 2 # iteratively sample on random 1~2 frames (always include the first frame)
num_frames_to_correct_for_eval: 1 # only iteratively sample on first frame
rand_frames_to_correct_for_train: True # random #init-cond-frame ~ 2
add_all_frames_to_correct_as_cond: True # when a frame receives a correction click, it becomes a conditioning frame (even if it's not initially a conditioning frame)
# maximum 2 initial conditioning frames
num_init_cond_frames_for_train: 2
rand_init_cond_frames_for_train: True # random 1~2
num_correction_pt_per_frame: 7
use_act_ckpt_iterative_pt_sampling: false
num_init_cond_frames_for_eval: 1 # only mask on the first frame
forward_backbone_per_frame_for_eval: True
data: train: target: training.dataset.sam2_datasets.TorchTrainMixedDataset phases_per_epoch: ${scratch.phases_per_epoch} # Chunks a single epoch into smaller phases batch_sizes: # List of batch sizes corresponding to each dataset
target: training.dataset.vos_dataset.VOSDataset training: true video_dataset: target: training.dataset.vos_raw_dataset.SA1BRawDataset img_folder: ${path_to_img_folder} gt_folder: ${path_to_gt_folder} file_list_txt: ${path_to_train_filelist} # Optional sampler: target: training.dataset.vos_sampler.RandomUniformSampler num_frames: 1 max_num_objects: ${max_num_objects_per_image} transforms: ${image_transforms} shuffle: True num_workers: ${num_train_workers} pin_memory: True drop_last: True collate_fn: target: training.utils.data_utils.collate_fn partial: true dict_key: all
optim: amp: enabled: True amp_dtype: bfloat16
optimizer: target: torch.optim.AdamW
gradient_clip: target: training.optimizer.GradientClipper max_norm: 0.1 norm_type: 2
param_group_modifiers:
target: training.optimizer.layer_decay_param_modifier partial: True layer_decay_value: 0.9 apply_to: 'image_encoder.trunk' overrides:
options: lr:
loss: all: target: training.loss_fns.MultiStepMultiMasksAndIous weight_dict: loss_mask: 20 loss_dice: 1 loss_iou: 1 loss_class: 1 supervise_all_iou: true iou_use_l1_loss: true pred_obj_scores: true focal_gamma_obj_score: 0.0 focal_alpha_obj_score: -1.0
distributed: backend: nccl find_unused_parameters: True
logging: tensorboard_writer: target: training.utils.logger.make_tensorboard_logger log_dir: ${launcher.experiment_log_dir}/tensorboard flush_secs: 120 should_log: True log_dir: ${launcher.experiment_log_dir}/logs log_freq: 10
checkpoint: save_dir: ${launcher.experiment_log_dir}/checkpoints save_freq: 0 # 0 only last checkpoint is saved. model_weight_initializer: partial: True target: training.utils.checkpoint_utils.load_state_dict_into_model strict: True ignore_unexpected_keys: null ignore_missing_keys: null
state_dict: target: training.utils.checkpoint_utils.load_checkpoint_and_apply_kernels checkpoint_path: sam2.1_hiera_base_plus.pt # PATH to SAM 2.1 checkpoint ckpt_state_dict_keys: ['model']
launcher: num_nodes: 1 gpus_per_node: 1 experiment_log_dir: null # Path to log directory, defaults to ./sam2_logs/${config_name}
submitit: partition: null account: null qos: null cpus_per_task: 10 use_cluster: false timeout_hour: 24 name: null port_range: [10000, 65000]`
raise InterpolationKeyError(f"Interpolation key '{inter_key}' not found") omegaconf.errors.InterpolationKeyError: Interpolation key 'phases_per_epoch' not found full_key: trainer.data.train.phases_per_epoch object_type=dict