google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
369 stars 51 forks source link

Replace Int64 with Int32 for edge #246

Open rfechtner opened 1 month ago

rfechtner commented 1 month ago

Description of the bug:

Hi,

I am trying to covert an PyTorch to TFLite which uses torch.argmax(..).indicies and torch.gather(..) - hence creating LongTensors (Int64). As my targeted runtime delegate does not support any int64 ops (including cast int64 -> int32), I am seeking to replace int64 ops by corresponding int32 ones.

Min rep. example:

import ai_edge_torch
import torch

sample_inputs = (torch.randn(1, 3, 224, 224),)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    select = tensor.max(dim=1).indices.unsqueeze(0)
    return torch.gather(tensor, dim=1, index=select)

model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
edge_model(*sample_inputs)

In the past I have been dong this via intermediate ONNX model representation where I modified the relevant nodes and then converted ONNX to TFLite, but with this new framework I’d hoped to get rid of the onnx.

I have tried to replace the torch.argmax() with a tf.math.argmax(.., output_type=tf.int32) or the numpy equivalent which supports specifying the output type or array, but that fails during torch.export() and results in

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py](https://qscyhyv3ft-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in unimplemented(msg, from_exc)
    219     if from_exc is not _NOTHING:
    220         raise Unsupported(msg) from from_exc
--> 221     raise Unsupported(msg)

Unsupported: 'skip function argmax_v2 in file /usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py'

from user code:
   File "<ipython-input-3-7b1c313a80d9>", line 10, in forward
    idx = tf.math.argmax(tensor.detach().numpy(), output_tzpe=tf.int32)

One remaining avenue I can think of is post processing the resulting flatbuffer representation and replacing the int64 ops here, but that seems quite brittle and overly complicated.

Any other suggestions? Or is there a way do dynamically replace functions?

Note: I had to pin tf-nightly==2.18.0.dev20240722 otherwise the export fails with:

Click this to collapse/fold.

---------------------------------------------------------------------------
ConverterError                            Traceback (most recent call last)
[](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in ()
     13 
     14 model = Model().eval()
---> 15 edge_model = ai_edge_torch.convert(model, sample_inputs)
     16 edge_model(*sample_inputs)

12 frames
[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/_convert/converter.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(module, sample_args, sample_kwargs, quant_config, dynamic_shapes, _ai_edge_converter_flags)
    239     _ai_edge_converter_flags = {}
    240 
--> 241   return Converter().convert(
    242       module,
    243       sample_args,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/_convert/converter.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(self, module, sample_args, sample_kwargs, quant_config, dynamic_shapes, _ai_edge_converter_flags)
    161             " specified."
    162         )
--> 163     return conversion.convert_signatures(
    164         self._signatures,
    165         quant_config=quant_config,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/_convert/conversion.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert_signatures(signatures, quant_config, _tfl_converter_flags)
    102   # Apply default fx passes
    103   exported_programs = list(map(_run_convert_passes, exported_programs))
--> 104   tflite_model = lowertools.exported_programs_to_tflite(
    105       exported_programs,
    106       signatures,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/lowertools/_shim.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in exported_programs_to_tflite(exported_programs, signatures, quant_config, _tfl_converter_flags)
     73   )
     74 
---> 75   return utils.merged_bundle_to_tfl_model(
     76       merged_bundle,
     77       signatures,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/lowertools/torch_xla_utils.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in merged_bundle_to_tfl_model(merged_bundle, signatures, quant_config, _tfl_converter_flags)
    271     conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
    272 
--> 273     tflite_model = converter.convert()
    274 
    275     if (

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in wrapper(self, *args, **kwargs)
   1236   def wrapper(self, *args, **kwargs):
   1237     # pylint: disable=protected-access
-> 1238     return self._convert_and_export_metrics(convert_func, *args, **kwargs)
   1239     # pylint: enable=protected-access
   1240 

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in _convert_and_export_metrics(self, convert_func, *args, **kwargs)
   1188     self._save_conversion_params_metric()
   1189     start_time = time.process_time()
-> 1190     result = convert_func(self, *args, **kwargs)
   1191     elapsed_time_ms = (time.process_time() - start_time) * 1000
   1192     if result:

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(self)
   1570     del trackable_obj
   1571     gc.collect()
-> 1572     return self._convert_from_saved_model(graph_def)
   1573 
   1574 

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in _convert_from_saved_model(self, graph_def)
   1428     converter_kwargs.update(quant_mode.converter_flags())
   1429 
-> 1430     result = _convert_saved_model(**converter_kwargs)
   1431     return self._optimize_tflite_model(
   1432         result,

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert_phase.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in wrapper(*args, **kwargs)
    210         else:
    211           report_error_message(str(converter_error))
--> 212         raise converter_error from None  # Re-throws the exception.
    213       except Exception as error:
    214         report_error_message(str(error))

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert_phase.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in wrapper(*args, **kwargs)
    203     def wrapper(*args, **kwargs):
    204       try:
--> 205         return func(*args, **kwargs)
    206       except ConverterError as converter_error:
    207         if converter_error.errors:

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert_saved_model(**kwargs)
   1043   model_flags = build_model_flags(**kwargs)
   1044   conversion_flags = build_conversion_flags(**kwargs)
-> 1045   data = convert(
   1046       model_flags,
   1047       conversion_flags,

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(model_flags, conversion_flags, input_data_str, debug_info_str, enable_mlir_converter)
    374               enable_mlir_converter,
    375           )
--> 376       raise converter_error
    377 
    378   return _run_deprecated_conversion_binary(

ConverterError: Could not translate MLIR to FlatBuffer.:0: error: loc(callsite(callsite(callsite("__main__.Model;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_5"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_11"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): 'vhlo.iota_v1' op is not part of the vhlo support yet.
:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
:0: note: loc(callsite(callsite(callsite("__main__.Model;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_5"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_11"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): see current operation: %12 = "vhlo.iota_v1"() <{iota_dimension = #vhlo.integer_v1<0 : i64>}> : () -> tensor<224xui32>
:0: error: failed while converting: 'main': 
:0: note: see current operation: 
"func.func"() <{arg_attrs = [{tf_saved_model.index_path = ["args_0"]}], function_type = (tensor<1x3x224x224xf32>) -> tensor<1x1x224x224xf32>, res_attrs = [{tf_saved_model.index_path = ["output_0"]}], sym_name = "main"}> ({
^bb0(%arg0: tensor<1x3x224x224xf32>):
  %0 = "arith.constant"() <{value = dense<[1, 1, 224, 224, 1]> : tensor<5xi32>}> : () -> tensor<5xi32>
  %1 = "arith.constant"() <{value = dense<[1, 1, 224, 224]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %2 = "arith.constant"() <{value = dense<[1, 1, 1, 224, 1]> : tensor<5xi32>}> : () -> tensor<5xi32>
  %3 = "arith.constant"() <{value = dense<[1, 1, 224, 1, 1]> : tensor<5xi32>}> : () -> tensor<5xi32>
  %4 = "arith.constant"() <{value = dense<[1, 1, 224, 224, 1]> : tensor<5xi64>}> : () -> tensor<5xi64>
  %5 = "arith.constant"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
  %6 = "arith.constant"() <{value = dense<0> : tensor<1x1x224x224x1xui32>}> : () -> tensor<1x1x224x224x1xui32>
  %7 = "tfl.arg_max"(%arg0, %5) : (tensor<1x3x224x224xf32>, tensor<1xi32>) -> tensor<1x224x224xi32>
  %8 = "tfl.cast"(%7) : (tensor<1x224x224xi32>) -> tensor<1x224x224xi64>
  %9 = "tfl.reshape"(%8, %1) : (tensor<1x224x224xi64>, tensor<4xi32>) -> tensor<1x1x224x224xi64>
  %10 = "tfl.cast"(%9) : (tensor<1x1x224x224xi64>) -> tensor<1x1x224x224xui32>
  %11 = "tfl.reshape"(%10, %0) : (tensor<1x1x224x224xui32>, tensor<5xi32>) -> tensor<1x1x224x224x1xui32>
  %12 = "vhlo.iota_v1"() <{iota_dimension = #vhlo.integer_v1<0 : i64>}> : () -> tensor<224xui32>
  %13 = "tfl.reshape"(%12, %3) : (tensor<224xui32>, tensor<5xi32>) -> tensor<1x1x224x1x1xui32>
  %14 = "tfl.broadcast_to"(%13, %4) : (tensor<1x1x224x1x1xui32>, tensor<5xi64>) -> tensor<1x1x224x224x1xui32>
  %15 = "tfl.reshape"(%12, %2) : (tensor<224xui32>, tensor<5xi32>) -> tensor<1x1x1x224x1xui32>
  %16 = "tfl.broadcast_to"(%15, %4) : (tensor<1x1x1x224x1xui32>, tensor<5xi64>) -> tensor<1x1x224x224x1xui32>
  %17 = "tfl.concatenation"(%6, %11, %14, %16) <{axis = 4 : i32, fused_activation_function = "NONE"}> : (tensor<1x1x224x224x1xui32>, tensor<1x1x224x224x1xui32>, tensor<1x1x224x224x1xui32>, tensor<1x1x224x224x1xui32>) -> tensor<1x1x224x224x4xui32>
  %18 = "tfl.cast"(%17) : (tensor<1x1x224x224x4xui32>) -> tensor<1x1x224x224x4xi64>
  %19 = "tfl.gather_nd"(%arg0, %18) : (tensor<1x3x224x224xf32>, tensor<1x1x224x224x4xi64>) -> tensor<1x1x224x224xf32>
  "func.return"(%19) : (tensor<1x1x224x224xf32>) -> ()
}) {tf.entry_function = {control_outputs = "", inputs = "serving_default_args_0:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} : () -> ()

pkgoogle commented 1 month ago

Hi @rfechtner, I was actually not able to replicate this issue if I use the latest code in main i.e.:

# navigate to ai-edge-torch repo
git switch main # if not already in the main branch
git pull # update to latest code
pip install -e .
pip install tensorflow-cpu # There was an import conflict that the latest code works better with torch-XLA this way
# run your script

Can you give that a try?, let me know what goes wrong if you try this way, also I recommend you use a new venv/conda environment to ensure there's no weird conflict this way. I should note I'm using Python=3.11 if that makes a difference.

rfechtner commented 1 month ago

Hi @pkgoogle thanks for the swift reply.

I've created a clean env with your suggestions. Same behaviour: I can convert the PyTorch model just fine but the exported model will contain Int64 Tensors (as torch.max() returns LongTensor).

model_explorer_graph

but I want to avoid Int64 ops. I was trying to replace the torch function with TensorFlow ops, where I can specify the output dimension e.g.:

class ModelInt32TF(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    return tf.math.argmax(
          sample_inputs[0], axis=1, output_type=tf.int32
    )

model_int32_tf = ModelInt32TF().eval()
edge_model_int32_tf = ai_edge_torch.convert(model_int32_tf, sample_inputs)
edge_model_int32_tf(*sample_inputs)

which yields the error mentioned above:

---------------------------------------------------------------------------
Unsupported                               Traceback (most recent call last)
[<ipython-input-31-b7e7c53e3cf8>](https://v6zwhn4z3l-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240920-060127_RC00_676789073#) in <cell line: 11>()
      9 
     10 model_int32_tf = ModelInt32TF().eval()
---> 11 edge_model_int32_tf = ai_edge_torch.convert(model_int32_tf, sample_inputs)
     12 edge_model_int32_tf(*sample_inputs)

35 frames
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py](https://v6zwhn4z3l-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240920-060127_RC00_676789073#) in unimplemented(msg, from_exc)
    219     if from_exc is not _NOTHING:
    220         raise Unsupported(msg) from from_exc
--> 221     raise Unsupported(msg)
    222 
    223 

Unsupported: 'skip function argmax_v2 in file /usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py'

from user code:
   File "<ipython-input-31-b7e7c53e3cf8>", line 6, in forward
    return tf.math.argmax(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Note: I can replace

select = tensor.max(dim=1).indices.unsqueeze(0) by

select = np.emtpy(.., dtype=np.int32)
np.argmax(tensor, keepdims=1, out=select)

but torch.gather() and np.take_along_axis() (the later will be converted to the former) will keep requiring a Long tensor input...

rfechtner commented 1 month ago

Using the np.argmax(..) instead of the tf.math.argmax() brings me a step further:

import ai_edge_torch
import torch

sample_inputs = (torch.randn(1, 3, 224, 224),)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    B, C, H, W = tensor.shape
    mode = np.empty((B, H, W), dtype=np.int32)
    np.argmax(tensor.detach().numpy(), axis=1, out=mode)
    mode = torch.from_numpy(mode).unsqueeze(0)

    return torch.gather(tensor, dim=1, index=mode.long())

model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
edge_model(*sample_inputs)

Allows me to create index tensor of dtype int32, but torch.gather() still requires LongTensor as input.

Environment: pip freeze

absl-py==1.4.0
accelerate==0.34.2
ai-edge-litert-nightly==1.0.1.dev20240924
ai-edge-model-explorer==0.1.12
ai-edge-model-explorer-adapter==0.1.5
ai-edge-quantizer-nightly==0.0.1.dev20240924
-e git+https://github.com/google-ai-edge/ai-edge-torch.git@c9973d2e7423e86f420576c0e5cac1181f79ac0e#egg=ai_edge_torch
aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiosignal==1.3.1
alabaster==0.7.16
albucore==0.0.16
albumentations==1.4.15
altair==4.2.2
annotated-types==0.7.0
anyio==3.7.1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.5.1
arviz==0.19.0
astropy==6.1.3
astropy-iers-data==0.2024.9.16.0.32.21
astunparse==1.6.3
async-timeout==4.0.3
atpublic==4.1.0
attrs==24.2.0
audioread==3.0.1
autograd==1.7.0
babel==2.16.0
backcall==0.2.0
beautifulsoup4==4.12.3
bidict==0.23.1
bigframes==1.17.0
bigquery-magics==0.2.0
bleach==6.1.0
blinker==1.4
blis==0.7.11
blosc2==2.0.0
bokeh==3.4.3
bqplot==0.12.43
branca==0.7.2
build==1.2.2
CacheControl==0.14.0
cachetools==5.5.0
catalogue==2.0.10
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.3.2
chex==0.1.86
clarabel==0.9.0
click==8.1.7
cloud-tpu-client==0.10
cloudpathlib==0.19.0
cloudpickle==2.2.1
cmake==3.30.3
cmdstanpy==1.2.4
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
community==1.0.0b1
confection==0.1.5
cons==0.4.6
contextlib2==21.6.0
contourpy==1.3.0
cryptography==43.0.1
cuda-python==12.2.1
cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.4.1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=57366e7ef09dc63e0b389aff20df6c37d91e2790065861ee31a4720149f5b694
cufflinks==0.17.3
cupy-cuda12x==12.2.0
cvxopt==1.3.2
cvxpy==1.5.3
cycler==0.12.1
cymem==2.0.8
Cython==3.0.11
dask==2024.8.0
datascience==0.17.6
db-dtypes==1.3.0
dbus-python==1.2.18
debugpy==1.6.6
decorator==4.4.2
defusedxml==0.7.1
distributed==2024.8.0
distro==1.7.0
dlib==19.24.2
dm-tree==0.1.8
docstring_parser==0.16
docutils==0.18.1
dopamine_rl==4.0.9
duckdb==1.1.0
earthengine-api==1.0.0
easydict==1.13
ecos==2.0.14
editdistance==0.8.1
eerepr==0.0.4
einops==0.8.0
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
entrypoints==0.4
et-xmlfile==1.1.0
etils==1.9.4
etuples==0.3.9
eval_type_backport==0.2.0
exceptiongroup==1.2.2
fastai==2.7.17
fastcore==1.7.8
fastdownload==0.0.7
fastjsonschema==2.20.0
fastprogress==1.0.3
fastrlock==0.8.2
filelock==3.16.1
firebase-admin==6.5.0
Flask==2.2.5
flatbuffers==24.3.25
flax==0.8.4
folium==0.17.0
fonttools==4.53.1
frozendict==2.4.4
frozenlist==1.4.1
fsspec==2024.6.1
future==1.0.0
gast==0.6.0
gcsfs==2024.6.1
GDAL==3.6.4
gdown==5.2.0
geemap==0.34.2
gensim==4.3.3
geocoder==1.38.1
geographiclib==2.0
geopandas==1.0.1
geopy==2.4.1
gin-config==0.5.0
glob2==0.7
google==2.0.3
google-ai-generativelanguage==0.6.6
google-api-core==1.34.1
google-api-python-client==1.8.0
google-auth==2.27.0
google-auth-httplib2==0.2.0
google-auth-oauthlib==1.2.1
google-cloud-aiplatform==1.67.1
google-cloud-bigquery==3.25.0
google-cloud-bigquery-connection==1.15.5
google-cloud-bigquery-storage==2.26.0
google-cloud-bigtable==2.26.0
google-cloud-core==2.4.1
google-cloud-datastore==2.19.0
google-cloud-firestore==2.16.1
google-cloud-functions==1.16.5
google-cloud-iam==2.15.2
google-cloud-language==2.13.4
google-cloud-pubsub==2.23.1
google-cloud-resource-manager==1.12.5
google-cloud-storage==2.8.0
google-cloud-translate==3.15.5
google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz#sha256=deb182392f5f78765ea686f1200ff7cfd42e31bdf8d172a68d6a29f657e1fe18
google-crc32c==1.6.0
google-generativeai==0.7.2
google-pasta==0.2.0
google-resumable-media==2.7.2
googleapis-common-protos==1.65.0
googledrivedownloader==0.4
graphviz==0.20.3
greenlet==3.1.0
grpc-google-iam-v1==0.13.1
grpcio==1.64.1
grpcio-status==1.48.2
gspread==6.0.2
gspread-dataframe==3.3.1
gym==0.25.2
gym-notices==0.0.8
h5netcdf==1.3.0
h5py==3.11.0
holidays==0.57
holoviews==1.19.1
html5lib==1.1
httpimport==1.4.0
httplib2==0.22.0
huggingface-hub==0.24.7
humanize==4.10.0
hyperopt==0.2.7
ibis-framework==8.0.0
idna==3.10
imageio==2.35.1
imageio-ffmpeg==0.5.1
imagesize==1.4.1
imbalanced-learn==0.12.3
imgaug==0.4.0
immutabledict==4.2.0
importlib_metadata==8.5.0
importlib_resources==6.4.5
imutils==0.5.4
inflect==7.4.0
iniconfig==2.0.0
intel-cmplr-lib-ur==2024.2.1
intel-openmp==2024.2.1
ipyevents==2.0.2
ipyfilechooser==0.6.0
ipykernel==5.5.6
ipyleaflet==0.19.2
ipyparallel==8.8.0
ipython==7.34.0
ipython-genutils==0.2.0
ipython-sql==0.5.0
ipytree==0.2.2
ipywidgets==7.7.1
itsdangerous==2.2.0
jax==0.4.26
jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750
jedi==0.19.1
jeepney==0.7.1
jellyfish==1.1.0
jieba==0.42.1
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.3.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter-client==6.1.12
jupyter-console==6.1.0
jupyter-leaflet==0.19.2
jupyter-server==1.24.0
jupyter_core==5.7.2
jupyterlab_pygments==0.3.0
jupyterlab_widgets==3.0.13
kaggle==1.6.17
kagglehub==0.3.0
keras==3.4.1
keras-nightly==3.5.0.dev2024092403
keyring==23.5.0
kiwisolver==1.4.7
langcodes==3.4.0
language_data==1.2.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1
lightgbm==4.5.0
linkify-it-py==2.0.3
llvmlite==0.43.0
locket==1.0.0
logical-unification==0.4.6
lxml==4.9.4
marisa-trie==1.2.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.1
matplotlib-inline==0.1.7
matplotlib-venn==1.1.1
mdit-py-plugins==0.4.2
mdurl==0.1.2
miniKanren==1.0.3
missingno==0.5.2
mistune==0.8.4
mizani==0.11.4
mkl==2024.2.2
ml-dtypes==0.4.1
mlxtend==0.23.1
more-itertools==10.5.0
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.0.8
multidict==6.1.0
multipledispatch==1.0.0
multitasking==0.0.11
murmurhash==1.0.10
music21==9.1.0
namex==0.0.8
natsort==8.4.0
nbclassic==1.1.0
nbclient==0.10.0
nbconvert==6.5.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
nibabel==5.2.1
nltk==3.8.1
notebook==6.5.5
notebook_shim==0.2.4
numba==0.60.0
numexpr==2.10.1
numpy==1.26.4
nvidia-nccl-cu12==2.23.4
nvtx==0.2.10
oauth2client==4.1.3
oauthlib==3.2.2
opencv-contrib-python==4.10.0.84
opencv-python==4.10.0.84
opencv-python-headless==4.10.0.84
openpyxl==3.1.5
opt-einsum==3.3.0
optax==0.2.2
optree==0.12.1
orbax-checkpoint==0.6.4
osqp==0.6.7.post0
packaging==24.1
pandas==2.1.4
pandas-datareader==0.10.0
pandas-gbq==0.23.1
pandas-stubs==2.1.4.231227
pandocfilters==1.5.1
panel==1.4.5
param==2.1.1
parso==0.8.4
parsy==2.1
partd==1.4.2
pathlib==1.0.1
patsy==0.5.6
peewee==3.17.6
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.4.0
pip-tools==7.4.1
platformdirs==4.3.6
plotly==5.24.1
plotnine==0.13.6
pluggy==1.5.0
polars==1.6.0
pooch==1.8.2
portpicker==1.5.2
prefetch_generator==1.0.3
preshed==3.0.9
prettytable==3.11.0
proglog==0.1.10
progressbar2==4.5.0
prometheus_client==0.20.0
promise==2.3
prompt_toolkit==3.0.47
prophet==1.1.5
proto-plus==1.24.0
protobuf==3.20.3
psutil==5.9.5
psycopg2==2.9.9
ptyprocess==0.7.0
py-cpuinfo==9.0.0
py4j==0.10.9.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycocotools==2.0.8
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
pydata-google-auth==1.8.2
pydot==3.0.1
pydot-ng==2.0.0
pydotplus==2.0.2
PyDrive==1.3.1
PyDrive2==1.20.0
pyerfa==2.0.1.4
pygame==2.6.0
Pygments==2.18.0
PyGObject==3.42.1
PyJWT==2.9.0
pymc==5.16.2
pymystem3==0.2.0
pynvjitlink-cu12==0.3.0
pyogrio==0.9.0
PyOpenGL==3.1.7
pyOpenSSL==24.2.1
pyparsing==3.1.4
pyperclip==1.9.0
pyproj==3.6.1
pyproject_hooks==1.1.0
pyshp==2.3.1
PySocks==1.7.1
pytensor==2.25.4
pytest==7.4.4
python-apt==2.4.0
python-box==7.2.0
python-dateutil==2.8.2
python-louvain==0.16
python-slugify==8.0.4
python-utils==3.8.2
pytz==2024.2
pyviz_comms==3.0.3
PyYAML==6.0.2
pyzmq==24.0.1
qdldl==0.1.7.post4
ratelim==0.1.6
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
requests-oauthlib==1.3.1
requirements-parser==0.9.0
rich==13.8.1
rmm-cu12==24.4.0
rpds-py==0.20.0
rpy2==3.4.2
rsa==4.9
safetensors==0.4.5
scikit-image==0.24.0
scikit-learn==1.5.2
scipy==1.13.1
scooby==0.10.0
scs==3.2.7
seaborn==0.13.1
SecretStorage==3.3.1
Send2Trash==1.8.3
sentencepiece==0.2.0
shapely==2.0.6
shellingham==1.5.4
simple-parsing==0.1.6
six==1.16.0
sklearn-pandas==2.2.0
smart-open==7.0.4
sniffio==1.3.1
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.6
soxr==0.5.0.post1
spacy==3.7.6
spacy-legacy==3.0.12
spacy-loggers==1.0.5
Sphinx==5.0.2
sphinxcontrib-applehelp==2.0.0
sphinxcontrib-devhelp==2.0.0
sphinxcontrib-htmlhelp==2.1.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==2.0.0
sphinxcontrib-serializinghtml==2.0.0
SQLAlchemy==2.0.35
sqlglot==20.11.0
sqlparse==0.5.1
srsly==2.4.8
stanio==0.5.1
statsmodels==0.14.3
StrEnum==0.4.15
sympy==1.13.3
tables==3.8.0
tabulate==0.9.0
tb-nightly==2.18.0a20240924
tbb==2021.13.1
tblib==3.0.0
tenacity==9.0.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
tensorflow-cpu==2.17.0
tensorflow-datasets==4.9.6
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.37.1
tensorflow-metadata==1.15.0
tensorflow-probability==0.24.0
tensorstore==0.1.65
termcolor==2.4.0
terminado==0.18.1
text-unidecode==1.3
textblob==0.17.1
tf-slim==1.1.0
tf_keras==2.17.0
tf_nightly==2.18.0.dev20240923
thinc==8.2.5
threadpoolctl==3.5.0
tifffile==2024.8.30
tinycss2==1.3.0
tokenizers==0.19.1
toml==0.10.2
tomli==2.0.1
toolz==0.12.1
torch==2.4.0+cpu
torch-xla==2.4.0
torchaudio==2.4.0+cpu
torchsummary==1.5.1
torchvision==0.19.0+cpu
tornado==6.3.3
tqdm==4.66.5
traitlets==5.7.1
traittypes==0.2.1
transformers==4.44.2
tweepy==4.14.0
typeguard==4.3.0
typer==0.12.5
types-pytz==2024.2.0.20240913
types-setuptools==75.1.0.20240917
typing_extensions==4.12.2
tzdata==2024.1
tzlocal==5.2
uc-micro-py==1.0.3
uritemplate==3.0.1
urllib3==2.2.3
vega-datasets==0.9.0
wadllib==1.3.6
wasabi==1.1.3
wcwidth==0.2.13
weasel==0.4.1
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.0.4
widgetsnbextension==3.6.9
wordcloud==1.9.3
wrapt==1.16.0
xarray==2024.9.0
xarray-einstats==0.8.0
xgboost==2.1.1
xlrd==2.0.1
xyzservices==2024.9.0
yarl==1.11.1
yellowbrick==1.5
yfinance==0.2.43
zict==3.0.0
zipp==3.20.2

rfechtner commented 1 month ago

If I replace the torch.gather() by advanced indexing, I still get a Int64 OP Less that seems to be introduced for slicing:

import ai_edge_torch
import torch

sample_inputs = (torch.randn(1, 3, 224, 224),)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    B, C, H, W = tensor.shape
    mode = np.empty((B, H, W), dtype=np.int32)
    np.argmax(tensor.detach().numpy(), axis=1, out=mode)
    mode = torch.from_numpy(mode) # (1, H, W)

    collected = torch.empty((B, 1, H, W), dtype=tensor.dtype, device=tensor.device)
    for b in range(B):
      collected[b, 0] = tensor[
          torch.arange(B, dtype=torch.int32).unsqueeze(-1).unsqueeze(-1),
          mode,
          torch.arange(H, dtype=torch.int32),
          torch.arange(W, dtype=torch.int32)
      ]
    return collected

model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
result = edge_model(*sample_inputs)
print(f"Output: {result.shape}")
edge_model.export("fancy.tflite")

model_explorer_graph-4

pkgoogle commented 2 weeks ago

Hi @rfechtner, I think it's because np.argmax still also returns a int64 return value/tensor. I'm wondering if you can implement an explicit int32 or lower precision argmax function which never touches/gets turned into int64 values. I suppose a feature like index quantization or just general precision lowering might be interesting.

rfechtner commented 2 weeks ago

Hi, @pkgoogle, yes precisely. There seems no "out of the box" approach to advanced indexing, without implicit Int64 calls due to underlying int64 long tensor indices in Numpy's & PyTorch's implementations.

Going through ONNX it's quite straightforward to modify the graph after export and replace the dtype of relevant ops to avoid int64. Unfortunately, less trivial in the flatbuffer format..

Would you be so kind providing some references for the index quantisation & precision lowering approaches you mentioned?

Cheers

pkgoogle commented 2 weeks ago

I think that would be a feature we might implement -- something you may want to try yourself is reimplement np.argmax where it never becomes/touches an int64 tensor. (Just as a general python/pytorch function) and see if you can use that instead of np.argmax. The general form is for each index permutation of all the non-axis dimensions take the argmax of the 1D tensor produced by using that index permutation and all the values from moving through the axis dimension and that index value is the output tensor's value for that index permutation.

If you just want to test if it'll work maybe just implement a version where you know the input shape or maybe just try a 2D tensor first.

rfechtner commented 2 weeks ago

I will give it a shot, thanks a lot for the feedback!

A out of the box optimisation to avoid Int64 OPs via converter flags would be a awesome addition to get this to a one stop deployment pipeline. Please keep me posted on these efforts! :)

I'll report back as well.