artidoro / qlora

QLoRA: Efficient Finetuning of Quantized LLMs
https://arxiv.org/abs/2305.14314
MIT License
9.96k stars 820 forks source link

why dtype change line needs? #209

Open YooSungHyun opened 1 year ago

YooSungHyun commented 1 year ago

In qlora.py

for name, module in model.named_modules():
    if isinstance(module, LoraLayer):
        if args.bf16:
            module = module.to(torch.bfloat16)
    if 'norm' in name:
        module = module.to(torch.float32)
    if 'lm_head' in name or 'embed_tokens' in name:
        if hasattr(module, 'weight'):
            if args.bf16 and module.weight.dtype == torch.float32:
                module = module.to(torch.bfloat16)

this line why needs? if i run code in bfloat16, all layer dtype is bfloat16 doesn't it?