google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
373 stars 51 forks source link

Quantization of Llama results in TFLite file without prefill and decode sequences #369

Open Arya-Hari opened 3 days ago

Arya-Hari commented 3 days ago

Description of the bug:

I tried running the example.py script given for quantization example, but for Llama. Wherever the reference to Gemma was made, I made appropriate references to Llama. The modified code looks like this -

# Copyright 2024 The AI Edge Torch Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import ai_edge_torch
from ai_edge_torch.generative.examples.gemma import llama
from ai_edge_torch.generative.layers import kv_cache as kv_utils
from ai_edge_torch.generative.quantize import quant_recipes
from ai_edge_torch.generative.utilities import model_builder
import numpy as np
import torch

def main():
  # Build a PyTorch model as usual
  config = llama.get_fake_model_config()
  model = model_builder.DecoderOnlyModel(config).eval()
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
  tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
  tokens[0, :4] = idx
  input_pos = torch.arange(0, 10, dtype=torch.int)
  kv = kv_utils.KVCache.from_model_config(config)

  # Create a quantization recipe to be applied to the model
  quant_config = quant_recipes.full_int8_dynamic_recipe()
  print(quant_config)

  # Convert with quantization
  edge_model = ai_edge_torch.convert(
      model, (tokens, input_pos, kv), quant_config=quant_config
  )
  edge_model.export("/tmp/llama.tflite")

if __name__ == "__main__":
  main()

Actual vs expected behavior:

The proper TFLite model should have been produced. However, the generated tflite file does not have the required prefill and decode sequences. Thus, after bundling with the tokenizer and when trying to run on edge using mediapipe, I get a Failed to initialise error.

Any other information you'd like to share?

No response

pkgoogle commented 3 days ago

Hi @Arya-Hari, can you try w/ the actual llama conversion script? https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/llama/convert_to_tflite.py. It uses this function: https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/utilities/converter.py#L27 Which adds the required signatures.

Please review the generative API conversion examples: https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/generative/examples#model-conversion To ensure nothing else is missed. Thanks.

Arya-Hari commented 2 days ago

Hello. The actual llama conversion script produces the required result. But the size of the file produced in around 2GB. Is 8-bit quantization already applied when running the script?

pkgoogle commented 2 days ago

Hi @Arya-Hari, I believe so: https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/llama/convert_to_tflite.py#L54 .

Is that for the 1b or 3b model? quantized models -> 1 byte (8bits) /parameter, so from pure parameters excluding all overhead it should be around 1GB or 3GB. So perhaps the extra GB is all the overhead.

Arya-Hari commented 1 day ago

Okay I understand now. Will the use of any of the quantization recipes given in the repositories make any difference?

pkgoogle commented 1 day ago

Hi @Arya-Hari, it can definitely make a difference if you are quantizing to different precisions such as mixed activations (where some activations are 16-bit) or if you don't use full int-8 quantization. However it'll be largely around the same size most likely. Unless the model has a lot of ops which can't be quantized or something else like that.