Open rfechtner opened 1 month ago
Hi @rfechtner, I was actually not able to replicate this issue if I use the latest code in main i.e.:
# navigate to ai-edge-torch repo
git switch main # if not already in the main branch
git pull # update to latest code
pip install -e .
pip install tensorflow-cpu # There was an import conflict that the latest code works better with torch-XLA this way
# run your script
Can you give that a try?, let me know what goes wrong if you try this way, also I recommend you use a new venv/conda environment to ensure there's no weird conflict this way. I should note I'm using Python=3.11 if that makes a difference.
Hi @pkgoogle thanks for the swift reply.
I've created a clean env with your suggestions. Same behaviour: I can convert the PyTorch model just fine but the exported model will contain Int64 Tensors (as torch.max()
returns LongTensor
).
but I want to avoid Int64 ops. I was trying to replace the torch function with TensorFlow ops, where I can specify the output dimension e.g.:
class ModelInt32TF(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
return tf.math.argmax(
sample_inputs[0], axis=1, output_type=tf.int32
)
model_int32_tf = ModelInt32TF().eval()
edge_model_int32_tf = ai_edge_torch.convert(model_int32_tf, sample_inputs)
edge_model_int32_tf(*sample_inputs)
which yields the error mentioned above:
---------------------------------------------------------------------------
Unsupported Traceback (most recent call last)
[<ipython-input-31-b7e7c53e3cf8>](https://v6zwhn4z3l-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240920-060127_RC00_676789073#) in <cell line: 11>()
9
10 model_int32_tf = ModelInt32TF().eval()
---> 11 edge_model_int32_tf = ai_edge_torch.convert(model_int32_tf, sample_inputs)
12 edge_model_int32_tf(*sample_inputs)
35 frames
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py](https://v6zwhn4z3l-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240920-060127_RC00_676789073#) in unimplemented(msg, from_exc)
219 if from_exc is not _NOTHING:
220 raise Unsupported(msg) from from_exc
--> 221 raise Unsupported(msg)
222
223
Unsupported: 'skip function argmax_v2 in file /usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py'
from user code:
File "<ipython-input-31-b7e7c53e3cf8>", line 6, in forward
return tf.math.argmax(
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Note: I can replace
select = tensor.max(dim=1).indices.unsqueeze(0)
by
select = np.emtpy(.., dtype=np.int32)
np.argmax(tensor, keepdims=1, out=select)
but torch.gather() and np.take_along_axis() (the later will be converted to the former) will keep requiring a Long tensor input...
Using the np.argmax(..)
instead of the tf.math.argmax()
brings me a step further:
import ai_edge_torch
import torch
sample_inputs = (torch.randn(1, 3, 224, 224),)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
B, C, H, W = tensor.shape
mode = np.empty((B, H, W), dtype=np.int32)
np.argmax(tensor.detach().numpy(), axis=1, out=mode)
mode = torch.from_numpy(mode).unsqueeze(0)
return torch.gather(tensor, dim=1, index=mode.long())
model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
edge_model(*sample_inputs)
Allows me to create index tensor of dtype int32
, but torch.gather()
still requires LongTensor
as input.
Environment: pip freeze
absl-py==1.4.0
accelerate==0.34.2
ai-edge-litert-nightly==1.0.1.dev20240924
ai-edge-model-explorer==0.1.12
ai-edge-model-explorer-adapter==0.1.5
ai-edge-quantizer-nightly==0.0.1.dev20240924
-e git+https://github.com/google-ai-edge/ai-edge-torch.git@c9973d2e7423e86f420576c0e5cac1181f79ac0e#egg=ai_edge_torch
aiohappyeyeballs==2.4.0
aiohttp==3.10.5
aiosignal==1.3.1
alabaster==0.7.16
albucore==0.0.16
albumentations==1.4.15
altair==4.2.2
annotated-types==0.7.0
anyio==3.7.1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.5.1
arviz==0.19.0
astropy==6.1.3
astropy-iers-data==0.2024.9.16.0.32.21
astunparse==1.6.3
async-timeout==4.0.3
atpublic==4.1.0
attrs==24.2.0
audioread==3.0.1
autograd==1.7.0
babel==2.16.0
backcall==0.2.0
beautifulsoup4==4.12.3
bidict==0.23.1
bigframes==1.17.0
bigquery-magics==0.2.0
bleach==6.1.0
blinker==1.4
blis==0.7.11
blosc2==2.0.0
bokeh==3.4.3
bqplot==0.12.43
branca==0.7.2
build==1.2.2
CacheControl==0.14.0
cachetools==5.5.0
catalogue==2.0.10
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.3.2
chex==0.1.86
clarabel==0.9.0
click==8.1.7
cloud-tpu-client==0.10
cloudpathlib==0.19.0
cloudpickle==2.2.1
cmake==3.30.3
cmdstanpy==1.2.4
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
community==1.0.0b1
confection==0.1.5
cons==0.4.6
contextlib2==21.6.0
contourpy==1.3.0
cryptography==43.0.1
cuda-python==12.2.1
cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.4.1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=57366e7ef09dc63e0b389aff20df6c37d91e2790065861ee31a4720149f5b694
cufflinks==0.17.3
cupy-cuda12x==12.2.0
cvxopt==1.3.2
cvxpy==1.5.3
cycler==0.12.1
cymem==2.0.8
Cython==3.0.11
dask==2024.8.0
datascience==0.17.6
db-dtypes==1.3.0
dbus-python==1.2.18
debugpy==1.6.6
decorator==4.4.2
defusedxml==0.7.1
distributed==2024.8.0
distro==1.7.0
dlib==19.24.2
dm-tree==0.1.8
docstring_parser==0.16
docutils==0.18.1
dopamine_rl==4.0.9
duckdb==1.1.0
earthengine-api==1.0.0
easydict==1.13
ecos==2.0.14
editdistance==0.8.1
eerepr==0.0.4
einops==0.8.0
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
entrypoints==0.4
et-xmlfile==1.1.0
etils==1.9.4
etuples==0.3.9
eval_type_backport==0.2.0
exceptiongroup==1.2.2
fastai==2.7.17
fastcore==1.7.8
fastdownload==0.0.7
fastjsonschema==2.20.0
fastprogress==1.0.3
fastrlock==0.8.2
filelock==3.16.1
firebase-admin==6.5.0
Flask==2.2.5
flatbuffers==24.3.25
flax==0.8.4
folium==0.17.0
fonttools==4.53.1
frozendict==2.4.4
frozenlist==1.4.1
fsspec==2024.6.1
future==1.0.0
gast==0.6.0
gcsfs==2024.6.1
GDAL==3.6.4
gdown==5.2.0
geemap==0.34.2
gensim==4.3.3
geocoder==1.38.1
geographiclib==2.0
geopandas==1.0.1
geopy==2.4.1
gin-config==0.5.0
glob2==0.7
google==2.0.3
google-ai-generativelanguage==0.6.6
google-api-core==1.34.1
google-api-python-client==1.8.0
google-auth==2.27.0
google-auth-httplib2==0.2.0
google-auth-oauthlib==1.2.1
google-cloud-aiplatform==1.67.1
google-cloud-bigquery==3.25.0
google-cloud-bigquery-connection==1.15.5
google-cloud-bigquery-storage==2.26.0
google-cloud-bigtable==2.26.0
google-cloud-core==2.4.1
google-cloud-datastore==2.19.0
google-cloud-firestore==2.16.1
google-cloud-functions==1.16.5
google-cloud-iam==2.15.2
google-cloud-language==2.13.4
google-cloud-pubsub==2.23.1
google-cloud-resource-manager==1.12.5
google-cloud-storage==2.8.0
google-cloud-translate==3.15.5
google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz#sha256=deb182392f5f78765ea686f1200ff7cfd42e31bdf8d172a68d6a29f657e1fe18
google-crc32c==1.6.0
google-generativeai==0.7.2
google-pasta==0.2.0
google-resumable-media==2.7.2
googleapis-common-protos==1.65.0
googledrivedownloader==0.4
graphviz==0.20.3
greenlet==3.1.0
grpc-google-iam-v1==0.13.1
grpcio==1.64.1
grpcio-status==1.48.2
gspread==6.0.2
gspread-dataframe==3.3.1
gym==0.25.2
gym-notices==0.0.8
h5netcdf==1.3.0
h5py==3.11.0
holidays==0.57
holoviews==1.19.1
html5lib==1.1
httpimport==1.4.0
httplib2==0.22.0
huggingface-hub==0.24.7
humanize==4.10.0
hyperopt==0.2.7
ibis-framework==8.0.0
idna==3.10
imageio==2.35.1
imageio-ffmpeg==0.5.1
imagesize==1.4.1
imbalanced-learn==0.12.3
imgaug==0.4.0
immutabledict==4.2.0
importlib_metadata==8.5.0
importlib_resources==6.4.5
imutils==0.5.4
inflect==7.4.0
iniconfig==2.0.0
intel-cmplr-lib-ur==2024.2.1
intel-openmp==2024.2.1
ipyevents==2.0.2
ipyfilechooser==0.6.0
ipykernel==5.5.6
ipyleaflet==0.19.2
ipyparallel==8.8.0
ipython==7.34.0
ipython-genutils==0.2.0
ipython-sql==0.5.0
ipytree==0.2.2
ipywidgets==7.7.1
itsdangerous==2.2.0
jax==0.4.26
jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750
jedi==0.19.1
jeepney==0.7.1
jellyfish==1.1.0
jieba==0.42.1
Jinja2==3.1.4
joblib==1.4.2
jsonpickle==3.3.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
jupyter-client==6.1.12
jupyter-console==6.1.0
jupyter-leaflet==0.19.2
jupyter-server==1.24.0
jupyter_core==5.7.2
jupyterlab_pygments==0.3.0
jupyterlab_widgets==3.0.13
kaggle==1.6.17
kagglehub==0.3.0
keras==3.4.1
keras-nightly==3.5.0.dev2024092403
keyring==23.5.0
kiwisolver==1.4.7
langcodes==3.4.0
language_data==1.2.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.2.post1
lightgbm==4.5.0
linkify-it-py==2.0.3
llvmlite==0.43.0
locket==1.0.0
logical-unification==0.4.6
lxml==4.9.4
marisa-trie==1.2.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.1
matplotlib-inline==0.1.7
matplotlib-venn==1.1.1
mdit-py-plugins==0.4.2
mdurl==0.1.2
miniKanren==1.0.3
missingno==0.5.2
mistune==0.8.4
mizani==0.11.4
mkl==2024.2.2
ml-dtypes==0.4.1
mlxtend==0.23.1
more-itertools==10.5.0
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.0.8
multidict==6.1.0
multipledispatch==1.0.0
multitasking==0.0.11
murmurhash==1.0.10
music21==9.1.0
namex==0.0.8
natsort==8.4.0
nbclassic==1.1.0
nbclient==0.10.0
nbconvert==6.5.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
nibabel==5.2.1
nltk==3.8.1
notebook==6.5.5
notebook_shim==0.2.4
numba==0.60.0
numexpr==2.10.1
numpy==1.26.4
nvidia-nccl-cu12==2.23.4
nvtx==0.2.10
oauth2client==4.1.3
oauthlib==3.2.2
opencv-contrib-python==4.10.0.84
opencv-python==4.10.0.84
opencv-python-headless==4.10.0.84
openpyxl==3.1.5
opt-einsum==3.3.0
optax==0.2.2
optree==0.12.1
orbax-checkpoint==0.6.4
osqp==0.6.7.post0
packaging==24.1
pandas==2.1.4
pandas-datareader==0.10.0
pandas-gbq==0.23.1
pandas-stubs==2.1.4.231227
pandocfilters==1.5.1
panel==1.4.5
param==2.1.1
parso==0.8.4
parsy==2.1
partd==1.4.2
pathlib==1.0.1
patsy==0.5.6
peewee==3.17.6
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.4.0
pip-tools==7.4.1
platformdirs==4.3.6
plotly==5.24.1
plotnine==0.13.6
pluggy==1.5.0
polars==1.6.0
pooch==1.8.2
portpicker==1.5.2
prefetch_generator==1.0.3
preshed==3.0.9
prettytable==3.11.0
proglog==0.1.10
progressbar2==4.5.0
prometheus_client==0.20.0
promise==2.3
prompt_toolkit==3.0.47
prophet==1.1.5
proto-plus==1.24.0
protobuf==3.20.3
psutil==5.9.5
psycopg2==2.9.9
ptyprocess==0.7.0
py-cpuinfo==9.0.0
py4j==0.10.9.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycocotools==2.0.8
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
pydata-google-auth==1.8.2
pydot==3.0.1
pydot-ng==2.0.0
pydotplus==2.0.2
PyDrive==1.3.1
PyDrive2==1.20.0
pyerfa==2.0.1.4
pygame==2.6.0
Pygments==2.18.0
PyGObject==3.42.1
PyJWT==2.9.0
pymc==5.16.2
pymystem3==0.2.0
pynvjitlink-cu12==0.3.0
pyogrio==0.9.0
PyOpenGL==3.1.7
pyOpenSSL==24.2.1
pyparsing==3.1.4
pyperclip==1.9.0
pyproj==3.6.1
pyproject_hooks==1.1.0
pyshp==2.3.1
PySocks==1.7.1
pytensor==2.25.4
pytest==7.4.4
python-apt==2.4.0
python-box==7.2.0
python-dateutil==2.8.2
python-louvain==0.16
python-slugify==8.0.4
python-utils==3.8.2
pytz==2024.2
pyviz_comms==3.0.3
PyYAML==6.0.2
pyzmq==24.0.1
qdldl==0.1.7.post4
ratelim==0.1.6
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
requests-oauthlib==1.3.1
requirements-parser==0.9.0
rich==13.8.1
rmm-cu12==24.4.0
rpds-py==0.20.0
rpy2==3.4.2
rsa==4.9
safetensors==0.4.5
scikit-image==0.24.0
scikit-learn==1.5.2
scipy==1.13.1
scooby==0.10.0
scs==3.2.7
seaborn==0.13.1
SecretStorage==3.3.1
Send2Trash==1.8.3
sentencepiece==0.2.0
shapely==2.0.6
shellingham==1.5.4
simple-parsing==0.1.6
six==1.16.0
sklearn-pandas==2.2.0
smart-open==7.0.4
sniffio==1.3.1
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.6
soxr==0.5.0.post1
spacy==3.7.6
spacy-legacy==3.0.12
spacy-loggers==1.0.5
Sphinx==5.0.2
sphinxcontrib-applehelp==2.0.0
sphinxcontrib-devhelp==2.0.0
sphinxcontrib-htmlhelp==2.1.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==2.0.0
sphinxcontrib-serializinghtml==2.0.0
SQLAlchemy==2.0.35
sqlglot==20.11.0
sqlparse==0.5.1
srsly==2.4.8
stanio==0.5.1
statsmodels==0.14.3
StrEnum==0.4.15
sympy==1.13.3
tables==3.8.0
tabulate==0.9.0
tb-nightly==2.18.0a20240924
tbb==2021.13.1
tblib==3.0.0
tenacity==9.0.0
tensorboard==2.17.0
tensorboard-data-server==0.7.2
tensorflow-cpu==2.17.0
tensorflow-datasets==4.9.6
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.37.1
tensorflow-metadata==1.15.0
tensorflow-probability==0.24.0
tensorstore==0.1.65
termcolor==2.4.0
terminado==0.18.1
text-unidecode==1.3
textblob==0.17.1
tf-slim==1.1.0
tf_keras==2.17.0
tf_nightly==2.18.0.dev20240923
thinc==8.2.5
threadpoolctl==3.5.0
tifffile==2024.8.30
tinycss2==1.3.0
tokenizers==0.19.1
toml==0.10.2
tomli==2.0.1
toolz==0.12.1
torch==2.4.0+cpu
torch-xla==2.4.0
torchaudio==2.4.0+cpu
torchsummary==1.5.1
torchvision==0.19.0+cpu
tornado==6.3.3
tqdm==4.66.5
traitlets==5.7.1
traittypes==0.2.1
transformers==4.44.2
tweepy==4.14.0
typeguard==4.3.0
typer==0.12.5
types-pytz==2024.2.0.20240913
types-setuptools==75.1.0.20240917
typing_extensions==4.12.2
tzdata==2024.1
tzlocal==5.2
uc-micro-py==1.0.3
uritemplate==3.0.1
urllib3==2.2.3
vega-datasets==0.9.0
wadllib==1.3.6
wasabi==1.1.3
wcwidth==0.2.13
weasel==0.4.1
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.0.4
widgetsnbextension==3.6.9
wordcloud==1.9.3
wrapt==1.16.0
xarray==2024.9.0
xarray-einstats==0.8.0
xgboost==2.1.1
xlrd==2.0.1
xyzservices==2024.9.0
yarl==1.11.1
yellowbrick==1.5
yfinance==0.2.43
zict==3.0.0
zipp==3.20.2
If I replace the torch.gather()
by advanced indexing, I still get a Int64 OP Less
that seems to be introduced for slicing:
import ai_edge_torch
import torch
sample_inputs = (torch.randn(1, 3, 224, 224),)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tensor):
B, C, H, W = tensor.shape
mode = np.empty((B, H, W), dtype=np.int32)
np.argmax(tensor.detach().numpy(), axis=1, out=mode)
mode = torch.from_numpy(mode) # (1, H, W)
collected = torch.empty((B, 1, H, W), dtype=tensor.dtype, device=tensor.device)
for b in range(B):
collected[b, 0] = tensor[
torch.arange(B, dtype=torch.int32).unsqueeze(-1).unsqueeze(-1),
mode,
torch.arange(H, dtype=torch.int32),
torch.arange(W, dtype=torch.int32)
]
return collected
model = Model().eval()
edge_model = ai_edge_torch.convert(model, sample_inputs)
result = edge_model(*sample_inputs)
print(f"Output: {result.shape}")
edge_model.export("fancy.tflite")
Hi @rfechtner, I think it's because np.argmax still also returns a int64 return value/tensor. I'm wondering if you can implement an explicit int32 or lower precision argmax function which never touches/gets turned into int64 values. I suppose a feature like index quantization or just general precision lowering might be interesting.
Hi, @pkgoogle, yes precisely. There seems no "out of the box" approach to advanced indexing, without implicit Int64 calls due to underlying int64 long tensor indices in Numpy's & PyTorch's implementations.
Going through ONNX it's quite straightforward to modify the graph after export and replace the dtype of relevant ops to avoid int64. Unfortunately, less trivial in the flatbuffer format..
Would you be so kind providing some references for the index quantisation & precision lowering approaches you mentioned?
Cheers
I think that would be a feature we might implement -- something you may want to try yourself is reimplement np.argmax where it never becomes/touches an int64 tensor. (Just as a general python/pytorch function) and see if you can use that instead of np.argmax. The general form is for each index permutation of all the non-axis dimensions take the argmax of the 1D tensor produced by using that index permutation and all the values from moving through the axis dimension and that index value is the output tensor's value for that index permutation.
If you just want to test if it'll work maybe just implement a version where you know the input shape or maybe just try a 2D tensor first.
I will give it a shot, thanks a lot for the feedback!
A out of the box optimisation to avoid Int64 OPs via converter flags would be a awesome addition to get this to a one stop deployment pipeline. Please keep me posted on these efforts! :)
I'll report back as well.
Description of the bug:
Hi,
I am trying to covert an PyTorch to TFLite which uses
torch.argmax(..).indicies
andtorch.gather(..
) - hence creatingLongTensor
s (Int64). As my targeted runtime delegate does not support any int64 ops (including cast int64 -> int32), I am seeking to replace int64 ops by corresponding int32 ones.Min rep. example:
In the past I have been dong this via intermediate ONNX model representation where I modified the relevant nodes and then converted ONNX to TFLite, but with this new framework I’d hoped to get rid of the onnx.
I have tried to replace the
torch.argmax()
with atf.math.argmax(.., output_type=tf.int32)
or the numpy equivalent which supports specifying the output type or array, but that fails duringtorch.export()
and results inOne remaining avenue I can think of is post processing the resulting flatbuffer representation and replacing the int64 ops here, but that seems quite brittle and overly complicated.
Any other suggestions? Or is there a way do dynamically replace functions?
Note: I had to pin
tf-nightly==2.18.0.dev20240722
otherwise the export fails with:Click this to collapse/fold.