Open mystorm16 opened 2 days ago
I am also facing this issue. It seems that the problem arises because when PL tries to restore the model from the checkpoint file, it also needs to read the corresponding optimizer state parameters (such as learning rate) from when the training was terminated. However, the model file provided by the author has already filtered out these unnecessary parameters, leaving only the model parameters, which causes the error. I wonder if the author could additionally provide a checkpoint file that contains all the required state parameters for finetuning🫡?
Hi @mystorm16 and @Langwenchong, thanks for your interest in our work.
To fine-tune from the released weight, you can initialize the model from the existing checkpoint and skip the checkpoint path in the fit
function. Below, I have provided a workaround solution for your reference.
Change the model initialization from https://github.com/donydchen/mvsplat/blob/378ff818c0151719bbc052ac2797a2c769766320/src/main.py#L123-L132 to
model_kwargs = {
"optimizer_cfg": cfg.optimizer,
"test_cfg": cfg.test,
"train_cfg": cfg.train,
"encoder": encoder,
"encoder_visualizer": encoder_visualizer,
"decoder": get_decoder(cfg.model.decoder, cfg.dataset),
"losses": get_losses(cfg.loss),
"step_tracker": step_tracker,
}
model_wrapper = ModelWrapper.load_from_checkpoint(
checkpoint_path, **model_kwargs, strict=True, map_location="cpu",
)
Then, change the fit
function from https://github.com/donydchen/mvsplat/blob/378ff818c0151719bbc052ac2797a2c769766320/src/main.py#L141 to
trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=None)
You can confirm the setting by checking the first validation logged at step 0, which should show a good visual result. I will find time in the following weeks to update the code to support fine-tuning. Feel free to let me know if you have any other questions or suggestions.
Thanks for the quick reply! I made this modification and had a good visual result, is this the same:
model_state_dict = encoder.state_dict()
checkpoint = torch.load('checkpoints/acid.ckpt')
checkpoint_state_dict = checkpoint['state_dict']
for key in model_state_dict:
if 'encoder.'+key in checkpoint_state_dict:
if model_state_dict[key].shape == checkpoint_state_dict['encoder.'+key].shape:
model_state_dict[key].copy_(checkpoint_state_dict['encoder.'+key])
else:
print(f"Shape mismatch for parameter {key}. Skipping...")
encoder.load_state_dict(model_state_dict)
Hi @mystorm16, I think your solution does the same thing as the one I provided above since the decoder actually has no trainable parameters. Cheers.
Hi, thanks for the great work.
When I fine-tune with acid.ckpt or re10k.ckpt:
python3 -m src.main +experiment=custom checkpointing.load=checkpoints/re10k.ckpt mode=train data_loader.train.batch_size=4
an error occurs:
KeyError: 'Trying to restore optimizer state but checkpoint contains only the model. This is probably due to
ModelCheckpoint.save_weights_onlybeing set to
True.'
Is my finetune command incorrect?