Closed seareale closed 1 year ago
I solved the error using pretrained weights.
def load_ckpt(ckpt_path):
saved_ckpt = torch.load(ckpt_path)
config = saved_ckpt["config_dict"]["_content"]
model = instantiate_from_config(config["model"]).to(device).eval()
autoencoder = instantiate_from_config(config["autoencoder"]).to(device).eval()
text_encoder = instantiate_from_config(config["text_encoder"]).to(device).eval()
diffusion = instantiate_from_config(config["diffusion"]).to(device)
# ADD #################
model_parts = torch.load("diffusion_pytorch_model.bin")
######################
model.load_state_dict(saved_ckpt["model"])
autoencoder.load_state_dict(model_parts["autoencoder"]) # MODIFIED
text_encoder.load_state_dict(model_parts["text_encoder"]) # MODIFIED
diffusion.load_state_dict(model_parts["diffusion"]) # MODIFIED
return model, autoencoder, text_encoder, diffusion, config
First, Thank you for sharing Great work!
I trained my custom dataset using GLIGEN and obtained weights.
But I can't run
gligen_inference.py
using my custom weights with a error below.The code with the problem is here. https://github.com/gligen/GLIGEN/blob/d52d93a25fc9727ad6085f267345db0237f658a8/gligen_inference.py#L68-L84
How can I solve it?