axinc-ai / ailia-models

The collection of pre-trained, state-of-the-art AI models for ailia SDK
2.04k stars 324 forks source link

ADD Grounded-Segment-Anything #1493

Closed kyakuno closed 3 months ago

kyakuno commented 4 months ago

https://github.com/IDEA-Research/Grounded-Segment-Anything

kyakuno commented 4 months ago

@ooe1123 g2pの方が終わりましたら、こちらをお願いできると嬉しいです。

ooe1123 commented 3 months ago

groundingdino_swint_ogc.onnx

〇 GroundingDINO/groundingdino/models/GroundingDINO/backbone/swin_transformer.py

class PatchMerging(nn.Module):
    ...
    def forward(self, x, H, W):
        ...
        # padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

class BasicLayer(nn.Module):
    ...
    def __init__(
        ...
    ):
        ...
        self.use_checkpoint = use_checkpoint

    def forward(self, x, H, W):
        ...
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        ...
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            Wh, Ww = (H + 1) // 2, (W + 1) // 2
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W

class PatchMerging(nn.Module):
    ...
    def forward(self, x, H, W):
        ...
        # padding
        if torch.onnx.is_in_onnx_export():
            # Updated
            x = F.pad(x, (0, 0, 0, torch.remainder(W, 2), 0, torch.remainder(H, 2)))
        else:
            pad_input = (H % 2 == 1) or (W % 2 == 1)
            if pad_input:
                x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

class BasicLayer(nn.Module):
    ...
    def __init__(
        ...
    ):
        ...
        self.use_checkpoint = False

    def forward(self, x, H, W):
        ...
        if torch.onnx.is_in_onnx_export():
            Hp = (H + self.window_size - 1).floor_divide(self.window_size) * self.window_size
            Wp = (W + self.window_size - 1).floor_divide(self.window_size) * self.window_size
        else:
            Hp = int(np.ceil(H / self.window_size)) * self.window_size
            Wp = int(np.ceil(W / self.window_size)) * self.window_size
        ...
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            Wh, Ww = torch.floor_divide(H + 1, 2), torch.floor_divide(W + 1, 2)
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W

〇 GroundingDINO/groundingdino/models/GroundingDINO/ms_deform_attn.py

def multi_scale_deformable_attn_pytorch(
    ...
) -> torch.Tensor:
    ...
    for level, (H_, W_) in enumerate(value_spatial_shapes):
        ...
        sampling_value_l_ = F.grid_sample(
            value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
-        )
        ).type(torch.float16)

class MultiScaleDeformableAttention(nn.Module):
    ...
    def forward(
        ...
    ) -> torch.Tensor:
        ...
        if torch.cuda.is_available() and value.is_cuda:
            ...
        else:
            output = multi_scale_deformable_attn_pytorch(
                value, spatial_shapes, sampling_locations, attention_weights
            )

def multi_scale_deformable_attn_pytorch(
    ...
) -> torch.Tensor:
    for level, (H_, W_) in enumerate(value_spatial_shapes):
        ...
        sampling_value_l_ = F.grid_sample(
            value_l_.to(torch.float32), sampling_grid_l_.to(torch.float32), mode="bilinear", padding_mode="zeros", align_corners=False
        ).type(torch.float16)

class MultiScaleDeformableAttention(nn.Module):
    ...
    def forward(
        ...
    ) -> torch.Tensor:
        ...
        if False:
            ...
        else:
            output = multi_scale_deformable_attn_pytorch(
                value, spatial_shapes, sampling_locations, attention_weights
            )

〇 GroundingDINO/groundingdino/models/GroundingDINO/transformer.py

class Transformer(nn.Module):
    ...
    def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
        ...
        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
            ...
            if self.num_feature_levels > 1 and self.level_embed is not None:
                lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
            else:
                lvl_pos_embed = pos_embed
        ...
        spatial_shapes = torch.as_tensor(
            spatial_shapes, dtype=torch.long, device=src_flatten.device
        )

class TransformerEncoder(nn.Module):
    def __init__(
        ...
    ):
        self.use_checkpoint = use_checkpoint
        self.use_transformer_ckpt = use_transformer_ckpt

class Transformer(nn.Module):
    ...
    def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, text_dict=None):
        ...
        for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
            ...
            if torch.onnx.is_in_onnx_export():
                lvl_pos_embed = pos_embed + self._level_embed[lvl].view(1, 1, -1)
            else:
                if self.num_feature_levels > 1 and self.level_embed is not None:
                    lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
                else:
                    lvl_pos_embed = pos_embed
        ...
        if torch.onnx.is_in_onnx_export():
            spatial_shapes = torch.stack([torch.stack(x) for x in spatial_shapes]).to(src_flatten.device)
        else:
            spatial_shapes = torch.as_tensor(
                spatial_shapes, dtype=torch.long, device=src_flatten.device
            )

class TransformerEncoder(nn.Module):
    def __init__(
        ...
    ):
        self.use_checkpoint = False
        self.use_transformer_ckpt = False

〇 GroundingDINO/groundingdino/util/inference.py

def predict(
        ...
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
    ...
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])

    prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0]  # prediction_logits.shape = (nq, 256)

def predict(
        ...
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
    ...
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])

    prediction_logits = outputs["pred_logits"].to(torch.float32).cpu().sigmoid()[0]  # prediction_logits.shape = (nq, 256)

〇 GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py

class GroundingDINO(nn.Module):
    def forward(self, samples: NestedTensor, targets: List = None, **kw):
        ...
        bert_output = self.bert(**tokenized_for_encoder)  # bs, 195, 768

class GroundingDINO(nn.Module):
    def forward(self, samples: NestedTensor, targets: List = None, **kw):
        ...
        if 1:
            class Exp(torch.nn.Module):
                def __init__(self, bert, feat_map, backbone, input_proj, transformer, bbox_embed, class_embed):
                    super().__init__()
                    self.bert = bert
                    self.feat_map = feat_map
                    self.backbone = backbone
                    self.input_proj = input_proj
                    # transformer.level_embed.cuda().detach()
                    transformer._level_embed = transformer.level_embed.cuda().detach()
                    self.transformer = transformer
                    self.bbox_embed = bbox_embed
                    self.class_embed = class_embed
                    self.max_text_len = 256
                    self.num_feature_levels = 4

                def forward(self, samples, input_ids, token_type_ids, attention_mask, position_ids, text_token_mask):
                    inputs = {
                        "input_ids": input_ids,
                        "token_type_ids": token_type_ids,
                        "attention_mask": attention_mask,
                        "position_ids": position_ids,
                    }
                    bert_output = self.bert(**inputs)
                    encoded_text = self.feat_map(bert_output["last_hidden_state"])

                    encoded_text = encoded_text[:, : self.max_text_len, :]
                    text_token_mask = text_token_mask[:, : self.max_text_len]
                    position_ids = position_ids[:, : self.max_text_len]
                    attention_mask = attention_mask[
                        :, : self.max_text_len, : self.max_text_len
                    ]

                    text_dict = {
                        "encoded_text": encoded_text,  # bs, 195, d_model
                        "text_token_mask": text_token_mask,  # bs, 195
                        "position_ids": position_ids,  # bs, 195
                        "text_self_attention_masks": text_self_attention_masks,  # bs, 195,195
                    }

                    samples = nested_tensor_from_tensor_list(samples)
                    a = self.backbone(samples)
                    features, poss = a

                    srcs = []
                    masks = []
                    for l, feat in enumerate(features):
                        src, mask = feat.decompose()
                        srcs.append(self.input_proj[l](src))
                        masks.append(mask)
                        assert mask is not None
                    if self.num_feature_levels > len(srcs):
                        _len_srcs = len(srcs)
                        for l in range(_len_srcs, self.num_feature_levels):
                            if l == _len_srcs:
                                src = self.input_proj[l](features[-1].tensors)
                            else:
                                src = self.input_proj[l](srcs[-1])
                            m = samples.mask
                            mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
                            pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
                            srcs.append(src)
                            masks.append(mask)
                            poss.append(pos_l)

                    input_query_bbox = input_query_label = attn_mask = dn_meta = None
                    a = self.transformer(
                        srcs, masks, input_query_bbox, poss, input_query_label, attn_mask, text_dict
                    )
                    hs, reference, hs_enc, ref_enc, init_box_proposal = a

                    # deformable-detr-like anchor update
                    outputs_coord_list = []
                    for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(
                        zip(reference[:-1], self.bbox_embed, hs)
                    ):
                        layer_delta_unsig = layer_bbox_embed(layer_hs)
                        layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig)
                        layer_outputs_unsig = layer_outputs_unsig.sigmoid()
                        outputs_coord_list.append(layer_outputs_unsig)
                    outputs_coord_list = torch.stack(outputs_coord_list)

                    # output
                    outputs_class = torch.stack(
                        [
                            layer_cls_embed(layer_hs, text_dict)
                            for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
                        ]
                    )
                    pred_logits = outputs_class[-1]
                    pred_boxes = outputs_coord_list[-1]
                    print("pred_logits---", pred_logits)
                    print("pred_logits---", pred_logits.shape)
                    print("pred_boxes---", pred_boxes)
                    print("pred_boxes---", pred_boxes.shape)
                    return pred_logits, pred_boxes

            print("------>")
            model = Exp(self.bert, self.feat_map, self.backbone, self.input_proj, self.transformer, self.bbox_embed, self.class_embed)
            x = (samples, tokenized_for_encoder["input_ids"], tokenized_for_encoder["token_type_ids"], tokenized_for_encoder["attention_mask"], tokenized_for_encoder["position_ids"], tokenized.attention_mask.bool())
            torch.onnx.export(
                model, x, 'groundingdino_swint_ogc.onnx',
                input_names=["samples", "input_ids", "token_type_ids", "attention_mask", "position_ids", "text_token_mask"],
                output_names=["pred_logits", "pred_boxes"],
                dynamic_axes={"samples": {2:"h", 3:"w"}, "input_ids": [1], "token_type_ids": [1], "attention_mask": [1,2], "position_ids": [1], "text_token_mask": [1]},
                # do_constant_folding=False,
                verbose=False, opset_version=17
            )
            print("<------")
            exit()
ooe1123 commented 3 months ago

sam_vit_h_4b8939.onnx

〇 segment_anything/segment_anything/modeling/mask_decoder.py

class MaskDecoder(nn.Module):
    ...
    def predict_masks(
        ...
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        ...
        # Expand per-image data in batch direction to be per-mask
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        src = src + dense_prompt_embeddings
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)

class MaskDecoder(nn.Module):
    ...
    def predict_masks(
        ...
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        ...
        # Expand per-image data in batch direction to be per-mask
        if torch.onnx.is_in_onnx_export():
            src = image_embeddings
        else:
            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        src = src + dense_prompt_embeddings
        if torch.onnx.is_in_onnx_export():
            pos_src = image_pe
        else:
            pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)

〇 segment_anything/segment_anything/modeling/sam.py

class Sam(nn.Module):
    ...
    def postprocess_masks(
        self,
        masks: torch.Tensor,
        input_size: Tuple[int, ...],
        original_size: Tuple[int, ...],
    ) -> torch.Tensor:
        ...
        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
        return masks

class Sam(nn.Module):
    ...
    def postprocess_masks(
        self,
        masks: torch.Tensor,
        input_size: Tuple[int, ...],
        original_size: Tuple[int, ...],
    ) -> torch.Tensor:
        ...
        masks = F.interpolate(masks, (original_size[0], original_size[1]), mode="bilinear", align_corners=False)
        return masks

〇 segment_anything/segment_anything/predictor.py

class SamPredictor:
    ...
    def set_torch_image(
        ...
    ) -> None:
        ...
        self.original_size = original_image_size
        self.input_size = tuple(transformed_image.shape[-2:])
        input_image = self.model.preprocess(transformed_image)
        self.features, self.interm_features = self.model.image_encoder(input_image)
        self.is_image_set = True
    ...
    def predict(
        ...
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        ...
        masks, iou_predictions, low_res_masks = self.predict_torch(
            coords_torch,
            labels_torch,
            box_torch,
            mask_input_torch,
            multimask_output,
            return_logits=return_logits,
            hq_token_only=hq_token_only,
        )

class SamPredictor:
    ...
    def set_torch_image(
        ...
    ) -> None:
        ...
        self.original_size = original_image_size
        self.input_size = tuple(transformed_image.shape[-2:])
        input_image = self.model.preprocess(transformed_image)
        self.input_image = input_image  # Add
        self.features, self.interm_features = self.model.image_encoder(input_image)
        self.is_image_set = True
    ...
    def predict(
        ...
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        ...
        if 1:
            class Exp(torch.nn.Module):
                def __init__(self, model):
                    super().__init__()
                    self.model = model
                    self.sam_model = model.model

                def forward(self, input_image, box_torch, input_size, original_size):
                    features, interm_features = self.sam_model.image_encoder(input_image)
                    self.model.features = features
                    self.model.interm_features = interm_features
                    self.model.input_size = input_size
                    self.model.original_size = original_size

                    masks, iou_predictions, low_res_masks = self.model.predict_torch(
                        point_coords=None,
                        point_labels=None,
                        boxes=box_torch,
                        mask_input=None,
                        multimask_output=True,
                        return_logits=False,
                        hq_token_only=False,
                    )
                    return masks, iou_predictions, low_res_masks

            with torch.no_grad():
                print("------>")
                from torch.autograd import Variable
                model = Exp(self)
                x = (self.input_image, box_torch, torch.tensor(self.input_size, dtype=torch.long).cuda(), torch.tensor(self.original_size, dtype=torch.long).cuda())
                torch.onnx.export(
                    model, x, 'onnx/sam_vit_h_4b8939.onnx',
                    input_names=["image", "box", "input_size", "original_size"],
                    output_names=["masks", "iou_predictions", "low_res_masks"],
                    dynamic_axes={"masks": {2: "h", 3: "w"}, "iou_predictions": [1], "box": [1]},
                    #do_constant_folding=False,
                    verbose=False, opset_version=17
                )
                print("<------")
                exit()