Closed martinkorelic closed 3 months ago
I'm running into the same issue:
python convert_to_tflite.py
2024-08-01 22:26:04.401644: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-08-01 22:26:05.621391: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1722551169.675369 817544 cpu_client.cc:424] TfrtCpuClient created.
Traceback (most recent call last):
File "xxxxxxxx/git/ai-edge-torch/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py", line 67, in <module>
convert_tiny_llama_to_tflite(checkpoint_path)
File "xxxxxxxx/git/ai-edge-torch/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py", line 43, in convert_tiny_llama_to_tflite
pytorch_model = tiny_llama.build_model(
^^^^^^^^^^^^^^^^^^^^^^^
File "xxxxxxxx/miniconda3/envs/aet_115/lib/python3.11/site-packages/ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py", line 147, in build_model
loader.load(model)
File "xxxxxxxx/miniconda3/envs/aet_115/lib/python3.11/site-packages/ai_edge_torch/generative/utilities/loader.py", line 161, in load
model.load_state_dict(converted_state, strict=strict)
File "xxxxxxxx/miniconda3/envs/aet_115/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2191, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TinyLLamma:
Missing key(s) in state_dict: "transformer_blocks.0.atten_func.qkv_projection.weight", "transformer_blocks.0.atten_func.output_projection.weight", "transformer_blocks.1.atten_func.qkv_projection.weight", "transformer_blocks.1.atten_func.output_projection.weight", "transformer_blocks.2.atten_func.qkv_projection.weight", "transformer_blocks.2.atten_func.output_projection.weight", "transformer_blocks.3.atten_func.qkv_projection.weight", "transformer_blocks.3.atten_func.output_projection.weight", "transformer_blocks.4.atten_func.qkv_projection.weight", "transformer_blocks.4.atten_func.output_projection.weight", "transformer_blocks.5.atten_func.qkv_projection.weight", "transformer_blocks.5.atten_func.output_projection.weight", "transformer_blocks.6.atten_func.qkv_projection.weight", "transformer_blocks.6.atten_func.output_projection.weight", "transformer_blocks.7.atten_func.qkv_projection.weight", "transformer_blocks.7.atten_func.output_projection.weight", "transformer_blocks.8.atten_func.qkv_projection.weight", "transformer_blocks.8.atten_func.output_projection.weight", "transformer_blocks.9.atten_func.qkv_projection.weight", "transformer_blocks.9.atten_func.output_projection.weight", "transformer_blocks.10.atten_func.qkv_projection.weight", "transformer_blocks.10.atten_func.output_projection.weight", "transformer_blocks.11.atten_func.qkv_projection.weight", "transformer_blocks.11.atten_func.output_projection.weight", "transformer_blocks.12.atten_func.qkv_projection.weight", "transformer_blocks.12.atten_func.output_projection.weight", "transformer_blocks.13.atten_func.qkv_projection.weight", "transformer_blocks.13.atten_func.output_projection.weight", "transformer_blocks.14.atten_func.qkv_projection.weight", "transformer_blocks.14.atten_func.output_projection.weight", "transformer_blocks.15.atten_func.qkv_projection.weight", "transformer_blocks.15.atten_func.output_projection.weight", "transformer_blocks.16.atten_func.qkv_projection.weight", "transformer_blocks.16.atten_func.output_projection.weight", "transformer_blocks.17.atten_func.qkv_projection.weight", "transformer_blocks.17.atten_func.output_projection.weight", "transformer_blocks.18.atten_func.qkv_projection.weight", "transformer_blocks.18.atten_func.output_projection.weight", "transformer_blocks.19.atten_func.qkv_projection.weight", "transformer_blocks.19.atten_func.output_projection.weight", "transformer_blocks.20.atten_func.qkv_projection.weight", "transformer_blocks.20.atten_func.output_projection.weight", "transformer_blocks.21.atten_func.qkv_projection.weight", "transformer_blocks.21.atten_func.output_projection.weight".
Unexpected key(s) in state_dict: "transformer_blocks.0.atten_func.attn.weight", "transformer_blocks.0.atten_func.proj.weight", "transformer_blocks.1.atten_func.attn.weight", "transformer_blocks.1.atten_func.proj.weight", "transformer_blocks.2.atten_func.attn.weight", "transformer_blocks.2.atten_func.proj.weight", "transformer_blocks.3.atten_func.attn.weight", "transformer_blocks.3.atten_func.proj.weight", "transformer_blocks.4.atten_func.attn.weight", "transformer_blocks.4.atten_func.proj.weight", "transformer_blocks.5.atten_func.attn.weight", "transformer_blocks.5.atten_func.proj.weight", "transformer_blocks.6.atten_func.attn.weight", "transformer_blocks.6.atten_func.proj.weight", "transformer_blocks.7.atten_func.attn.weight", "transformer_blocks.7.atten_func.proj.weight", "transformer_blocks.8.atten_func.attn.weight", "transformer_blocks.8.atten_func.proj.weight", "transformer_blocks.9.atten_func.attn.weight", "transformer_blocks.9.atten_func.proj.weight", "transformer_blocks.10.atten_func.attn.weight", "transformer_blocks.10.atten_func.proj.weight", "transformer_blocks.11.atten_func.attn.weight", "transformer_blocks.11.atten_func.proj.weight", "transformer_blocks.12.atten_func.attn.weight", "transformer_blocks.12.atten_func.proj.weight", "transformer_blocks.13.atten_func.attn.weight", "transformer_blocks.13.atten_func.proj.weight", "transformer_blocks.14.atten_func.attn.weight", "transformer_blocks.14.atten_func.proj.weight", "transformer_blocks.15.atten_func.attn.weight", "transformer_blocks.15.atten_func.proj.weight", "transformer_blocks.16.atten_func.attn.weight", "transformer_blocks.16.atten_func.proj.weight", "transformer_blocks.17.atten_func.attn.weight", "transformer_blocks.17.atten_func.proj.weight", "transformer_blocks.18.atten_func.attn.weight", "transformer_blocks.18.atten_func.proj.weight", "transformer_blocks.19.atten_func.attn.weight", "transformer_blocks.19.atten_func.proj.weight", "transformer_blocks.20.atten_func.attn.weight", "transformer_blocks.20.atten_func.proj.weight", "transformer_blocks.21.atten_func.attn.weight", "transformer_blocks.21.atten_func.proj.weight".
I0000 00:00:1722551177.566826 817544 cpu_client.cc:427] TfrtCpuClient destroyed.
Hi, I'm unable to replicate this error but I noticed that in @pkgoogle 's log that the version of ai_edge_torch
that is used in from site-packages, and not the repo itself. Could you go to the root directory of this repo (the directory with the setup.py
and execute pip install -e .
so that the ai_edge_torch
that you get is not the one from PyPI, but the actual code in this repo and try again? So in other words, starting at the top of this repo, it would be:
pip install -e .
cd ai_edge_torch/generative/examples/tiny_llama/
python convert_to_tflite.py
If you still get the error, can you copy/paste the entire error message, so we can see which Python file is raising the exceptions (in site-packages or not). Thanks!
Hi @talumbau, thanks! I was able to successfully run the script after creating a new conda environment and installing the local files instead of the current package.
@martinkorelic can you try creating a fresh python environment with 3.9, 3.10, or 3.11 (venv/conda are popular choices) then install the local files as noted above.
example:
git clone https://github.com/google-ai-edge/ai-edge-torch.git
cd ai-edge-torch
pip install -e .
# ensure the model checkpoint path is where you want it to be
cd ai_edge_torch/generative/examples/tiny_llama/
python convert_to_tflite.py
@pkgoogle This seems to work as expected when converting, however I'm running into other issues:
~~The conversion seems to take an enormous amount of memory (which is expected), however sometimes the whole process crashes without any warning, even though I have allocated more swap space, which doesn't seems to get used all the way.
Similar conversion worked with tensorflow lite conversion with bigger models, which used the available swap space, however for some reason, this one does not (I have 90 GB of ram available).
Is there anything that might be preventing this? Are there any options which decrease the memory needed for this conversion?~~
Edit: Seems to be OS killing the process, solved with:
systemctl disable --now systemd-oomd
One more question unrelated to the issue: With this framework, would it be possible to define a training function signature in pytorch and expose it for tflite conversion, which would be available later for finetuning?
One more question unrelated to the issue: With this framework, would it be possible to define a training function signature in pytorch and expose it for tflite conversion, which would be available later for finetuning?
Interesting idea. IIUC, you are saying you want to define a training loop in Pytorch (forward pass with some data through a model, compute a loss metric, execute backward pass to compute gradients, define and use an optimizer to modify model parameters, repeat).
This is not currently on our engineering road map, which is mostly about making it a very smooth process to author (or re-author) a model in PyTorch and then convert to an edge compatible artifact for on device inference only execution.
For on-device training, I would recommend looking at the infrastructure offered via TensorFlow (also available in Keras I believe):
https://www.tensorflow.org/lite/examples/on_device_training/overview
This looks resolved. Please re-open if needed.
Description of the bug:
Followed the examples for TFLite conversion, installed:
pip install ai-edge-torch
pip install torch-xla
Downloaded the
model.safetensors
and corrected the path to the model checkpoint from TinyLlama/TinyLlama-1.1B-Chat-v1.0.Simply ran the script in
examples/tiny_llama
:python convert_to_tflite.py
Actual vs expected behavior:
Expected: The model conversion to TFLite should begin.
Actual: Mismatch in model state dict when loading the model:
Any other information you'd like to share?
Is there anything different in model structure, is perhaps the wrong model checkpoint linked in the examples?