google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
300 stars 40 forks source link

Replace Int64 with Int32 for edge #246

Open rfechtner opened 1 week ago

rfechtner commented 1 week ago

Description of the bug:

Hi,

I am trying to covert an PyTorch to TFLite which uses torch.argmax(..).indicies and torch.gather(..) - hence creating LongTensors (Int64). As my targeted runtime delegate does not support any int64 ops (including cast int64 -> int32), I am seeking to replace int64 ops by corresponding int32 ones.

Min rep. example:

import ai_edge_torch
import torch

sample_inputs = (torch.randn(1, 3, 224, 224),)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    select = tensor.max(dim=1).indices.unsqueeze(0)
    return torch.gather(tensor, dim=1, index=select)

model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
edge_model(*sample_inputs)

In the past I have been dong this via intermediate ONNX model representation where I modified the relevant nodes and then converted ONNX to TFLite, but with this new framework I’d hoped to get rid of the onnx.

I have tried to replace the torch.argmax() with a tf.math.argmax(.., output_type=tf.int32) or the numpy equivalent which supports specifying the output type or array, but that fails during torch.export() and results in

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py](https://qscyhyv3ft-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in unimplemented(msg, from_exc)
    219     if from_exc is not _NOTHING:
    220         raise Unsupported(msg) from from_exc
--> 221     raise Unsupported(msg)

Unsupported: 'skip function argmax_v2 in file /usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py'

from user code:
   File "<ipython-input-3-7b1c313a80d9>", line 10, in forward
    idx = tf.math.argmax(tensor.detach().numpy(), output_tzpe=tf.int32)

One remaining avenue I can think of is post processing the resulting flatbuffer representation and replacing the int64 ops here, but that seems quite brittle and overly complicated.

Any other suggestions? Or is there a way do dynamically replace functions?

Note: I had to pin tf-nightly==2.18.0.dev20240722 otherwise the export fails with:

Click this to collapse/fold.

---------------------------------------------------------------------------
ConverterError                            Traceback (most recent call last)
[](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in ()
     13 
     14 model = Model().eval()
---> 15 edge_model = ai_edge_torch.convert(model, sample_inputs)
     16 edge_model(*sample_inputs)

12 frames
[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/_convert/converter.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(module, sample_args, sample_kwargs, quant_config, dynamic_shapes, _ai_edge_converter_flags)
    239     _ai_edge_converter_flags = {}
    240 
--> 241   return Converter().convert(
    242       module,
    243       sample_args,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/_convert/converter.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(self, module, sample_args, sample_kwargs, quant_config, dynamic_shapes, _ai_edge_converter_flags)
    161             " specified."
    162         )
--> 163     return conversion.convert_signatures(
    164         self._signatures,
    165         quant_config=quant_config,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/_convert/conversion.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert_signatures(signatures, quant_config, _tfl_converter_flags)
    102   # Apply default fx passes
    103   exported_programs = list(map(_run_convert_passes, exported_programs))
--> 104   tflite_model = lowertools.exported_programs_to_tflite(
    105       exported_programs,
    106       signatures,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/lowertools/_shim.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in exported_programs_to_tflite(exported_programs, signatures, quant_config, _tfl_converter_flags)
     73   )
     74 
---> 75   return utils.merged_bundle_to_tfl_model(
     76       merged_bundle,
     77       signatures,

[/usr/local/lib/python3.10/dist-packages/ai_edge_torch/lowertools/torch_xla_utils.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in merged_bundle_to_tfl_model(merged_bundle, signatures, quant_config, _tfl_converter_flags)
    271     conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
    272 
--> 273     tflite_model = converter.convert()
    274 
    275     if (

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in wrapper(self, *args, **kwargs)
   1236   def wrapper(self, *args, **kwargs):
   1237     # pylint: disable=protected-access
-> 1238     return self._convert_and_export_metrics(convert_func, *args, **kwargs)
   1239     # pylint: enable=protected-access
   1240 

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in _convert_and_export_metrics(self, convert_func, *args, **kwargs)
   1188     self._save_conversion_params_metric()
   1189     start_time = time.process_time()
-> 1190     result = convert_func(self, *args, **kwargs)
   1191     elapsed_time_ms = (time.process_time() - start_time) * 1000
   1192     if result:

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(self)
   1570     del trackable_obj
   1571     gc.collect()
-> 1572     return self._convert_from_saved_model(graph_def)
   1573 
   1574 

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/lite.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in _convert_from_saved_model(self, graph_def)
   1428     converter_kwargs.update(quant_mode.converter_flags())
   1429 
-> 1430     result = _convert_saved_model(**converter_kwargs)
   1431     return self._optimize_tflite_model(
   1432         result,

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert_phase.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in wrapper(*args, **kwargs)
    210         else:
    211           report_error_message(str(converter_error))
--> 212         raise converter_error from None  # Re-throws the exception.
    213       except Exception as error:
    214         report_error_message(str(error))

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert_phase.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in wrapper(*args, **kwargs)
    203     def wrapper(*args, **kwargs):
    204       try:
--> 205         return func(*args, **kwargs)
    206       except ConverterError as converter_error:
    207         if converter_error.errors:

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert_saved_model(**kwargs)
   1043   model_flags = build_model_flags(**kwargs)
   1044   conversion_flags = build_conversion_flags(**kwargs)
-> 1045   data = convert(
   1046       model_flags,
   1047       conversion_flags,

[/usr/local/lib/python3.10/dist-packages/tensorflow/lite/python/convert.py](https://e9nv1h674fr-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240919-060125_RC00_676352627#) in convert(model_flags, conversion_flags, input_data_str, debug_info_str, enable_mlir_converter)
    374               enable_mlir_converter,
    375           )
--> 376       raise converter_error
    377 
    378   return _run_deprecated_conversion_binary(

ConverterError: Could not translate MLIR to FlatBuffer.:0: error: loc(callsite(callsite(callsite("__main__.Model;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_5"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_11"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): 'vhlo.iota_v1' op is not part of the vhlo support yet.
:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from
:0: note: loc(callsite(callsite(callsite("__main__.Model;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_5"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_11"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): see current operation: %12 = "vhlo.iota_v1"() <{iota_dimension = #vhlo.integer_v1<0 : i64>}> : () -> tensor<224xui32>
:0: error: failed while converting: 'main': 
:0: note: see current operation: 
"func.func"() <{arg_attrs = [{tf_saved_model.index_path = ["args_0"]}], function_type = (tensor<1x3x224x224xf32>) -> tensor<1x1x224x224xf32>, res_attrs = [{tf_saved_model.index_path = ["output_0"]}], sym_name = "main"}> ({
^bb0(%arg0: tensor<1x3x224x224xf32>):
  %0 = "arith.constant"() <{value = dense<[1, 1, 224, 224, 1]> : tensor<5xi32>}> : () -> tensor<5xi32>
  %1 = "arith.constant"() <{value = dense<[1, 1, 224, 224]> : tensor<4xi32>}> : () -> tensor<4xi32>
  %2 = "arith.constant"() <{value = dense<[1, 1, 1, 224, 1]> : tensor<5xi32>}> : () -> tensor<5xi32>
  %3 = "arith.constant"() <{value = dense<[1, 1, 224, 1, 1]> : tensor<5xi32>}> : () -> tensor<5xi32>
  %4 = "arith.constant"() <{value = dense<[1, 1, 224, 224, 1]> : tensor<5xi64>}> : () -> tensor<5xi64>
  %5 = "arith.constant"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
  %6 = "arith.constant"() <{value = dense<0> : tensor<1x1x224x224x1xui32>}> : () -> tensor<1x1x224x224x1xui32>
  %7 = "tfl.arg_max"(%arg0, %5) : (tensor<1x3x224x224xf32>, tensor<1xi32>) -> tensor<1x224x224xi32>
  %8 = "tfl.cast"(%7) : (tensor<1x224x224xi32>) -> tensor<1x224x224xi64>
  %9 = "tfl.reshape"(%8, %1) : (tensor<1x224x224xi64>, tensor<4xi32>) -> tensor<1x1x224x224xi64>
  %10 = "tfl.cast"(%9) : (tensor<1x1x224x224xi64>) -> tensor<1x1x224x224xui32>
  %11 = "tfl.reshape"(%10, %0) : (tensor<1x1x224x224xui32>, tensor<5xi32>) -> tensor<1x1x224x224x1xui32>
  %12 = "vhlo.iota_v1"() <{iota_dimension = #vhlo.integer_v1<0 : i64>}> : () -> tensor<224xui32>
  %13 = "tfl.reshape"(%12, %3) : (tensor<224xui32>, tensor<5xi32>) -> tensor<1x1x224x1x1xui32>
  %14 = "tfl.broadcast_to"(%13, %4) : (tensor<1x1x224x1x1xui32>, tensor<5xi64>) -> tensor<1x1x224x224x1xui32>
  %15 = "tfl.reshape"(%12, %2) : (tensor<224xui32>, tensor<5xi32>) -> tensor<1x1x1x224x1xui32>
  %16 = "tfl.broadcast_to"(%15, %4) : (tensor<1x1x1x224x1xui32>, tensor<5xi64>) -> tensor<1x1x224x224x1xui32>
  %17 = "tfl.concatenation"(%6, %11, %14, %16) <{axis = 4 : i32, fused_activation_function = "NONE"}> : (tensor<1x1x224x224x1xui32>, tensor<1x1x224x224x1xui32>, tensor<1x1x224x224x1xui32>, tensor<1x1x224x224x1xui32>) -> tensor<1x1x224x224x4xui32>
  %18 = "tfl.cast"(%17) : (tensor<1x1x224x224x4xui32>) -> tensor<1x1x224x224x4xi64>
  %19 = "tfl.gather_nd"(%arg0, %18) : (tensor<1x3x224x224xf32>, tensor<1x1x224x224x4xi64>) -> tensor<1x1x224x224xf32>
  "func.return"(%19) : (tensor<1x1x224x224xf32>) -> ()
}) {tf.entry_function = {control_outputs = "", inputs = "serving_default_args_0:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} : () -> ()

pkgoogle commented 5 days ago

Hi @rfechtner, I was actually not able to replicate this issue if I use the latest code in main i.e.:

# navigate to ai-edge-torch repo
git switch main # if not already in the main branch
git pull # update to latest code
pip install -e .
pip install tensorflow-cpu # There was an import conflict that the latest code works better with torch-XLA this way
# run your script

Can you give that a try?, let me know what goes wrong if you try this way, also I recommend you use a new venv/conda environment to ensure there's no weird conflict this way. I should note I'm using Python=3.11 if that makes a difference.

rfechtner commented 5 days ago

Hi @pkgoogle thanks for the swift reply.

I've created a clean env with your suggestions. Same behaviour: I can convert the PyTorch model just fine but the exported model will contain Int64 Tensors (as torch.max() returns LongTensor).

model_explorer_graph

but I want to avoid Int64 ops. I was trying to replace the torch function with TensorFlow ops, where I can specify the output dimension e.g.:

class ModelInt32TF(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    return tf.math.argmax(
          sample_inputs[0], axis=1, output_type=tf.int32
    )

model_int32_tf = ModelInt32TF().eval()
edge_model_int32_tf = ai_edge_torch.convert(model_int32_tf, sample_inputs)
edge_model_int32_tf(*sample_inputs)

which yields the error mentioned above:

---------------------------------------------------------------------------
Unsupported                               Traceback (most recent call last)
[<ipython-input-31-b7e7c53e3cf8>](https://v6zwhn4z3l-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240920-060127_RC00_676789073#) in <cell line: 11>()
      9 
     10 model_int32_tf = ModelInt32TF().eval()
---> 11 edge_model_int32_tf = ai_edge_torch.convert(model_int32_tf, sample_inputs)
     12 edge_model_int32_tf(*sample_inputs)

35 frames
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py](https://v6zwhn4z3l-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240920-060127_RC00_676789073#) in unimplemented(msg, from_exc)
    219     if from_exc is not _NOTHING:
    220         raise Unsupported(msg) from from_exc
--> 221     raise Unsupported(msg)
    222 
    223 

Unsupported: 'skip function argmax_v2 in file /usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py'

from user code:
   File "<ipython-input-31-b7e7c53e3cf8>", line 6, in forward
    return tf.math.argmax(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Note: I can replace

select = tensor.max(dim=1).indices.unsqueeze(0) by

select = np.emtpy(.., dtype=np.int32)
np.argmax(tensor, keepdims=1, out=select)

but torch.gather() and np.take_along_axis() (the later will be converted to the former) will keep requiring a Long tensor input...

rfechtner commented 4 days ago

Using the np.argmax(..) instead of the tf.math.argmax() brings me a step further:

import ai_edge_torch
import torch

sample_inputs = (torch.randn(1, 3, 224, 224),)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    B, C, H, W = tensor.shape
    mode = np.empty((B, H, W), dtype=np.int32)
    np.argmax(tensor.detach().numpy(), axis=1, out=mode)
    mode = torch.from_numpy(mode).unsqueeze(0)

    return torch.gather(tensor, dim=1, index=mode.long())

model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
edge_model(*sample_inputs)

Allows me to create index tensor of dtype int32, but torch.gather() still requires LongTensor as input.

Environment: pip freeze

absl-py==1.4.0
accelerate==0.34.2
ai-edge-litert-nightly==1.0.1.dev20240924
ai-edge-model-explorer==0.1.12
ai-edge-model-explorer-adapter==0.1.5
ai-edge-quantizer-nightly==0.0.1.dev20240924
-e git+https://github.com/google-ai-edge/ai-edge-torch.git@c9973d2e7423e86f420576c0e5cac1181f79ac0e#egg=ai_edge_torch
aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiosignal==1.3.1
alabaster==0.7.16
albucore==0.0.16
albumentations==1.4.15
altair==4.2.2
annotated-types==0.7.0
anyio==3.7.1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.5.1
arviz==0.19.0
astropy==6.1.3
astropy-iers-data==0.2024.9.16.0.32.21
astunparse==1.6.3
async-timeout==4.0.3
atpublic==4.1.0
attrs==24.2.0
audioread==3.0.1
autograd==1.7.0
babel==2.16.0
backcall==0.2.0
beautifulsoup4==4.12.3
bidict==0.23.1
bigframes==1.17.0
bigquery-magics==0.2.0
bleach==6.1.0
blinker==1.4
blis==0.7.11
blosc2==2.0.0
bokeh==3.4.3
bqplot==0.12.43
branca==0.7.2
build==1.2.2
CacheControl==0.14.0
cachetools==5.5.0
catalogue==2.0.10
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.3.2
chex==0.1.86
clarabel==0.9.0
click==8.1.7
cloud-tpu-client==0.10
cloudpathlib==0.19.0
cloudpickle==2.2.1
cmake==3.30.3
cmdstanpy==1.2.4
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
community==1.0.0b1
confection==0.1.5
cons==0.4.6
contextlib2==21.6.0
contourpy==1.3.0
cryptography==43.0.1
cuda-python==12.2.1
cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.4.1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=57366e7ef09dc63e0b389aff20df6c37d91e2790065861ee31a4720149f5b694
cufflinks==0.17.3
cupy-cuda12x==12.2.0
cvxopt==1.3.2
cvxpy==1.5.3
cycler==0.12.1
cymem==2.0.8
Cython==3.0.11
dask==2024.8.0
datascience==0.17.6
db-dtypes==1.3.0
dbus-python==1.2.18
debugpy==1.6.6
decorator==4.4.2
defusedxml==0.7.1
distributed==2024.8.0
distro==1.7.0
dlib==19.24.2
dm-tree==0.1.8
docstring_parser==0.16
docutils==0.18.1
dopamine_rl==4.0.9
duckdb==1.1.0
earthengine-api==1.0.0
easydict==1.13
ecos==2.0.14
editdistance==0.8.1
eerepr==0.0.4
einops==0.8.0
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
entrypoints==0.4
et-xmlfile==1.1.0
etils==1.9.4
etuples==0.3.9
eval_type_backport==0.2.0
exceptiongroup==1.2.2
fastai==2.7.17
fastcore==1.7.8
fastdownload==0.0.7
fastjsonschema==2.20.0
fastprogress==1.0.3
fastrlock==0.8.2
filelock==3.16.1
firebase-admin==6.5.0
Flask==2.2.5
flatbuffers==24.3.25
flax==0.8.4
folium==0.17.0
fonttools==4.53.1
frozendict==2.4.4
frozenlist==1.4.1
fsspec==2024.6.1
future==1.0.0
gast==0.6.0
gcsfs==2024.6.1
GDAL==3.6.4
gdown==5.2.0
geemap==0.34.2
gensim==4.3.3
geocoder==1.38.1
geographiclib==2.0
geopandas==1.0.1
geopy==2.4.1
gin-config==0.5.0
glob2==0.7
google==2.0.3
google-ai-generativelanguage==0.6.6
google-api-core==1.34.1
google-api-python-client==1.8.0
google-auth==2.27.0
google-auth-httplib2==0.2.0
google-auth-oauthlib==1.2.1
google-cloud-aiplatform==1.67.1
google-cloud-bigquery==3.25.0
google-cloud-bigquery-connection==1.15.5
google-cloud-bigquery-storage==2.26.0
google-cloud-bigtable==2.26.0
google-cloud-core==2.4.1
google-cloud-datastore==2.19.0
google-cloud-firestore==2.16.1
google-cloud-functions==1.16.5
google-cloud-iam==2.15.2
google-cloud-language==2.13.4
google-cloud-pubsub==2.23.1
google-cloud-resource-manager==1.12.5
google-cloud-storage==2.8.0
google-cloud-translate==3.15.5
google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz#sha256=deb182392f5f78765ea686f1200ff7cfd42e31bdf8d172a68d6a29f657e1fe18
google-crc32c==1.6.0
google-generativeai==0.7.2
google-pasta==0.2.0
google-resumable-media==2.7.2
googleapis-common-protos==1.65.0
googledrivedownloader==0.4
graphviz==0.20.3
greenlet==3.1.0
grpc-google-iam-v1==0.13.1
grpcio==1.64.1
grpcio-status==1.48.2
gspread==6.0.2
gspread-dataframe==3.3.1
gym==0.25.2
gym-notices==0.0.8
h5netcdf==1.3.0
h5py==3.11.0
holidays==0.57
holoviews==1.19.1
html5lib==1.1
httpimport==1.4.0
httplib2==0.22.0
huggingface-hub==0.24.7
humanize==4.10.0
hyperopt==0.2.7
ibis-framework==8.0.0
idna==3.10
imageio==2.35.1
imageio-ffmpeg==0.5.1
imagesize==1.4.1
imbalanced-learn==0.12.3
imgaug==0.4.0
immutabledict==4.2.0
importlib_metadata==8.5.0
importlib_resources==6.4.5
imutils==0.5.4
inflect==7.4.0
iniconfig==2.0.0
intel-cmplr-lib-ur==2024.2.1
intel-openmp==2024.2.1
ipyevents==2.0.2
ipyfilechooser==0.6.0
ipykernel==5.5.6
ipyleaflet==0.19.2
ipyparallel==8.8.0
ipython==7.34.0
ipython-genutils==0.2.0
ipython-sql==0.5.0
ipytree==0.2.2
ipywidgets==7.7.1
itsdangerous==2.2.0
jax==0.4.26
jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750
jedi==0.19.1
jeepney==0.7.1
jellyfish==1.1.0
jieba==0.42.1
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.3.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter-client==6.1.12
jupyter-console==6.1.0
jupyter-leaflet==0.19.2
jupyter-server==1.24.0
jupyter_core==5.7.2
jupyterlab_pygments==0.3.0
jupyterlab_widgets==3.0.13
kaggle==1.6.17
kagglehub==0.3.0
keras==3.4.1
keras-nightly==3.5.0.dev2024092403
keyring==23.5.0
kiwisolver==1.4.7
langcodes==3.4.0
language_data==1.2.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1
lightgbm==4.5.0
linkify-it-py==2.0.3
llvmlite==0.43.0
locket==1.0.0
logical-unification==0.4.6
lxml==4.9.4
marisa-trie==1.2.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.1
matplotlib-inline==0.1.7
matplotlib-venn==1.1.1
mdit-py-plugins==0.4.2
mdurl==0.1.2
miniKanren==1.0.3
missingno==0.5.2
mistune==0.8.4
mizani==0.11.4
mkl==2024.2.2
ml-dtypes==0.4.1
mlxtend==0.23.1
more-itertools==10.5.0
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.0.8
multidict==6.1.0
multipledispatch==1.0.0
multitasking==0.0.11
murmurhash==1.0.10
music21==9.1.0
namex==0.0.8
natsort==8.4.0
nbclassic==1.1.0
nbclient==0.10.0
nbconvert==6.5.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
nibabel==5.2.1
nltk==3.8.1
notebook==6.5.5
notebook_shim==0.2.4
numba==0.60.0
numexpr==2.10.1
numpy==1.26.4
nvidia-nccl-cu12==2.23.4
nvtx==0.2.10
oauth2client==4.1.3
oauthlib==3.2.2
opencv-contrib-python==4.10.0.84
opencv-python==4.10.0.84
opencv-python-headless==4.10.0.84
openpyxl==3.1.5
opt-einsum==3.3.0
optax==0.2.2
optree==0.12.1
orbax-checkpoint==0.6.4
osqp==0.6.7.post0
packaging==24.1
pandas==2.1.4
pandas-datareader==0.10.0
pandas-gbq==0.23.1
pandas-stubs==2.1.4.231227
pandocfilters==1.5.1
panel==1.4.5
param==2.1.1
parso==0.8.4
parsy==2.1
partd==1.4.2
pathlib==1.0.1
patsy==0.5.6
peewee==3.17.6
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.4.0
pip-tools==7.4.1
platformdirs==4.3.6
plotly==5.24.1
plotnine==0.13.6
pluggy==1.5.0
polars==1.6.0
pooch==1.8.2
portpicker==1.5.2
prefetch_generator==1.0.3
preshed==3.0.9
prettytable==3.11.0
proglog==0.1.10
progressbar2==4.5.0
prometheus_client==0.20.0
promise==2.3
prompt_toolkit==3.0.47
prophet==1.1.5
proto-plus==1.24.0
protobuf==3.20.3
psutil==5.9.5
psycopg2==2.9.9
ptyprocess==0.7.0
py-cpuinfo==9.0.0
py4j==0.10.9.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycocotools==2.0.8
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
pydata-google-auth==1.8.2
pydot==3.0.1
pydot-ng==2.0.0
pydotplus==2.0.2
PyDrive==1.3.1
PyDrive2==1.20.0
pyerfa==2.0.1.4
pygame==2.6.0
Pygments==2.18.0
PyGObject==3.42.1
PyJWT==2.9.0
pymc==5.16.2
pymystem3==0.2.0
pynvjitlink-cu12==0.3.0
pyogrio==0.9.0
PyOpenGL==3.1.7
pyOpenSSL==24.2.1
pyparsing==3.1.4
pyperclip==1.9.0
pyproj==3.6.1
pyproject_hooks==1.1.0
pyshp==2.3.1
PySocks==1.7.1
pytensor==2.25.4
pytest==7.4.4
python-apt==2.4.0
python-box==7.2.0
python-dateutil==2.8.2
python-louvain==0.16
python-slugify==8.0.4
python-utils==3.8.2
pytz==2024.2
pyviz_comms==3.0.3
PyYAML==6.0.2
pyzmq==24.0.1
qdldl==0.1.7.post4
ratelim==0.1.6
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
requests-oauthlib==1.3.1
requirements-parser==0.9.0
rich==13.8.1
rmm-cu12==24.4.0
rpds-py==0.20.0
rpy2==3.4.2
rsa==4.9
safetensors==0.4.5
scikit-image==0.24.0
scikit-learn==1.5.2
scipy==1.13.1
scooby==0.10.0
scs==3.2.7
seaborn==0.13.1
SecretStorage==3.3.1
Send2Trash==1.8.3
sentencepiece==0.2.0
shapely==2.0.6
shellingham==1.5.4
simple-parsing==0.1.6
six==1.16.0
sklearn-pandas==2.2.0
smart-open==7.0.4
sniffio==1.3.1
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.6
soxr==0.5.0.post1
spacy==3.7.6
spacy-legacy==3.0.12
spacy-loggers==1.0.5
Sphinx==5.0.2
sphinxcontrib-applehelp==2.0.0
sphinxcontrib-devhelp==2.0.0
sphinxcontrib-htmlhelp==2.1.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==2.0.0
sphinxcontrib-serializinghtml==2.0.0
SQLAlchemy==2.0.35
sqlglot==20.11.0
sqlparse==0.5.1
srsly==2.4.8
stanio==0.5.1
statsmodels==0.14.3
StrEnum==0.4.15
sympy==1.13.3
tables==3.8.0
tabulate==0.9.0
tb-nightly==2.18.0a20240924
tbb==2021.13.1
tblib==3.0.0
tenacity==9.0.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
tensorflow-cpu==2.17.0
tensorflow-datasets==4.9.6
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.37.1
tensorflow-metadata==1.15.0
tensorflow-probability==0.24.0
tensorstore==0.1.65
termcolor==2.4.0
terminado==0.18.1
text-unidecode==1.3
textblob==0.17.1
tf-slim==1.1.0
tf_keras==2.17.0
tf_nightly==2.18.0.dev20240923
thinc==8.2.5
threadpoolctl==3.5.0
tifffile==2024.8.30
tinycss2==1.3.0
tokenizers==0.19.1
toml==0.10.2
tomli==2.0.1
toolz==0.12.1
torch==2.4.0+cpu
torch-xla==2.4.0
torchaudio==2.4.0+cpu
torchsummary==1.5.1
torchvision==0.19.0+cpu
tornado==6.3.3
tqdm==4.66.5
traitlets==5.7.1
traittypes==0.2.1
transformers==4.44.2
tweepy==4.14.0
typeguard==4.3.0
typer==0.12.5
types-pytz==2024.2.0.20240913
types-setuptools==75.1.0.20240917
typing_extensions==4.12.2
tzdata==2024.1
tzlocal==5.2
uc-micro-py==1.0.3
uritemplate==3.0.1
urllib3==2.2.3
vega-datasets==0.9.0
wadllib==1.3.6
wasabi==1.1.3
wcwidth==0.2.13
weasel==0.4.1
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.0.4
widgetsnbextension==3.6.9
wordcloud==1.9.3
wrapt==1.16.0
xarray==2024.9.0
xarray-einstats==0.8.0
xgboost==2.1.1
xlrd==2.0.1
xyzservices==2024.9.0
yarl==1.11.1
yellowbrick==1.5
yfinance==0.2.43
zict==3.0.0
zipp==3.20.2

rfechtner commented 4 days ago

If I replace the torch.gather() by advanced indexing, I still get a Int64 OP Less that seems to be introduced for slicing:

import ai_edge_torch
import torch

sample_inputs = (torch.randn(1, 3, 224, 224),)

class Model(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, tensor):
    B, C, H, W = tensor.shape
    mode = np.empty((B, H, W), dtype=np.int32)
    np.argmax(tensor.detach().numpy(), axis=1, out=mode)
    mode = torch.from_numpy(mode) # (1, H, W)

    collected = torch.empty((B, 1, H, W), dtype=tensor.dtype, device=tensor.device)
    for b in range(B):
      collected[b, 0] = tensor[
          torch.arange(B, dtype=torch.int32).unsqueeze(-1).unsqueeze(-1),
          mode,
          torch.arange(H, dtype=torch.int32),
          torch.arange(W, dtype=torch.int32)
      ]
    return collected

model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
result = edge_model(*sample_inputs)
print(f"Output: {result.shape}")
edge_model.export("fancy.tflite")

model_explorer_graph-4