SysCV / sam-pt

SAM-PT: Extending SAM to zero-shot video segmentation with point-based tracking.
https://arxiv.org/abs/2307.01197
Apache License 2.0
950 stars 60 forks source link

HQ-SAM predictor issue #21

Closed georgeYanch closed 10 months ago

georgeYanch commented 10 months ago

Hi, there's a problem when I try to use HQ-SAM-Light-VIT-T model

In sam_pt.yaml I set - sam@sam_predictor.sam_model: samhq_light_vit_tiny

and when running demo I get

Traceback (most recent call last):
  File "C:\Users\Alina\Desktop\Workplace\sam-pt\demo\demo.py", line 55, in main
    logits, trajectories, visibilities, scores = run_inference(model, rgbs, query_points, target_hw)
  File "C:\Users\Alina\Desktop\Workplace\sam-pt\demo\demo.py", line 131, in run_inference
    outputs = model(video)
  File "C:\Users\Alina\anaconda3\envs\sam-pt\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Alina\Desktop\Workplace\sam-pt\sam_pt\modeling\sam_pt.py", line 176, in forward
    query_masks = self.extract_query_masks(images, query_points)
  File "C:\Users\Alina\Desktop\Workplace\sam-pt\sam_pt\modeling\sam_pt.py", line 324, in extract_query_masks
    _, query_masks_logits, _ = self._apply_sam_to_trajectories(
  File "C:\Users\Alina\Desktop\Workplace\sam-pt\sam_pt\modeling\sam_pt.py", line 843, in _apply_sam_to_trajectories
    mask_frame_logits, iou_prediction_score = predict_mask(visible_point_coords, visible_point_labels)
  File "C:\Users\Alina\Desktop\Workplace\sam-pt\sam_pt\modeling\sam_pt.py", line 783, in predict_mask
    _, _, low_res_masks = self.sam_predictor.predict_torch(
  File "C:\Users\Alina\anaconda3\envs\sam-pt\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\Alina\anaconda3\envs\sam-pt\lib\site-packages\segment_anything\predictor.py", line 229, in predict_torch
    low_res_masks, iou_predictions = self.model.mask_decoder(
  File "C:\Users\Alina\anaconda3\envs\sam-pt\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)

TypeError: forward() missing 2 required positional arguments: 'hq_token_only' and 'interm_embeddings'

m43 commented 10 months ago

Hi, thanks for your interest and question. This is because when using the HQ-SAM variants, you need to change not only the model but also the SAM predictor class being instantiated. Here's how you can update your sam_pt.yaml configuration:

#...

defaults:
  # ...
  - sam@sam_predictor.sam_model: samhq_light_vit_tiny

sam_predictor:
  _target_: segment_anything_hq.predictor.SamPredictor

# ...

iterative_refinement_iterations: 3

# ...

Please note that based on a small hyperparameter search, it's recommended to use iterative_refinement_iterations equal to 3 for the Light HQ-SAM variant.

You can also override the configuration via the command line without modifying the .yaml files, as shown in the VOS evaluation examples here. Hope this helps!

georgeYanch commented 10 months ago

Thanks! Everything works just fine now