JanaldoChen / Anim-NeRF

MIT License
240 stars 22 forks source link

About loading pre-trained model #19

Closed samsara-ku closed 1 year ago

samsara-ku commented 1 year ago

Hi. First, I'm very happy for your work. Thanks for sharing a nice work.

By the way, I'm not familiar with pytorch-lighting package but there is a problem with loading your pre-trained checkpoint checkpoints/male-3-casual/last.ckpt in this line:

system = AnimNeRFSystem.load_from_checkpoint(args.ckpt_path).to(device)

The error is like this:

Error(s) in loading state_dict for AnimNeRFSystem:
    Missing key(s) in state_dict: "evaluator.lpips.scaling_layer.shift", "evaluator.lpips.scaling_layer.scale", "evaluator.lpips.net.slice1.0.weight", "evaluator.lpips.net.slice1.0.bias", "evaluator.lpips.net.slice2.3.weight", "evaluator.lpips.net.slice2.3.bias", "evaluator.lpips.net.slice3.6.weight", "evaluator.lpips.net.slice3.6.bias", "evaluator.lpips.net.slice4.8.weight", "evaluator.lpips.net.slice4.8.bias", "evaluator.lpips.net.slice5.10.weight", "evaluator.lpips.net.slice5.10.bias", "evaluator.lpips.lin0.model.1.weight", "evaluator.lpips.lin1.model.1.weight", "evaluator.lpips.lin2.model.1.weight", "evaluator.lpips.lin3.model.1.weight", "evaluator.lpips.lin4.model.1.weight", "evaluator.lpips.lins.0.model.1.weight", "evaluator.lpips.lins.1.model.1.weight", "evaluator.lpips.lins.2.model.1.weight", "evaluator.lpips.lins.3.model.1.weight", "evaluator.lpips.lins.4.model.1.weight".

I'm not sure but maybe you did intend to load only a certain part of class AnimeNeRFSystem (i.e. self.models), however the pytorch-lighting load_from_checkpoint fucntion couldn't load appropriately the weights and bias from checkpoint file.

I just changed the strict=True parameter of load_from_checkpoint function to load pre-trained model, and now the question:

Is this a right direction for executing code?

If right, could you change your source code?

JanaldoChen commented 1 year ago

Thanks for your attention! You should set strict=False in load_from_checkpoint. This error 'Missing key(s) in state_dict' is because the metric 'lpips' is not in the pre-trained models, which will not affect the results.

samsara-ku commented 1 year ago

@JanaldoChen Thanks for your answer. I close this issue.