google / generative-ai-docs

Documentation for Google's Gen AI site - including the Gemini API and Gemma
https://ai.google.dev
Apache License 2.0
1.67k stars 601 forks source link

Why did Gemma7B perform poorly #365

Open JuJoker opened 6 months ago

JuJoker commented 6 months ago

Description of the bug:

I ran the Gemema-7B model based on the code in the example, and found that the model's answers were rather poor and didn't seem to understand my question at all. Is this normal? My device is Nvidia 4090 GPU. My code as follows:

import os
import sys
import torch
sys.path.append('gemma_pytorch')

from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM

VARIANT = '7b'
MACHINE_TYPE = 'cuda'

# Set up model config.
model_config = get_config_for_7b()

# Ensure that the tokenizer is present
tokenizer_path = os.path.join('./gemma-7b', 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join('./gemma-7b', f'gemma-7b.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

model_config.tokenizer = tokenizer_path

model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'
MODEL_CHAT_START = '<start_of_turn>model\n'
MODEL_CHAT_END = '<end_of_turn>\n'

MULTI_CHAT = ''

# while True:
#     input_text = input("User input(press q to exit):")
#     if input_text != 'q':
#         user_prompt = USER_CHAT_TEMPLATE.format(prompt=input_text)
#         MULTI_CHAT = MULTI_CHAT + user_prompt + MODEL_CHAT_START
#         model_response = model.generate(
#             user_prompt,
#             device=device,
#             output_len=64,
#         )
#         print(f'Model reply: {model_response}')
#         MULTI_CHAT = MULTI_CHAT + model_response + MODEL_CHAT_END
#     else:
#         break

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

res = model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)
print(res)

The according response as follows: image

Actual vs expected behavior:

but this code run on the kaggle, the result seems correctly, image

Any other information you'd like to share?

No response

cog-master commented 6 months ago

not only they perform poorly but slow also u have this "Nvidia 4090 GPU" then the model seems to be not working properly

MarkDaoust commented 6 months ago

@gustheman