axinc-ai / ailia-models-tflite

Quantized version of model library
24 stars 2 forks source link

ADD SAM2 (tflite) #88

Open kyakuno opened 3 weeks ago

kyakuno commented 3 weeks ago

edge-ai-torchで変換を検討。 https://github.com/facebookresearch/segment-anything-2 https://medium.com/axinc/ai-edge-torch%E3%81%A7pytorch%E3%81%8B%E3%82%89tflite%E3%81%AB%E5%A4%89%E6%8F%9B%E3%81%99%E3%82%8B-376be7dc5619 難しそうであれば、下記と同様に、Pytorch -> Kerasを検討。 https://github.com/tirthasheshpatel/segment_anything_keras

kyakuno commented 3 weeks ago

image_encoderをonnxには変換できるが、edge-ai-torchとCUDAでtfliteに変換しようとすると下記のエラーになる。

ValueError: Cannot view a tensor with shape torch.Size([1024, 4, 4, 288]) and strides (4608, 4, 1, 16) as a tensor with shape (16384, 288)!

While executing %view_38 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%view_37, [16384, 288]), kwargs = {})
Original traceback:
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/sam2_base.py", line 196, in forward
    backbone_out = self.forward_image(input_image)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/sam2_base.py", line 485, in forward_image
    backbone_out = self.image_encoder(img_batch)
  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/backbones/image_encoder.py", line 31, in forward
    features, pos = self.neck(self.trunk(sample))
  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/backbones/hieradet.py", line 284, in forward
    x = blk(x)
  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/backbones/hieradet.py", line 147, in forward
    x = self.attn(x)
  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/backbones/hieradet.py", line 77, in forward
    x = self.proj(x)
kyakuno commented 3 weeks ago

segment_anything_kerasにSAM2対応のIssueはあるが未対応。 https://github.com/tirthasheshpatel/segment_anything_keras/issues/4

kyakuno commented 3 weeks ago

edge-ai-torchをCPUモードで動かすとエラーが変わる。

<unknown>:0: error: failed while converting: 'main':
Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select
TF Select ops: Relu6
Details:
        tf.Relu6(tensor<256x1xi64>) -> (tensor<256x1xi64>)
        tf.Relu6(tensor<256xi64>) -> (tensor<256xi64>)
kyakuno commented 3 weeks ago

tfliteのconverterには_ai_edge_converter_flagsでフラグを与えらえる。 https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md

kyakuno commented 3 weeks ago

下記でFlexを有効にするとImageEncoderのエクスポート自体はできた。

            import ai_edge_torch
            import tensorflow as tf
            sample_inputs = (input_image,)
            tfl_converter_flags = {'target_spec': {'supported_ops': [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]}}
            edge_model = ai_edge_torch.convert(self.model, sample_inputs, _ai_edge_converter_flags=tfl_converter_flags)
            edge_model.export("image_encoder.tflite")
kyakuno commented 3 weeks ago

Relu6はFlexRelu6になる。

スクリーンショット 2024-08-21 15 29 38

kyakuno commented 3 weeks ago

SegmentAnthing Int8の論文。ピークが2つある分布になるため、対処が必要と記載がある。 https://openaccess.thecvf.com/content/CVPR2024/papers/Lv_PTQ4SAM_Post-Training_Quantization_for_Segment_Anything_CVPR_2024_paper.pdf

kyakuno commented 3 weeks ago

edge-ai-torchでflexを有効にして量子化すると、MixedPrecisionのグラフになる。 Convはint8で、それ以外のオペレータはFloatで動く。

kyakuno commented 3 weeks ago

full int8 quantはまだサポートされていない気配がある。 https://github.com/google-ai-edge/ai-edge-torch/issues/150

kyakuno commented 3 weeks ago

量子化対象のオペレータのリストが下記にある。 https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/quantize/pt2e_quantizer.py

kyakuno commented 3 weeks ago

PromptEncoderは下記のエラーになる。

RuntimeError: This model contains ops not capturable by Pytorch/XLA: aten::nonzero

エクスポートできない原因は、prompt_encoder.pyの_embed_pointsの下記のロジック。 ONNXだとWhereになる部分。

        point_embedding[labels == -1] = self.not_a_point_embed.weight
        point_embedding[labels == 0] += self.point_embeddings[0].weight
        point_embedding[labels == 1] += self.point_embeddings[1].weight
        point_embedding[labels == 2] += self.point_embeddings[2].weight
        point_embedding[labels == 3] += self.point_embeddings[3].weight

下記のようにするとエクスポートできるようになる。

        labels = labels.int()
        table = torch.zeros((5, self.point_embeddings[0].weight.shape[1]))
        table[0] = self.not_a_point_embed.weight
        table[1] = self.point_embeddings[0].weight
        table[2] = self.point_embeddings[1].weight
        table[3] = self.point_embeddings[2].weight
        table[4] = self.point_embeddings[3].weight
        for i in range(labels.shape[0]):
            point_embedding[i] = point_embedding[i] + table[labels[i] + 1]
kyakuno commented 3 weeks ago

ImageEncoderをtfliteで推論してみる。ImageEncoderは正常にexportできている。

output1

kyakuno commented 3 weeks ago

MaskDecoderは下記のエラーになる。

  File "/home/kyakuno/.local/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 221, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function ConstantVariable(int: 512) [ConstantVariable()] {}

from user code:
   File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/sam/mask_decoder.py", line 137, in forward
    masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
  File "/mnt/c/Users/kyakuno/Desktop/segment-anything-2/sam2/modeling/sam/mask_decoder.py", line 198, in predict_masks
    sparse_prompt_embeddings.size(0), -1, -1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

I0000 00:00:1724378905.868815    9151 cpu_client.cc:470] TfrtCpuClient destroyed.
kyakuno commented 3 weeks ago

size -> shapeに置き換えるとここはpassする。

kyakuno commented 3 weeks ago

Floatモデルはすべてエクスポートできた https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mask_decoder_hiera_l.tflite https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/prompt_encoder_sparse_hiera_l.tflite https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/image_encoder_hiera_l.tflite

kyakuno commented 3 weeks ago

量子化でキャリブレーションしようとすると下記のエラーになる。

torch.histogram: input tensor and hist tensor should have the same dtype, but got input long int and hist float

https://github.com/pytorch/pytorch/issues/74420

kyakuno commented 3 weeks ago

量子化のフロー https://pytorch.org/tutorials/prototype/quantization_in_pytorch_2_0_export_tutorial.html

kyakuno commented 3 weeks ago

グラフの入力が問題ではなく、グラフの途中でint64のテンソルが出てきて対応できなくなっている。

kyakuno commented 3 weeks ago

torch/ao/quantization/observer.pyのreset_histogramで下記のコードを追加すると通る。

    def reset_histogram(self, x: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor) -> None:
        self.min_val.resize_(min_val.shape)
        self.min_val.copy_(min_val)
        self.max_val.resize_(max_val.shape)
        self.max_val.copy_(max_val)
        assert (
            min_val.numel() == 1 and max_val.numel() == 1
        ), "histogram min/max values must be scalar."
        if x.dtype != torch.float32: # 追加
            x = x.float() # 追加
        torch.histc(
            x, self.bins, min=min_val, max=max_val, out=self.histogram  # type: ignore[arg-type]
        )
kyakuno commented 3 weeks ago

int8版のImageEncoderの出力。魂はあっていそう。

output1

kyakuno commented 3 weeks ago

使用したバージョン torch 2.4.0 ai-edge-torch 0.2.0

kyakuno commented 3 weeks ago

PromptEncoderはlabelがint64で量子化できない。

kyakuno commented 3 weeks ago

PromptEncoderは演算量が少ないのでfloatで動かしてもいい気はする。

kyakuno commented 3 weeks ago

MaskDecoderの量子化結果。ImageEncoderはfloat。

output1

kyakuno commented 3 weeks ago

量子化モデル。ただし、キャリブレーションはtruckでしか行っていない。

https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/mask_decoder_hiera_l_int8.tflite https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/image_encoder_hiera_l_int8.tflite

kyakuno commented 3 weeks ago

ai-edge-torchがLLMを対象にしているためか、テンソルは結構、floatになっている。 量子化ツールに手を入れないといけない感じはする。

kyakuno commented 3 weeks ago

ImageEncoderのAttentionのところは、weightはint8になっていて、floatにして行列積を行っている。

スクリーンショット 2024-08-23 14 12 20

kyakuno commented 2 weeks ago

ImageEncoderの出力にposを含めると、量子化でエラーが発生する。 is_dynamic=Trueにして、DynamicQuantizationにすると通るが、演算は全てFloatになる。

kyakuno commented 2 weeks ago

ImageEncoderのDynamicQuantizationの出力。出力は綺麗。

output1

kyakuno commented 2 weeks ago

is_dynamic = False(後処理もモデルに含むことでposを出力に含めない) https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/image_encoder_hiera_l_int8_is_dynamc_false.tflite

is_dynamic = True(posを出力に含める) https://storage.googleapis.com/ailia-models-tflite/segment-anything-2/image_encoder_hiera_l_int8_is_dynamic_true.tflite

kyakuno commented 1 week ago

ImageEncoder、PromptEncoder、MaskDecoderは正常にtflite (float)に変換できた。

MemoryAttentionの変換は下記で行う。 https://github.com/axinc-ai/ailia-models/issues/1514

kyakuno commented 3 days ago

MaskDecoderをdynamic shapeにしようとすると、下記のエラーになる。

 File "/home/kyakuno/.local/lib/python3.10/site-packages/torch_xla/experimental/unbounded_dynamism_export.py", line 115, in decompose_dynamic_shape_select
    assert symbolic_dims[
AssertionError: Selected dim cannot be symbolic

tfliteはStatic Shapeのみ対応にする。