gligen / GLIGEN

Open-Set Grounded Text-to-Image Generation
MIT License
1.92k stars 144 forks source link

Unable to load my trained checkpoints #16

Closed seareale closed 1 year ago

seareale commented 1 year ago

First, Thank you for sharing Great work!

I trained my custom dataset using GLIGEN and obtained weights.

But I can't run gligen_inference.py using my custom weights with a error below.

image

The code with the problem is here. https://github.com/gligen/GLIGEN/blob/d52d93a25fc9727ad6085f267345db0237f658a8/gligen_inference.py#L68-L84

How can I solve it?

seareale commented 1 year ago

I solved the error using pretrained weights.

def load_ckpt(ckpt_path):
    saved_ckpt = torch.load(ckpt_path)
    config = saved_ckpt["config_dict"]["_content"]

    model = instantiate_from_config(config["model"]).to(device).eval()
    autoencoder = instantiate_from_config(config["autoencoder"]).to(device).eval()
    text_encoder = instantiate_from_config(config["text_encoder"]).to(device).eval()
    diffusion = instantiate_from_config(config["diffusion"]).to(device)

    # ADD #################
    model_parts = torch.load("diffusion_pytorch_model.bin")
    ######################

    model.load_state_dict(saved_ckpt["model"])
    autoencoder.load_state_dict(model_parts["autoencoder"]) # MODIFIED
    text_encoder.load_state_dict(model_parts["text_encoder"]) # MODIFIED
    diffusion.load_state_dict(model_parts["diffusion"]) # MODIFIED

    return model, autoencoder, text_encoder, diffusion, config