Describe the bug
I can use ComViz load_model_from_config function , add safetensors load , it's work , but when I use deepspeed get error _pickle.UnpicklingError: invalid load key, '\xec'
from safetensors.torch import load_file
def load_model_from_config(config, ckpt, model_format="ckpt",verbose=False):
print(f"Loading model from {ckpt}")
if model_format=="ckpt":
pl_sd = torch.load(ckpt, map_location="cpu")
elif model_format=="safetensors":
pl_sd = load_file(ckpt, device="cuda:0")
else:
raise Exception(f"not support {model_format} ")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
#deepspeed part:
ds_engine = deepspeed.init_inference(model,
#mp_size=1,
dtype=torch.float16,
checkpoint=None,
replace_method='auto',
replace_with_kernel_inject=True)
Expected behavior
File "/home/ubuntu/aigc/stable-diffusion/ldm_txt2img.py", line 317, in
main(opt)
File "/home/ubuntu/aigc/stable-diffusion/ldm_txt2img.py", line 218, in main
model = load_model_from_config(config, f"{opt.ckpt}",f"{opt.model_format}")
File "/home/ubuntu/aigc/stable-diffusion/ldm_txt2img.py", line 38, in load_model_from_config
pl_sd = torch.load(ckpt, map_location="cpu")
File "/opt/conda/envs/deepspeed01/lib/python3.9/site-packages/torch/serialization.py", line 713, in load
return _legacy_load(opened_file, map_location, pickle_module, pickle_load_args)
File "/opt/conda/envs/deepspeed01/lib/python3.9/site-packages/torch/serialization.py", line 920, in _legacy_load
magic_number = pickle_module.load(f, pickle_load_args)
_pickle.UnpicklingError: invalid load key, '\xec'.
Describe the bug I can use ComViz load_model_from_config function , add safetensors load , it's work , but when I use deepspeed get error _pickle.UnpicklingError: invalid load key, '\xec'
accelerate==0.16.0 deepspeed==0.7.5 diffusers==0.10.0 invisible-watermark==0.1.5 omegaconf==2.3.0 Pillow==9.4.0 safetensors==0.3.0 torch==1.11.0+cu113 torchmetrics==0.11.3 torchvision==0.12.0+cu113 transformers==4.25.1 triton==2.0.0.dev20221202
Expected behavior File "/home/ubuntu/aigc/stable-diffusion/ldm_txt2img.py", line 317, in
main(opt)
File "/home/ubuntu/aigc/stable-diffusion/ldm_txt2img.py", line 218, in main
model = load_model_from_config(config, f"{opt.ckpt}",f"{opt.model_format}")
File "/home/ubuntu/aigc/stable-diffusion/ldm_txt2img.py", line 38, in load_model_from_config
pl_sd = torch.load(ckpt, map_location="cpu")
File "/opt/conda/envs/deepspeed01/lib/python3.9/site-packages/torch/serialization.py", line 713, in load
return _legacy_load(opened_file, map_location, pickle_module, pickle_load_args)
File "/opt/conda/envs/deepspeed01/lib/python3.9/site-packages/torch/serialization.py", line 920, in _legacy_load
magic_number = pickle_module.load(f, pickle_load_args)
_pickle.UnpicklingError: invalid load key, '\xec'.