The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.14k
stars
1.1k
forks
source link
problem about pos_embed, window_embed and window_spec in Hiera #407
Hello, I am currently fine-tuning SAM2 using my own dataset (image size is 472x472). I have already converted my dataset into a format similar to SA-1B and written the config file. However, when I run the training script (train.py) in the Python training directory, I get the following error:
rank0: Traceback (most recent call last):
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 270, in
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 240, in main
rank0: single_node_runner(cfg, main_port)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 53, in single_node_runner
rank0: single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 41, in single_proc_run
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 515, in run
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 532, in run_train
rank0: outs = self.train_epoch(dataloader)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 749, in train_epoch
rank0: self._run_step(batch, phase, loss_mts, extra_loss_mts)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 865, in _run_step
rank0: loss_dict, batch_size, extra_losses = self._step(
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 457, in _step
rank0: outputs = model(batch)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
rank0: return self._call_impl(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
rank0: return forward_call(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
rank0: else self._run_ddp_forward(*inputs, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
rank0: return self.module(*inputs, **kwargs) # type: ignore[index]
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
rank0: return self._call_impl(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
rank0: return forward_call(*args, **kwargs)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/model/sam2.py", line 110, in forward
rank0: backbone_out = self.forward_image(input.flat_img_batch)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/sam2_base.py", line 469, in forward_image
rank0: backbone_out = self.image_encoder(img_batch)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
rank0: return self._call_impl(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
rank0: return forward_call(*args, **kwargs)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/backbones/image_encoder.py", line 31, in forward
rank0: features, pos = self.neck(self.trunk(sample))
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
rank0: return self._call_impl(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
rank0: return forward_call(*args, **kwargs)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/backbones/hieradet.py", line 294, in forward
rank0: x = x + self._get_pos_embed(x.shape[1:3])
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/backbones/hieradet.py", line 280, in _get_pos_embed
rank0: pos_embed = pos_embed + window_embed.tile(
rank0: RuntimeError: The size of tensor a (118) must match the size of tensor b (112) at non-singleton dimension 3
From the error message, I found the following code:
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
#print(pos_embed.shape)
# 确保 h 和 w 是 window_embed 对应维度的倍数
pos_embed = pos_embed + window_embed.tile(
[ x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
)
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.patch_embed(x)
# x: (B, H, W, C)
# Add pos embed
#print(x.shape)
x = x + self._get_pos_embed(x.shape[1:3])
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if (i == self.stage_ends[-1]) or (
i in self.stage_ends and self.return_interm_layers
):
feats = x.permute(0, 3, 1, 2)
outputs.append(feats)
return outputs
i find that in this line:
[ x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
x // y is not int type, so it occours mismatch.
Could you please tell me how to solve this?
Hello, I am currently fine-tuning SAM2 using my own dataset (image size is 472x472). I have already converted my dataset into a format similar to SA-1B and written the config file. However, when I run the training script (train.py) in the Python training directory, I get the following error:
rank0: Traceback (most recent call last): rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 270, in
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 240, in main rank0: single_node_runner(cfg, main_port) rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 53, in single_node_runner rank0: single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=num_proc) rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/train.py", line 41, in single_proc_run
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 515, in run
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 532, in run_train rank0: outs = self.train_epoch(dataloader)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 749, in train_epoch rank0: self._run_step(batch, phase, loss_mts, extra_loss_mts) rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 865, in _run_step rank0: loss_dict, batch_size, extra_losses = self._step(
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/trainer.py", line 457, in _step rank0: outputs = model(batch)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl rank0: return self._call_impl(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl rank0: return forward_call(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward rank0: else self._run_ddp_forward(*inputs, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward rank0: return self.module(*inputs, **kwargs) # type: ignore[index]
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl rank0: return self._call_impl(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl rank0: return forward_call(*args, **kwargs)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/training/model/sam2.py", line 110, in forward rank0: backbone_out = self.forward_image(input.flat_img_batch)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/sam2_base.py", line 469, in forward_image rank0: backbone_out = self.image_encoder(img_batch)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl rank0: return self._call_impl(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl rank0: return forward_call(*args, **kwargs)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/backbones/image_encoder.py", line 31, in forward rank0: features, pos = self.neck(self.trunk(sample))
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl rank0: return self._call_impl(*args, **kwargs)
rank0: File "/home/zhaobh/miniconda3/envs/SAM2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl rank0: return forward_call(*args, **kwargs)
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/backbones/hieradet.py", line 294, in forward rank0: x = x + self._get_pos_embed(x.shape[1:3])
rank0: File "/home/zhaobh/1128_Workspace/Fracture/sam2/sam2/modeling/backbones/hieradet.py", line 280, in _get_pos_embed rank0: pos_embed = pos_embed + window_embed.tile(
rank0: RuntimeError: The size of tensor a (118) must match the size of tensor b (112) at non-singleton dimension 3
From the error message, I found the following code:
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: h, w = hw window_embed = self.pos_embed_window
print(window_embed.shape)
i find that in this line: [ x // y for x, y in zip(pos_embed.shape, window_embed.shape)] x // y is not int type, so it occours mismatch. Could you please tell me how to solve this?