gligen / GLIGEN

Open-Set Grounded Text-to-Image Generation
MIT License
1.91k stars 145 forks source link

Unexpected key(s) in state_dict: "transformer.text_model.embeddings.position_ids". #67

Closed tienduong-21 closed 7 months ago

tienduong-21 commented 7 months ago

I run file gligen_inference.py and this is my error: RuntimeError: Error(s) in loading state_dict for FrozenCLIPEmbedder: Unexpected key(s) in state_dict: "transformer.text_model.embeddings.position_ids". This is my code

def load_ckpt(ckpt_path):

    saved_ckpt = torch.load(ckpt_path)
    config = saved_ckpt["config_dict"]["_content"]

    model = instantiate_from_config(config['model']).to(device).eval()
    autoencoder = instantiate_from_config(config['autoencoder']).to(device).eval()
    text_encoder = instantiate_from_config(config['text_encoder']).to(device).eval()
    diffusion = instantiate_from_config(config['diffusion']).to(device)

   # I added this line to fix but it not work
    model_parts = torch.load("gligen_checkpoints/diffusion_pytorch_model.bin")

    # donot need to load official_ckpt for self.model here, since we will load from our ckpt
    model.load_state_dict( saved_ckpt['model'] )
    autoencoder.load_state_dict( model_parts["autoencoder"]  )
    text_encoder.load_state_dict( model_parts["text_encoder"]  )
    diffusion.load_state_dict( model_parts["diffusion"]  )

    return model, autoencoder, text_encoder, diffusion, config

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--folder", type=str,  default="generation_samples", help="root folder for output")

    parser.add_argument("--batch_size", type=int, default=5, help="")
    parser.add_argument("--no_plms", action='store_true', help="use DDIM instead. WARNING: I did not test the code yet")
    parser.add_argument("--guidance_scale", type=float,  default=7.5, help="")
    parser.add_argument("--negative_prompt", type=str,  default='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality', help="")
    #parser.add_argument("--negative_prompt", type=str,  default=None, help="")
    args = parser.parse_args()

    meta_list = [ 

        # - - - - - - - - GLIGEN on text grounding for generation - - - - - - - - # 
        dict(
            ckpt = "gligen_checkpoints/diffusion_pytorch_model.bin",
            prompt = "a teddy bear sitting next to a bird",
            phrases = ['a teddy bear', 'a bird'],
            locations = [ [0.0,0.09,0.33,0.76], [0.55,0.11,1.0,0.8] ],
            alpha_type = [0.3, 0.0, 0.7],
            save_folder_name="generation_box_text"
        ), 
yejinxueshi commented 7 months ago

Have you solved it? I also encountered the same problem

tienduong-21 commented 7 months ago

I solved this problem!! You need to remove this key from dictionary

def load_ckpt(ckpt_path):

    saved_ckpt = torch.load(ckpt_path)
    config = saved_ckpt["config_dict"]["_content"]

    model = instantiate_from_config(config['model']).to(device).eval()
    autoencoder = instantiate_from_config(config['autoencoder']).to(device).eval()
    text_encoder = instantiate_from_config(config['text_encoder']).to(device).eval()
    diffusion = instantiate_from_config(config['diffusion']).to(device)

    model_parts = torch.load("gligen_checkpoints/diffusion_pytorch_model.bin")
    # donot need to load official_ckpt for self.model here, since we will load from our ckpt
    model.load_state_dict( saved_ckpt['model'] )
    autoencoder.load_state_dict( model_parts["autoencoder"]  )
    del model_parts["text_encoder"]['transformer.text_model.embeddings.position_ids'] #add this line

    text_encoder.load_state_dict( model_parts["text_encoder"]  )
    diffusion.load_state_dict( model_parts["diffusion"]  )

    return model, autoencoder, text_encoder, diffusion, config
catsled commented 6 months ago

Have you solved it? I also encountered the same problem

you can check your packages version, such as transformers and diffusers,

--- i reinstalled these packages and solved it.