It appears that, during stage 2 training, I should disable both PoseGuider and ReferenceNet in train.py like this to prevent the AttributeError. Am I correct?
# if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
if is_main_process and global_step % checkpointing_steps == 0 :
save_path = os.path.join(output_dir, f"checkpoints")
state_dict = {
"epoch": epoch,
"global_step": global_step,
"unet_state_dict": unet.module.state_dict(),
# "poseguider_state_dict": poseguider.module.state_dict(),
# "referencenet_state_dict": referencenet.module.state_dict(),
}
if step == len(train_dataloader) - 1:
torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
else:
It appears that, during stage 2 training, I should disable both PoseGuider and ReferenceNet in
train.py
like this to prevent the AttributeError. Am I correct?