Closed kyakuno closed 3 months ago
@ooe1123 g2pの方が終わりましたら、こちらをお願いできると嬉しいです。
〇 GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py
class PatchMerging(nn.Module):
...
def forward(self, x, H, W):
...
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
class BasicLayer(nn.Module):
...
def __init__(
...
):
...
self.use_checkpoint = use_checkpoint
def forward(self, x, H, W):
...
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
...
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
↓
class PatchMerging(nn.Module):
...
def forward(self, x, H, W):
...
# padding
if torch.onnx.is_in_onnx_export():
# Updated
x = F.pad(x, (0, 0, 0, torch.remainder(W, 2), 0, torch.remainder(H, 2)))
else:
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
class BasicLayer(nn.Module):
...
def __init__(
...
):
...
self.use_checkpoint = False
def forward(self, x, H, W):
...
if torch.onnx.is_in_onnx_export():
Hp = (H + self.window_size - 1).floor_divide(self.window_size) * self.window_size
Wp = (W + self.window_size - 1).floor_divide(self.window_size) * self.window_size
else:
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
...
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = torch.floor_divide(H + 1, 2), torch.floor_divide(W + 1, 2)
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
〇 GroundingDINO/groundingdino/models/GroundingDINO/ms_deform_attn.py
def multi_scale_deformable_attn_pytorch(
...
) -> torch.Tensor:
...
for level, (H_, W_) in enumerate(value_spatial_shapes):
...
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
- )
).type(torch.float16)
class MultiScaleDeformableAttention(nn.Module):
...
def forward(
...
) -> torch.Tensor:
...
if torch.cuda.is_available() and value.is_cuda:
...
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights
)
↓
def multi_scale_deformable_attn_pytorch(
...
) -> torch.Tensor:
for level, (H_, W_) in enumerate(value_spatial_shapes):
...
sampling_value_l_ = F.grid_sample(
value_l_.to(torch.float32), sampling_grid_l_.to(torch.float32), mode="bilinear", padding_mode="zeros", align_corners=False
).type(torch.float16)
class MultiScaleDeformableAttention(nn.Module):
...
def forward(
...
) -> torch.Tensor:
...
if False:
...
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights
)
〇 GroundingDINO/groundingdino/models/GroundingDINO/transformer.py
class Transformer(nn.Module):
...
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
...
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
...
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
else:
lvl_pos_embed = pos_embed
...
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device
)
class TransformerEncoder(nn.Module):
def __init__(
...
):
self.use_checkpoint = use_checkpoint
self.use_transformer_ckpt = use_transformer_ckpt
↓
class Transformer(nn.Module):
...
def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
...
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
...
if torch.onnx.is_in_onnx_export():
lvl_pos_embed = pos_embed + self._level_embed[lvl].view(1, 1, -1)
else:
if self.num_feature_levels > 1 and self.level_embed is not None:
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
else:
lvl_pos_embed = pos_embed
...
if torch.onnx.is_in_onnx_export():
spatial_shapes = torch.stack([torch.stack(x) for x in spatial_shapes]).to(src_flatten.device)
else:
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=src_flatten.device
)
class TransformerEncoder(nn.Module):
def __init__(
...
):
self.use_checkpoint = False
self.use_transformer_ckpt = False
〇 GroundingDINO/groundingdino/util/inference.py
def predict(
...
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
...
with torch.no_grad():
outputs = model(image[None], captions=[caption])
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
↓
def predict(
...
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
...
with torch.no_grad():
outputs = model(image[None], captions=[caption])
prediction_logits = outputs["pred_logits"].to(torch.float32).cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
〇 GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py
class GroundingDINO(nn.Module):
def forward(self, samples: NestedTensor, targets: List = None, **kw):
...
bert_output = self.bert(**tokenized_for_encoder) # bs, 195, 768
↓
class GroundingDINO(nn.Module):
def forward(self, samples: NestedTensor, targets: List = None, **kw):
...
if 1:
class Exp(torch.nn.Module):
def __init__(self, bert, feat_map, backbone, input_proj, transformer, bbox_embed, class_embed):
super().__init__()
self.bert = bert
self.feat_map = feat_map
self.backbone = backbone
self.input_proj = input_proj
# transformer.level_embed.cuda().detach()
transformer._level_embed = transformer.level_embed.cuda().detach()
self.transformer = transformer
self.bbox_embed = bbox_embed
self.class_embed = class_embed
self.max_text_len = 256
self.num_feature_levels = 4
def forward(self, samples, input_ids, token_type_ids, attention_mask, position_ids, text_token_mask):
inputs = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
bert_output = self.bert(**inputs)
encoded_text = self.feat_map(bert_output["last_hidden_state"])
encoded_text = encoded_text[:, : self.max_text_len, :]
text_token_mask = text_token_mask[:, : self.max_text_len]
position_ids = position_ids[:, : self.max_text_len]
attention_mask = attention_mask[
:, : self.max_text_len, : self.max_text_len
]
text_dict = {
"encoded_text": encoded_text, # bs, 195, d_model
"text_token_mask": text_token_mask, # bs, 195
"position_ids": position_ids, # bs, 195
"text_self_attention_masks": text_self_attention_masks, # bs, 195,195
}
samples = nested_tensor_from_tensor_list(samples)
a = self.backbone(samples)
features, poss = a
srcs = []
masks = []
for l, feat in enumerate(features):
src, mask = feat.decompose()
srcs.append(self.input_proj[l](src))
masks.append(mask)
assert mask is not None
if self.num_feature_levels > len(srcs):
_len_srcs = len(srcs)
for l in range(_len_srcs, self.num_feature_levels):
if l == _len_srcs:
src = self.input_proj[l](features[-1].tensors)
else:
src = self.input_proj[l](srcs[-1])
m = samples.mask
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
srcs.append(src)
masks.append(mask)
poss.append(pos_l)
input_query_bbox = input_query_label = attn_mask = dn_meta = None
a = self.transformer(
srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
)
hs, reference, hs_enc, ref_enc, init_box_proposal = a
# deformable-detr-like anchor update
outputs_coord_list = []
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
zip(reference[:-1], self.bbox_embed, hs)
):
layer_delta_unsig = layer_bbox_embed(layer_hs)
layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
outputs_coord_list.append(layer_outputs_unsig)
outputs_coord_list = torch.stack(outputs_coord_list)
# output
outputs_class = torch.stack(
[
layer_cls_embed(layer_hs, text_dict)
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
]
)
pred_logits = outputs_class[-1]
pred_boxes = outputs_coord_list[-1]
print("pred_logits---", pred_logits)
print("pred_logits---", pred_logits.shape)
print("pred_boxes---", pred_boxes)
print("pred_boxes---", pred_boxes.shape)
return pred_logits, pred_boxes
print("------>")
model = Exp(self.bert, self.feat_map, self.backbone, self.input_proj, self.transformer, self.bbox_embed, self.class_embed)
x = (samples, tokenized_for_encoder["input_ids"], tokenized_for_encoder["token_type_ids"], tokenized_for_encoder["attention_mask"], tokenized_for_encoder["position_ids"], tokenized.attention_mask.bool())
torch.onnx.export(
model, x, 'groundingdino_swint_ogc.onnx',
input_names=["samples", "input_ids", "token_type_ids", "attention_mask", "position_ids", "text_token_mask"],
output_names=["pred_logits", "pred_boxes"],
dynamic_axes={"samples": {2:"h", 3:"w"}, "input_ids": [1], "token_type_ids": [1], "attention_mask": [1,2], "position_ids": [1], "text_token_mask": [1]},
# do_constant_folding=False,
verbose=False, opset_version=17
)
print("<------")
exit()
〇 segment_anything/segment_anything/modeling/mask_decoder.py
class MaskDecoder(nn.Module):
...
def predict_masks(
...
) -> Tuple[torch.Tensor, torch.Tensor]:
...
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
↓
class MaskDecoder(nn.Module):
...
def predict_masks(
...
) -> Tuple[torch.Tensor, torch.Tensor]:
...
# Expand per-image data in batch direction to be per-mask
if torch.onnx.is_in_onnx_export():
src = image_embeddings
else:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
if torch.onnx.is_in_onnx_export():
pos_src = image_pe
else:
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
〇 segment_anything/segment_anything/modeling/sam.py
class Sam(nn.Module):
...
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
...
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
return masks
↓
class Sam(nn.Module):
...
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
...
masks = F.interpolate(masks, (original_size[0], original_size[1]), mode="bilinear", align_corners=False)
return masks
〇 segment_anything/segment_anything/predictor.py
class SamPredictor:
...
def set_torch_image(
...
) -> None:
...
self.original_size = original_image_size
self.input_size = tuple(transformed_image.shape[-2:])
input_image = self.model.preprocess(transformed_image)
self.features, self.interm_features = self.model.image_encoder(input_image)
self.is_image_set = True
...
def predict(
...
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
...
masks, iou_predictions, low_res_masks = self.predict_torch(
coords_torch,
labels_torch,
box_torch,
mask_input_torch,
multimask_output,
return_logits=return_logits,
hq_token_only=hq_token_only,
)
↓
class SamPredictor:
...
def set_torch_image(
...
) -> None:
...
self.original_size = original_image_size
self.input_size = tuple(transformed_image.shape[-2:])
input_image = self.model.preprocess(transformed_image)
self.input_image = input_image # Add
self.features, self.interm_features = self.model.image_encoder(input_image)
self.is_image_set = True
...
def predict(
...
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
...
if 1:
class Exp(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.sam_model = model.model
def forward(self, input_image, box_torch, input_size, original_size):
features, interm_features = self.sam_model.image_encoder(input_image)
self.model.features = features
self.model.interm_features = interm_features
self.model.input_size = input_size
self.model.original_size = original_size
masks, iou_predictions, low_res_masks = self.model.predict_torch(
point_coords=None,
point_labels=None,
boxes=box_torch,
mask_input=None,
multimask_output=True,
return_logits=False,
hq_token_only=False,
)
return masks, iou_predictions, low_res_masks
with torch.no_grad():
print("------>")
from torch.autograd import Variable
model = Exp(self)
x = (self.input_image, box_torch, torch.tensor(self.input_size, dtype=torch.long).cuda(), torch.tensor(self.original_size, dtype=torch.long).cuda())
torch.onnx.export(
model, x, 'onnx/sam_vit_h_4b8939.onnx',
input_names=["image", "box", "input_size", "original_size"],
output_names=["masks", "iou_predictions", "low_res_masks"],
dynamic_axes={"masks": {2: "h", 3: "w"}, "iou_predictions": [1], "box": [1]},
#do_constant_folding=False,
verbose=False, opset_version=17
)
print("<------")
exit()
https://github.com/IDEA-Research/Grounded-Segment-Anything