huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.56k stars 1.19k forks source link

DDPO trained model error when used to generate images #1775

Closed nguyenhoa-uit closed 1 month ago

nguyenhoa-uit commented 3 months ago

I have trained a simple DDPO model (5 epochs) at 'Nguyen17/my_DDPO' But when I use to generate images, i found this error:

ValueError: Cannot load <class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'> from /root/.cache/huggingface/hub/models--Nguyen17--my_DDPO/snapshots/8e70ecc8d8ff9f04221eb40651b2ed3a3ece90b6/unet because the following keys are missing: down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.bias,

Please tell me how to fix it.

import torch from trl import DefaultDDPOStableDiffusionPipeline

pipeline = DefaultDDPOStableDiffusionPipeline("Nguyen17/my_DDPO")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

memory optimization

pipeline.vae.to(device, torch.float16) pipeline.text_encoder.to(device, torch.float16) pipeline.unet.to(device, torch.float16)

prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"] results = pipeline(prompts)

for prompt, image in zip(prompts,results.images): image.save(f"{prompt}.png")

nguyenhoa-uit commented 3 months ago

I found the solution here : https://github.com/huggingface/trl/issues/1404 Thanks.

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.