axinc-ai / ailia-models

The collection of pre-trained, state-of-the-art AI models for ailia SDK
2.02k stars 319 forks source link

ADD SAM2 (ONNX) #1514

Closed kyakuno closed 1 month ago

kyakuno commented 3 months ago

変換元 https://github.com/facebookresearch/segment-anything-2

kyakuno commented 3 months ago

動画に対するセグメンテーションモデル。

kyakuno commented 2 months ago

推論ブロック https://github.com/facebookresearch/segment-anything-2/blob/main/sam2/sam2_image_predictor.py

kyakuno commented 2 months ago

建設予定地 https://github.com/axinc-ai/segment-anything-2

kyakuno commented 2 months ago

Largeモデルの試作品 https://storage.googleapis.com/ailia-models/segment-anything-2/mask_decoder.onnx https://storage.googleapis.com/ailia-models/segment-anything-2/prompt_encoder.onnx https://storage.googleapis.com/ailia-models/segment-anything-2/image_encoder.onnx

kyakuno commented 2 months ago

推論できるONNXができた https://storage.googleapis.com/ailia-models/segment-anything-2/mask_decoder_hiera_l.onnx https://storage.googleapis.com/ailia-models/segment-anything-2/prompt_encoder_sparse_hiera_l.onnx https://storage.googleapis.com/ailia-models/segment-anything-2/image_encoder_hiera_l.onnx

kyakuno commented 2 months ago

動画の場合は、  sam2_video_predictor.py -> track_step (sam2_base.py) -> _forward_sam_heads (sam2_base.py) で、_forward_sam_headsの中でモデルの推論が走る。

kyakuno commented 2 months ago

とりあえず、import_from_onnxをvideo_predictorにも実装してみると良さそう。

kyakuno commented 2 months ago

mask_decoderのmultimask_outputがONNXに変換すると定数化されるので、必ず4プレーン出力して、後処理で1プレーンか3プレーンを選択するように修正する。

imageの場合は必ず3プレーンの出力(multimask_output = True)だが、videoの場合は1プレーンの出力(multimask_output = False)も使用する。

kyakuno commented 2 months ago

memory_moduleの_encode_new_memoryでは、推論は行わずにFeatureを保存している。

kyakuno commented 2 months ago

静止画と動画で、image encoderの出力のbackbone_outの使用法が異なるので、後処理はnumpyで書いた方が良い。

kyakuno commented 2 months ago

memory_attentionで下記のエクスポートエラー。 ScalarType ComplexFloat is an unexpected tensor scalar type

dynamo_exportでもNGだった。 https://pytorch.org/docs/stable/onnx_dynamo.html

下記に同じ問題のIssueがある。 https://github.com/facebookresearch/segment-anything-2/issues/186

kyakuno commented 2 months ago

RoPEが原因みたい。

it contains ComplexFloat64 ops in Rotary Position Embedding (RoPE) implementation https://github.com/FlagAI-Open/FlagAI/issues/406

kyakuno commented 2 months ago

RoPEでconplexを使わずに2テンソルで処理すれば解決できる?

kyakuno commented 2 months ago

PromptEncoderのmasks対応を行った。masks_enableを追加してwhereで切り替えた。 https://storage.googleapis.com/ailia-models/segment-anything-2/prompt_encoder_hiera_l.onnx

kyakuno commented 2 months ago

largeとtinyでmask decoderとprompt encoderのweightのサイズは同じだが値が異なる。 image encoderをtinyで、mask decoderをlargeで処理すると結果の精度が落ちる。

kyakuno commented 1 month ago

下記のリポジトリではMemoryAttentionのcomplexをmatmulに置き換えているので参考になる。 https://github.com/heyoeyo/muggled_sam/blob/9404efe12cae4c015832cfdb1b4695c9c86f77d7/lib/v2_sam/components/memfuse_attention.py#L294

kyakuno commented 1 month ago

下記でONNXに変換することができた。 https://github.com/axinc-ai/segment-anything-2/blob/onnx/sam2/modeling/sam/transformer.py#L340

kyakuno commented 1 month ago

Memory Attention Tiny : https://storage.googleapis.com/ailia-models/segment-anything-2/memory_attention_hiera_t.onnx Netron : https://netron.app/?url=https://storage.googleapis.com/ailia-models/segment-anything-2/memory_attention_hiera_t.onnx

kyakuno commented 1 month ago

下記もエクスポートが必要。

        if self.use_obj_ptrs_in_encoder:
            # a linear projection on SAM output tokens to turn them into object pointers
            self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
            if self.use_mlp_for_obj_ptr_proj:
                self.obj_ptr_proj = MLP(
                    self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
                )
kyakuno commented 1 month ago

SAM2Baseのデフォルト引数は使用されず、configファイルの定数が使用されるのでサンプル作成時は注意。

kyakuno commented 1 month ago

build_sam.pyで下記のオプションが適用される。

Image

        hydra_overrides_extra += [
            # dynamically fall back to multi-mask if the single mask is not stable
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
        ]

Video

        hydra_overrides_extra += [
            # dynamically fall back to multi-mask if the single mask is not stable
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
            # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
            "++model.binarize_mask_from_pts_for_mem_enc=true",
            # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
            "++model.fill_hole_area=8",
        ]
kyakuno commented 1 month ago

動画モードの出力が合わないと思ったら、memory_encoderのskip_mask_sigmoidをFalseでエクスポートしてしまっていた。

kyakuno commented 1 month ago

memory_encoderを差し替え。 https://storage.googleapis.com/ailia-models/segment-anything-2/memory_encoder_hiera_t.onnx

kyakuno commented 1 month ago

この段階でtorchと混在で正常に推論できるようになった。

kyakuno commented 1 month ago

memory_attentionで6次元テンソルが出現する。

ailia.core.AiliaInvalidLayerException: code: -10 (Incorrect layer parameter. [broken or unsupported AI model file])
+ error detail : Layer:/layers.0/cross_attn_image/Tile_output_0(Tile) Error:Unacceptable input shape. [ inputs:1 ]

スクリーンショット 2024-09-03 11 09 12

kyakuno commented 1 month ago

動画の場合、hiera_tは正しい絵が出るが、hiera_lが正しい絵が出ない。

kyakuno commented 1 month ago

ImageEncoderのembed_dimがsとtは96で、b+は112、lは144になっている。

kyakuno commented 1 month ago

エクスポータの方では動作している。

kyakuno commented 1 month ago

なぜかbackbone_featuresの段階で合わない。

Exporter

begin prompt encoder onnx
begin mask decoder onnx
backbone_features 46462.91
image_pe 12279.771
sparse_embeddings 11.907056
dense_embeddings 729.623
high_res_features 8821.962
high_res_features -939.2373

ailia-models

begin prompt encoder onnx
begin mask decoder onnx
backbone_features 41639.44
image_pe 12279.771
sparse_embeddings 11.907056
dense_embeddings 729.623
high_res_features 8821.962
high_res_features -939.2373
kyakuno commented 1 month ago

image encoderの出力。この段階では一致している。

exporter

vision_features 44474.938
vision_pos_enc_0 8304827.0
vision_pos_enc_1 2075745.8
vision_pos_enc_2 518706.06
backbone_fpn_0 8821.962
backbone_fpn_1 -939.2373
backbone_fpn_2 44474.938

ailia-models

vision_features 44474.938
vision_pos_enc_0 8304827.0
vision_pos_enc_1 2075745.8
vision_pos_enc_2 518706.06
backbone_fpn_0 8821.962
backbone_fpn_1 -939.2373
backbone_fpn_2 44474.938
kyakuno commented 1 month ago

no_mem_embedの値が異なる。

kyakuno commented 1 month ago

no_mem_embedはtorchの乱数で作っているので、実行のたびに異なる。

kyakuno commented 1 month ago

exporterでは固定されているので、no_mem_embedの値が学習時に固定されている気配。

kyakuno commented 1 month ago

パラメータを揃えてみたが、出力異常は治らない。

kyakuno commented 1 month ago

memory_encoderのモデルの出力が異なる。

kyakuno commented 1 month ago

exporter

begin memory encoder onnx
pix_feat 44474.938
mask_for_mem -9685020.0
vision_features 117341.8
vision_pos_enc 127058.88

ailia models

begin memory encoder onnx
pix_feat 44474.938
mask_for_mem -9685020.0
vision_features 110066.125
vision_pos_enc 127058.88
kyakuno commented 1 month ago

memory_encoder_hiera_l.onnx.prototxtが古くてsigmoidが入ってしまっている。

kyakuno commented 1 month ago

no_mem_embedはパラメータに持つ必要はなかった。完全一致を目指す場合のみ必要。

kyakuno commented 1 month ago

prototxtを修正すると正しい絵が出力された。あとは、サンプルの整理。

kyakuno commented 1 month ago

RoPEAttentionのfeat_sizeが512x512想定の値になっていて、forwardでrot_matの再確保が走っている。

feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution

ここは固定的に確保するようにした方が良さそう。

kyakuno commented 1 month ago

yamlでも(32, 32)になっているが、  q.shape[-2] が4096で、64*64を期待している。 学習は512x512で行われていて、推論時に1024x1024にしている感。

kyakuno commented 1 month ago

memory_attentionにおいて、初回推論では、  q = 1, 1, 4096, 256  k = 1, 1, 8200, 256 になり、2回目の推論では、  q = 1, 1, 4096, 256  k = 1, 1, 12300, 256 と増加していく。

この値を元に、nk // nqで下方向に丸めてrepeatする。

def apply_rotary_matenc(xq, xk, rotmats, repeat_freqs_k=False):   
    bq, hq, nq, cq = xq.shape
    bk, hk, nk, ck = xk.shape

    q_out = torch.matmul(rotmats, xq.reshape(bq, hq, nq, cq // 2, 2, 1)).flatten(3)
    k_rotmat = rotmats.repeat(1, 1, nk // nq, 1, 1, 1) if repeat_freqs_k else rotmats
    k_out = torch.matmul(k_rotmat, xk.reshape(bk, hk, nk, ck // 2, 2, 1)).flatten(3)

    return q_out, k_out
kyakuno commented 1 month ago

memory_attentionのtfliteのエクスポートエラーの詳細。

tensorflow.lite.python.convert_phase.ConverterError: Variable constant folding is failed. Please consider using enabling `experimental_enable_resource_variables` flag in the TFLite converter object. For example, converter.experimental_enable_resource_variables = True<unknown>:0: error: loc(callsite(callsite(callsite("sam2.modeling.memory_attention.MemoryAttention/sam2.modeling.memory_attention.MemoryAttentionLayer_0/sam2.modeling.sam.transformer.RoPEAttention_cross_attn_image;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_2397"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_2635"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): failed to legalize operation 'tfl.pad' that was explicitly marked illegal
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
kyakuno commented 1 month ago

num_k_exclude_rope = 0とするとここのエラーは通過するので、TraceするにはintではなくTensorにする必要があるかも。 tfliteのTrace対象がTensorだけな気がする。

kyakuno commented 1 month ago

Tensorにすると、Dynamic slicing on data-dependent value is not supportedになる。

torch._dynamo.exc.Unsupported: Dynamic slicing on data-dependent value is not supported k[:, :, :num_k_rope]

kyakuno commented 1 month ago

配列オーバを考慮できないためなので、checkを書くと通るらしい。

            torch._check_is_size(num_k_rope)
            torch._check(num_k_rope < k.shape[2])

https://github.com/pytorch/pytorch/issues/123592

kyakuno commented 1 month ago

RoPEAttentionのnum_k_exclude_ropeはrot_matの4096の倍数に合わせるために存在している。

kyakuno commented 1 month ago

やっぱり、うまくtfliteに出力できないので、余剰分とそれ以外でkとvを引きまわした方が良さそうな気配がある。

kyakuno commented 1 month ago

memory_1とmemory_2に分割することで、MemoryAttentionをtfliteに変換はできた。 しかし、Reshapeに定数が入ってしまい、Dynamic Shapeに対応できない。

スクリーンショット 2024-09-05 18 21 38

kyakuno commented 1 month ago

ONNX torch.exportだとDynamic Shapeになる。

スクリーンショット 2024-09-05 18 31 42