microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
33.78k stars 3.96k forks source link

[BUG] deepspeed load safetensors get _pickle.UnpicklingError: invalid load key, '\xec'. #2945

Open stevensu1977 opened 1 year ago

stevensu1977 commented 1 year ago

Describe the bug I can use ComViz load_model_from_config function , add safetensors load , it's work , but when I use deepspeed get error _pickle.UnpicklingError: invalid load key, '\xec'

accelerate==0.16.0 deepspeed==0.7.5 diffusers==0.10.0 invisible-watermark==0.1.5 omegaconf==2.3.0 Pillow==9.4.0 safetensors==0.3.0 torch==1.11.0+cu113 torchmetrics==0.11.3 torchvision==0.12.0+cu113 transformers==4.25.1 triton==2.0.0.dev20221202

from safetensors.torch import load_file

def load_model_from_config(config, ckpt, model_format="ckpt",verbose=False):
    print(f"Loading model from {ckpt}")
    if model_format=="ckpt":
        pl_sd = torch.load(ckpt, map_location="cpu")
    elif model_format=="safetensors":
        pl_sd = load_file(ckpt, device="cuda:0")
    else:
        raise Exception(f"not support {model_format} ")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

#deepspeed part:
ds_engine = deepspeed.init_inference(model,
                                 #mp_size=1,
                                 dtype=torch.float16,
                                 checkpoint=None,
                                 replace_method='auto',
                                 replace_with_kernel_inject=True)

Expected behavior File "/home/ubuntu/aigc/stable-diffusion/ldm_txt2img.py", line 317, in main(opt) File "/home/ubuntu/aigc/stable-diffusion/ldm_txt2img.py", line 218, in main model = load_model_from_config(config, f"{opt.ckpt}",f"{opt.model_format}") File "/home/ubuntu/aigc/stable-diffusion/ldm_txt2img.py", line 38, in load_model_from_config pl_sd = torch.load(ckpt, map_location="cpu") File "/opt/conda/envs/deepspeed01/lib/python3.9/site-packages/torch/serialization.py", line 713, in load return _legacy_load(opened_file, map_location, pickle_module, pickle_load_args) File "/opt/conda/envs/deepspeed01/lib/python3.9/site-packages/torch/serialization.py", line 920, in _legacy_load magic_number = pickle_module.load(f, pickle_load_args) _pickle.UnpicklingError: invalid load key, '\xec'.

ys2899 commented 1 year ago

I have the same problem