I ran the Gemema-7B model based on the code in the example, and found that the model's answers were rather poor and didn't seem to understand my question at all. Is this normal? My device is Nvidia 4090 GPU. My code as follows:
import os
import sys
import torch
sys.path.append('gemma_pytorch')
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM
VARIANT = '7b'
MACHINE_TYPE = 'cuda'
# Set up model config.
model_config = get_config_for_7b()
# Ensure that the tokenizer is present
tokenizer_path = os.path.join('./gemma-7b', 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join('./gemma-7b', f'gemma-7b.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT
# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()
# Generate with one request in chat mode
# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'
MODEL_CHAT_START = '<start_of_turn>model\n'
MODEL_CHAT_END = '<end_of_turn>\n'
MULTI_CHAT = ''
# while True:
# input_text = input("User input(press q to exit):")
# if input_text != 'q':
# user_prompt = USER_CHAT_TEMPLATE.format(prompt=input_text)
# MULTI_CHAT = MULTI_CHAT + user_prompt + MODEL_CHAT_START
# model_response = model.generate(
# user_prompt,
# device=device,
# output_len=64,
# )
# print(f'Model reply: {model_response}')
# MULTI_CHAT = MULTI_CHAT + model_response + MODEL_CHAT_END
# else:
# break
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
res = model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=100,
)
print(res)
The according response as follows:
Actual vs expected behavior:
but this code run on the kaggle, the result seems correctly,
Description of the bug:
I ran the Gemema-7B model based on the code in the example, and found that the model's answers were rather poor and didn't seem to understand my question at all. Is this normal? My device is Nvidia 4090 GPU. My code as follows:
The according response as follows:
Actual vs expected behavior:
but this code run on the kaggle, the result seems correctly,
Any other information you'd like to share?
No response