Open GoldenFishes opened 2 months ago
Hi @GoldenFishes, bfloat16 outputs are intended to be used together with Automatic Mixed Precision (AMP) in PyTorch,
In the predictor notebooks (https://github.com/facebookresearch/sam2/blob/main/notebooks/video_predictor_example.ipynb), we are using a line
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
to turn on bfloat16 AMP and in the vos_inference script there is also a similar line https://github.com/facebookresearch/sam2/blob/52198ead0eb13ae8270bea6ca768ef175f5bf167/tools/vos_inference.py#L117. Turning on AMP should resolve this issue, where the bfloat16 inputs and float32 weights would be compatible.
Background
propagate_in_video
allows passingstart_frame_idx
to specify the starting frame for propagation,max_frame_num_to_track
to set the maximum length of propagation, andreverse
to control the propagation direction, whereTrue
indicates reverse propagation.Error When we set
reverse=True
propagation from the most recent frame for a fixed length. We got an error:RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float
Reason When setting reverse propagation from the most recent frame for a fixed length, the
_prepare_memory_conditioned_features
method in theSAM2Base
class will enter the if notis_init_cond_frame:
branch. This branch is generally not entered during forward propagation because the first frame in forward propagation is usually a condition frame. The purpose of this branch is to condition the visual features of the current frame on previous memory. In the aforementioned branch, whenfeats = prev["maskmem_features"].to(device, non_blocking=True)
is retrieved, the value istorch.bfloat16
, which can cause subsequentMemoryAttention
data type mismatch errors.Temporary solution add
feats = feats.to(torch.float32)
Question Why
prev["maskmem_features"]
inSAM2Base
class istorch.bfloat16
dtype?Where is it generated?