PaddlePaddle / PaddleNLP

👑 Easy-to-use and powerful NLP and LLM library with 🤗 Awesome model zoo, supporting wide-range of NLP tasks from research to industrial applications, including 🗂Text Classification, 🔍 Neural Search, ❓ Question Answering, ℹ️ Information Extraction, 📄 Document Intelligence, 💌 Sentiment Analysis etc.
https://paddlenlp.readthedocs.io
Apache License 2.0
12k stars 2.93k forks source link

bf16数据类型的权重转换报错 #8688

Closed zhaogf01 closed 2 weeks ago

zhaogf01 commented 3 months ago

https://github.com/PaddlePaddle/PaddleNLP/blob/be5bb14a8b212bd9586815f11b2c7f32d5823b43/legacy/examples/torch_migration/pipeline/weights/torch2paddle.py#L78

torch的权重是bf16的数据,我也想使用bf16推理,此时在转换时会报错,如下: image

请问如何解决?

DrownFish19 commented 3 months ago

模型参数转换暂时可以参考qwen2的参数转换文件,https://gist.github.com/DrownFish19/80b43383c9205ee1cf7cf35445009488

主要转换代码如下:

def translate_one_safetensor(file_name):
    tensors = load_file(os.path.join(model_path, file_name))
    for key in list(tensors.keys()):
        dst_key = key.replace(src_prefix_key, dst_prefix_key)
        logger.info("{} {}".format(key, tensors[key].shape))
        if check_trans(key): # 判断参数是否需要转置
            t = tensors.pop(key).cuda().t().contiguous()
            capsule = torch.utils.dlpack.to_dlpack(t)
            t = paddle.utils.dlpack.from_dlpack(capsule)
            tensors[dst_key] = t.numpy()
        else:
            t = tensors.pop(key).cuda()
            capsule = torch.utils.dlpack.to_dlpack(t)
            t = paddle.utils.dlpack.from_dlpack(capsule)
            tensors[dst_key] = t.numpy()

            # tensors[dst_key] = paddle.to_tensor(tensors.pop(key).cuda().float().cpu().numpy(), dtype="bfloat16").numpy()
        logger.info("{} {}".format(dst_key, tensors[dst_key].shape))

    save_file(tensors, os.path.join(dst_path, file_name), metadata={"format": "np"})
    # os.remove(os.path.join(model_path, file_name))
zhaogf01 commented 3 months ago

模型参数转换暂时可以参考qwen2的参数转换文件,https://gist.github.com/DrownFish19/80b43383c9205ee1cf7cf35445009488

主要转换代码如下:

def translate_one_safetensor(file_name):
    tensors = load_file(os.path.join(model_path, file_name))
    for key in list(tensors.keys()):
        dst_key = key.replace(src_prefix_key, dst_prefix_key)
        logger.info("{} {}".format(key, tensors[key].shape))
        if check_trans(key): # 判断参数是否需要转置
            t = tensors.pop(key).cuda().t().contiguous()
            capsule = torch.utils.dlpack.to_dlpack(t)
            t = paddle.utils.dlpack.from_dlpack(capsule)
            tensors[dst_key] = t.numpy()
        else:
            t = tensors.pop(key).cuda()
            capsule = torch.utils.dlpack.to_dlpack(t)
            t = paddle.utils.dlpack.from_dlpack(capsule)
            tensors[dst_key] = t.numpy()

            # tensors[dst_key] = paddle.to_tensor(tensors.pop(key).cuda().float().cpu().numpy(), dtype="bfloat16").numpy()
        logger.info("{} {}".format(dst_key, tensors[dst_key].shape))

    save_file(tensors, os.path.join(dst_path, file_name), metadata={"format": "np"})
    # os.remove(os.path.join(model_path, file_name))

这是safetensor的转换,请问pdparams适用嘛

DrownFish19 commented 3 months ago
  1. 如果已经实现模型及XXXPretrainedModel._get_name_mapping_方法,可以直接通过paddlenlp转换参数,代码如下:
    
    # XXXForCausalLM为需要转换参数的类,XXX需根据需求替换
    import XXXForCausalLM

args为其他参数,根据需求增加

XXXForCausalLM.from_pretrain(model_name_or_path, convert_from_torch=True, **args)



参考代码如下:
(1)convert from torch 逻辑https://github.com/PaddlePaddle/PaddleNLP/blob/2723138738fe179acb32bf619c54a9315acdabe0/paddlenlp/transformers/model_utils.py#L2225C1-L2252C88
(2)convert逻辑 https://github.com/PaddlePaddle/PaddleNLP/blob/2723138738fe179acb32bf619c54a9315acdabe0/paddlenlp/transformers/conversion_utils.py#L1145C1-L1188C26

2. 如果需要脚本转换为pdparams,可以参考上述convert逻辑编写脚本,通过dict的update函数汇总所有参数并save为pdparams即可(注意线性层参数需要转置)。
github-actions[bot] commented 1 month ago

This issue is stale because it has been open for 60 days with no activity. 当前issue 60天内无活动,被标记为stale。

github-actions[bot] commented 2 weeks ago

This issue was closed because it has been inactive for 14 days since being marked as stale. 当前issue 被标记为stale已有14天,即将关闭。