ayushtewari / DFM

Implementation of "Diffusion with Forward Models: Solving Stochastic Inverse Problems Without Direct Supervision"
https://diffusion-with-forward-models.github.io/
153 stars 18 forks source link

Pretrain pixel-nerf results in errors when loading the checkpoint #13

Open lukasHoel opened 1 year ago

lukasHoel commented 1 year ago

Hi,

I tried to follow the instructions and first train a pixel-nerf checkpoint and then finetune. However, there are several issues when loading the state-dict for the second-stage training.

Some sources of error are:

tianweiy commented 1 year ago

thank you for the interest. I will push a fix soon.

we could just start from second stage training directly. I wonder if it makes a huge difference < sometime, the diffusion pixel nerf doesn't converge well without the PN init but we don't have a conclusive answer

tianweiy commented 1 year ago

fixed in 6c01dc83ed584a6a86cf8a936903383895b5a595

lukasHoel commented 1 year ago

Thank you very much for the fast help, really appreciate it! May I ask you also how to fix the error mentioned here? I get the same error message and guess problem could be similar?

https://github.com/ayushtewari/DFM/issues/7

tianweiy commented 1 year ago

fixed

lukasHoel commented 1 year ago

Now everything works. I just improved the loading functionality a bit more. The current implementation would always throw away the weights for model.enc.pos_embed, also if we continue training a checkpoint at the same stage:

https://github.com/ayushtewari/DFM/blob/50c6e20db124147f37ba44b256000de6ce524270/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L1103

I believe, instead we should first try to load everything (e.g. continue training at the same stage) and only if it fails do this fix (e.g., if loading a 64x64 checkpoint into a 128x128 model, this needs to be re-initialized).

        try:
            # load all parameters (e.g., continue training at the same stage)
            model.load_state_dict(data["model"], strict=True)
        except:
            print("loading with strict=True failed. Assume we continue from a 64x64 checkpoint and skip certain layers.")
            # e.g., if loading a 64x64 checkpoint into a 128x128 model, this needs to be re-initialized
            data["model"].pop("model.enc.pos_embed")
            print(model.load_state_dict(data["model"], strict=False))