facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.14k stars 1.1k forks source link

problem about pos_embed, window_embed and window_spec in Hiera #407

Closed BENgoooo closed 1 week ago

BENgoooo commented 2 weeks ago

Hello, I am currently fine-tuning SAM2 using my own dataset (image size is 472x472). I have already converted my dataset into a format similar to SA-1B and written the config file. However, when I run the training script (train.py) in the Python training directory, I get the following error:

rank0: Traceback (most recent call last): rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 270, in

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 240, in main rank0: single_node_runner(cfg, main_port) rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 53, in single_node_runner rank0: single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc) rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 41, in single_proc_run

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 515, in run

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 532, in run_train rank0: outs = self.train_epoch(dataloader)

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 749, in train_epoch rank0: self._run_step(batch, phase, loss_mts, extra_loss_mts) rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 865, in _run_step rank0: loss_dict, batch_size, extra_losses = self._step(

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 457, in _step rank0: outputs = model(batch)

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl rank0: return self._call_impl(*args, **kwargs)

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl rank0: return forward_call(*args, **kwargs)

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward rank0: else self._run_ddp_forward(*inputs, **kwargs)

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward rank0: return self.module(*inputs, **kwargs) # type: ignore[index]

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl rank0: return self._call_impl(*args, **kwargs)

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl rank0: return forward_call(*args, **kwargs)

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/model/sam2.py", line 110, in forward rank0: backbone_out = self.forward_image(input.flat_img_batch)

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/sam2_base.py", line 469, in forward_image rank0: backbone_out = self.image_encoder(img_batch)

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl rank0: return self._call_impl(*args, **kwargs)

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl rank0: return forward_call(*args, **kwargs)

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/backbones/image_encoder.py", line 31, in forward rank0: features, pos = self.neck(self.trunk(sample))

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl rank0: return self._call_impl(*args, **kwargs)

rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl rank0: return forward_call(*args, **kwargs)

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/backbones/hieradet.py", line 294, in forward rank0: x = x + self._get_pos_embed(x.shape[1:3])

rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/backbones/hieradet.py", line 280, in _get_pos_embed rank0: pos_embed = pos_embed + window_embed.tile(

rank0: RuntimeError: The size of tensor a (118) must match the size of tensor b (112) at non-singleton dimension 3

From the error message, I found the following code:

def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: h, w = hw window_embed = self.pos_embed_window

print(window_embed.shape)

    pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
    #print(pos_embed.shape)
    # 确保 h 和 w 是 window_embed 对应维度的倍数
    pos_embed = pos_embed + window_embed.tile(
        [ x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
    )

    pos_embed = pos_embed.permute(0, 2, 3, 1)
    return pos_embed

def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
    x = self.patch_embed(x)
    # x: (B, H, W, C)

    # Add pos embed
    #print(x.shape)
    x = x + self._get_pos_embed(x.shape[1:3])

    outputs = []
    for i, blk in enumerate(self.blocks):
        x = blk(x)
        if (i == self.stage_ends[-1]) or (
            i in self.stage_ends and self.return_interm_layers
        ):
            feats = x.permute(0, 3, 1, 2)
            outputs.append(feats)

    return outputs

i find that in this line: [ x // y for x, y in zip(pos_embed.shape, window_embed.shape)] x // y is not int type, so it occours mismatch. Could you please tell me how to solve this?