for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
this line why needs? if i run code in bfloat16, all layer dtype is bfloat16 doesn't it?
In
qlora.py
this line why needs? if i run code in bfloat16, all layer dtype is bfloat16 doesn't it?