aigc-apps / EasyAnimate

📺 An End-to-End Solution for High-Resolution and Long Video Generation Based on Transformer Diffusion
Apache License 2.0
1.12k stars 84 forks source link

Bug about resume from checkpoint #62

Closed majiajiong closed 1 month ago

majiajiong commented 1 month ago

在train_lora.py脚本中,使用resume_from_checkpoint加载accelerator保存的checkpoint时存在bug,不确定下面修改是否正确。

  1. 函数load_model_hook中https://github.com/aigc-apps/EasyAnimate/blob/main/scripts/train_lora.py#L755

    def load_model_hook(models, input_dir):
        for i in range(len(models)):
            # pop models so that they are not loaded again
            model = models.pop()
    
            # load diffusers style into model
            load_model = Transformer3DModel.from_pretrained(
                input_dir, subfolder="transformer",
                transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs'])
            )
            model.register_to_config(**load_model.config)
    
            model.load_state_dict(load_model.state_dict())
            del load_model

    models包含Transformer3DModel和LoRANetwork对象,其中LoRANetwork对象不含register_to_config方法/属性 报错AttributeError: 'LoRANetwork' object has no attribute 'register_to_config'

    使用isinstance(model,Transformer3DModel)判断model是Transformer3DModel时才执行下面的过程可运行成功

  2. 加载模型报错https://github.com/aigc-apps/EasyAnimate/blob/main/scripts/train_lora.py#L1055 执行accelerator.load_state(os.path.join(args.output_dir, path))时, RuntimeError: Error(s) in loading state_dict for T5EncoderModel: Missing key(s) in state_dict: "shared.weight", "encoder.embed_tokens.weight" ...... Unexpected key(s) in state_dict: "scale_shift_table", "pos_embed.proj.weight" ......

    TextEncoder的参数加载不应该在这里吧,加上strict=False忽略不匹配的参数就可以继续运行 accelerator.load_state(os.path.join(args.output_dir, path),strict=False)

bubbliiiing commented 1 month ago

我们找个时间修复一下,我没有试过lora的resume

bubbliiiing commented 1 month ago

已经修复,麻烦测试一下

majiajiong commented 1 month ago

已解决