google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.24k stars 499 forks source link

Quantised weights are bfloat16 not int8 #7

Closed dsanmart closed 6 months ago

dsanmart commented 6 months ago

According to the Kaggle model card it is a 7B int8 quantized parameter base model.

I'm trying to run the model using torch.uint8 dtype for the quantised version. However, the parameters are in torch.bfloat16 dtype. Is it possible to get them or transform the bfloat16 ones to int8?

michaelmoynihan commented 6 months ago

Hi there Diego, we do use int8 for quantized (not uint8). If you look at the implementation, we explicitly use dtype int8 for our weights (if quantized is used) here: https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L107C1-L154C22. Additionally, in Kaggle, the weights are int8 (and the scalers associated with the weights are bfloat16). Can you give me more details about how you are trying to run this?

pengchongjin commented 6 months ago

To use the quantized checkpoint, please do not explicitly set dtype in config.py.

Instead, you should be able to use the 7B-quant PyTorch checkpoint downloaded from Kaggle with the following command out-of-box.

PROMPT="The meaning of life is"

docker run -t --rm \
    --gpus all \
    -v ${CKPT_PATH}:/tmp/ckpt \
    ${DOCKER_URI} \
    python scripts/run.py \
    --device=cuda \
    --ckpt=/tmp/ckpt \
    --variant="${VARIANT}" \
    --prompt="${PROMPT}" \
    --quant
dsanmart commented 6 months ago

Hi there Diego, we do use int8 for quantized (not uint8). If you look at the implementation, we explicitly use dtype int8 for our weights (if quantized is used) here: https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L107C1-L154C22. Additionally, in Kaggle, the weights are int8 (and the scalers associated with the weights are bfloat16). Can you give me more details about how you are trying to run this?

Hi @michaelmoynihan, you are correct, thanks for the clarification.

I was trying to run it on MPS using the Pytorch python implementation and was getting a TypeError: BFloat16 is not supported on MPS so I quickly thought the weights were bfloat. However, it is was not the weights but the model_config dtype that was set to bfloat16. I also tried removing the scaler, even though it would decrease output quality, but had no success since there are other ComplexFloat that are not supported in the MPS backend.