facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.44k stars 1.15k forks source link

When we set reverse propagate_in_video, a data type mismatch error occurs. #308

Open GoldenFishes opened 2 months ago

GoldenFishes commented 2 months ago

Background propagate_in_video allows passing start_frame_idx to specify the starting frame for propagation, max_frame_num_to_track to set the maximum length of propagation, and reverse to control the propagation direction, where True indicates reverse propagation.

Error When we set reverse=True propagation from the most recent frame for a fixed length. We got an error: RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float

Reason When setting reverse propagation from the most recent frame for a fixed length, the _prepare_memory_conditioned_features method in the SAM2Base class will enter the if not is_init_cond_frame: branch. This branch is generally not entered during forward propagation because the first frame in forward propagation is usually a condition frame. The purpose of this branch is to condition the visual features of the current frame on previous memory. In the aforementioned branch, when feats = prev["maskmem_features"].to(device, non_blocking=True) is retrieved, the value is torch.bfloat16 , which can cause subsequent MemoryAttention data type mismatch errors.

Temporary solution add feats = feats.to(torch.float32)

Question Why prev["maskmem_features"] in SAM2Base class is torch.bfloat16 dtype?Where is it generated?

ronghanghu commented 1 month ago

Hi @GoldenFishes, bfloat16 outputs are intended to be used together with Automatic Mixed Precision (AMP) in PyTorch,

In the predictor notebooks (https://github.com/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb), we are using a line

torch.autocast("cuda", dtype=torch.bfloat16).__enter__()

to turn on bfloat16 AMP and in the vos_inference script there is also a similar line https://github.com/facebookresearch/sam2/blob/52198ead0eb13ae8270bea6ca768ef175f5bf167/tools/vos_inference.py#L117. Turning on AMP should resolve this issue, where the bfloat16 inputs and float32 weights would be compatible.