CarperAI / DRLX

Diffusion Reinforcement Learning Library
MIT License
171 stars 7 forks source link

Local models dont work? #21

Open TingTingin opened 1 year ago

TingTingin commented 1 year ago

I was trying to use a local safetensors sd model and cant seem to get it to work does the current setup always trys to download from hugging face even if an explicit file path is given and use_safetensors is set to true.

The models will work locally if downloaded from the hub intially but not if I give a file path to a local safetensors model

shahbuland commented 1 year ago

I will try to reproduce tomorrow. Just to be sure, are you using the trainers load method or are you trying to load a local model into the trainer for the denoiser? I haven't tested the latter situation but I can try to figure it out.

shahbuland commented 1 year ago

@TingTingin Tried to reproduce locally and having no issues. Does this code do what you want? It runs for me (I did add a config option in #22 to load model only locally, but even without this it should still look for local files).

from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig
import torch
import os

config = DRLXConfig.load_yaml("configs/my_cfg.yml")
config.model.use_safetensors = True
trainer = DDPOTrainer(config)

fp = "./checkpoints_saving_test"
trainer.save_pretrained("./output/saving_test")
trainer.save_checkpoint(fp)

trainer.load_checkpoint(fp)

from diffusers import StableDiffusionPipeline

del trainer
config.model.model_path = "./output/saving_test"
config.model.local_model = True
trainer = DDPOTrainer(config)
print("Successfully loaded pipeline")
TingTingin commented 1 year ago

sorry for taking so long to respond the way i thought it was huggingface_hub being outdated so i updated it but it still didnt im not using a config file im calling all the configs directly from the code this is the model portion thats not working


model_config = ModelConfig(model_path=r"C:\StableDiffusion\Repos\automatic\models\Stable-diffusion\awportrait_v11.safetensors",
                           model_arch_type="LDMUnet",
                           attention_slicing=True,
                           xformers_memory_efficient=True,
                           gradient_checkpointing=True,
                           use_safetensors=True
                           )