google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
310 stars 40 forks source link

Accuracy is not matching for generated tiny_llama model #258

Open akshatshah17 opened 1 week ago

akshatshah17 commented 1 week ago

Description of the bug:

I got the correct output from verify.py for tiny_llama model but while I ran through text_generation the output is quite different. Generative API (nightly) 133

Command : bazel run -c opt //ai_edge_torch/generative/examples/cpp:text_generator_main -- --tflite_model="/home/test/Akashat/google-ai-edge/models/tflite/tiny_llama_fp32_seq512_ekv1024.tflite" --sentencepiece_model="/home/test/Akashat/google-ai-edge/models/TinyLlama-1.1B-Chat-v1.0/tokenizer.model" --start_token="<s>" --stop_token="</s>" --num_threads=16 --prompt="Write me a function to calculate the first 10 digits of the fibonacci sequence in Python and print it out to the CLI."

Actual vs expected behavior:

[2024-09-25 10:32:06.419151] Generating answer with the original model... [2024-09-25 10:32:13.257219] outputs_from_original_model: [[ Write me a function to calculate the first 10 digits of the fibonacci sequence in Python and print it out to the CLI. The function should take an integer argument representing the number of digits to print. The sequence should start with 0 and 1, and the first digit of each subsequent sequence should be the sum of the previous two digits. The function should return the first 10 digits of the sequence. ]]

[2024-09-25 10:32:13.257282] Generating answer with the reauthored model... [2024-09-25 10:32:47.001326] outputs from reauthored model: [[ Write me a function to calculate the first 10 digits of the fibonacci sequence in Python and print it out to the CLI. The function should take an integer argument representing the number of digits to print. The sequence should start with 0 and 1, and the first digit of each subsequent sequence should be the sum of the previous two digits. The function should return the first 10 digits of the sequence. ]]

text generation output Prompt: Write me a function to calculate the first 10 digits of the fibonacci sequence in Python and print it out to the CLI. Output text: The function should take in the sequence as an argument and return the first 10 digits. The function should also handle edge cases such as an empty sequence or a sequence with only one digit. The output should be formatted in a readable way.

Any other information you'd like to share?

No response

hheydary commented 1 week ago

I briefly looked at verify.py, two things stood out to me:

Please note that verify.py is a simple script for testing the converted against a reference point. Greedy sampling is not the best for generating quality texts from the model. Additionally, instruction tuned models work best when the prompt fed to them adheres to the prompt template that they were trained with. Please reference the model card for more details on that.

akshatshah17 commented 1 week ago

@hheydary it has inserted it but if I add s and /s in the github it basically creates the output like this example so I removed it. One more thing what is the limit set for max_token in verify.py

akshatshah17 commented 1 week ago

@pkgoogle how can I check if the generated genAI model wether it's tiny_lamma or gemma or stable diffusion is mathematically accurate to the base model?

pkgoogle commented 6 days ago

@akshatshah17, I recommend taking something like 5 random tensors (but locking them down after you determined them) and running through both runtimes and check how close the outputs are. If the conversion is mathematically equivalent (i.e. theres is no quantization, change in precision of any sort, or any optimizations) then theoretically there should be no difference. Typically we pick an epsilon that "makes sense" to compare this difference to. Precision/dtype conversions will introduce errors which is much harder to say whether it is completely accurate, but pick an epsilon for these that is acceptable for your use case. Does that help?