Open isyzwang opened 1 month ago
你可能下载了新版本的accelerate,默认保存模型变为了safetensors格式。图中的model.safetensors就是训练好的模型,你可以直接加载这个模型参数。
是这样改吗? 还是报错了!
尝试这样加载: from safetensors.torch import load_file
val_pipe.unet.load_state_dict(load_file(args.output_dir + f"/checkpoint-{checkpoint_step}/unet_target/diffusion_pytorch_model.safetensors"))