Closed xiankgx closed 1 year ago
Is something like this ok? It allows to resume training.
def load_model_hook(models, input_dir):
print(f"load_model_hook called - models: {models}, input_dir: {input_dir}")
unet_ = None
text_encoder_one_ = None
text_encoder_two_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
print(f"lora_state_dict.keys(): {lora_state_dict.keys()}")
unet_new_state_dict = {}
for p, v in lora_state_dict.items():
if p.startswith("unet."):
p = p.replace("unet.", "")
unet_new_state_dict[p] = v
missings, unexpecteds = unet.load_state_dict(unet_new_state_dict, strict=False)
filleds = set(unet.state_dict().keys()) - set(missings)
print(f"filleds: {filleds}")
# print(f"missings: {missings}")
# print(f"unexpecteds: {unexpecteds}")
if args.train_text_encoder:
text_encoder_one_new_state_dict = {}
for p, v in lora_state_dict.items():
if p.startswith("text_encoder_one."):
p = p.replace("text_encoder_one.", "")
text_encoder_one_new_state_dict[p] = v
missings, unexpecteds = text_encoder_one.load_state_dict(text_encoder_one_new_state_dict, strict=False)
filleds = set(text_encoder_one.state_dict().keys()) - set(missings)
print(f"filleds: {filleds}")
# print(f"missings: {missings}")
# print(f"unexpecteds: {unexpecteds}")
text_encoder_two_new_state_dict = {}
for p, v in lora_state_dict.items():
if p.startswith("text_encoder_two."):
p = p.replace("text_encoder_two.", "")
text_encoder_two_new_state_dict[p] = v
missings, unexpecteds = text_encoder_two.load_state_dict(text_encoder_two_new_state_dict, strict=False)
filleds = set(text_encoder_two.state_dict().keys()) - set(missings)
print(f"filleds: {filleds}")
# print(f"missings: {missings}")
# print(f"unexpecteds: {unexpecteds}")
Thanks for reporting this. Will give it a look.
What is the solution for this? I am still facing the issue "No inf checks were recorded for this optimizer."
I am using the following script
!accelerate launch train_text_to_image_lora_sdxl.py \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
--train_data_dir="/kaggle/input/training-data-updated" --caption_column="text" \
--resolution=1024 --random_flip \
--train_batch_size=1 \
--checkpointing_steps=300 \
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
--seed=42 \
--output_dir="sd-lora-sdxl" \
--gradient_checkpointing \
--mixed_precision="fp16" \
--num_train_epochs=100 \
--push_to_hub \
--hub_token $HF_TOKEN \
--hub_model_id $REPO_ID \
--resume="/kaggle/working/diffusers/examples/text_to_image/sd-lora-sdxl/checkpoint-1500"
and it shows ""AssertionError: No inf checks were recorded for this optimizer."" @sayakpaul . Can you help me in this?
Describe the bug
I am able to train a SDXL Lora no problem. However, when I tried to resume from an existing checkpoint, I'm faced with the error:
Looking at the error, it seems to be AMP related.
Reproduction
Logs