donydchen / mvsplat

🌊 [ECCV'24] MVSplat: Efficient 3D Gaussian Splatting from Sparse Multi-View Images
https://donydchen.github.io/mvsplat
Other
498 stars 22 forks source link

Question about custom dataset finetuning #45

Open mystorm16 opened 2 days ago

mystorm16 commented 2 days ago

Hi, thanks for the great work.

When I fine-tune with acid.ckpt or re10k.ckpt:python3 -m src.main +experiment=custom checkpointing.load=checkpoints/re10k.ckpt mode=train data_loader.train.batch_size=4

an error occurs:KeyError: 'Trying to restore optimizer state but checkpoint contains only the model. This is probably due toModelCheckpoint.save_weights_onlybeing set toTrue.'

Is my finetune command incorrect?

Langwenchong commented 2 days ago

I am also facing this issue. It seems that the problem arises because when PL tries to restore the model from the checkpoint file, it also needs to read the corresponding optimizer state parameters (such as learning rate) from when the training was terminated. However, the model file provided by the author has already filtered out these unnecessary parameters, leaving only the model parameters, which causes the error. I wonder if the author could additionally provide a checkpoint file that contains all the required state parameters for finetuning🫡?

donydchen commented 2 days ago

Hi @mystorm16 and @Langwenchong, thanks for your interest in our work.

To fine-tune from the released weight, you can initialize the model from the existing checkpoint and skip the checkpoint path in the fit function. Below, I have provided a workaround solution for your reference.

Change the model initialization from https://github.com/donydchen/mvsplat/blob/378ff818c0151719bbc052ac2797a2c769766320/src/main.py#L123-L132 to

model_kwargs = {
    "optimizer_cfg": cfg.optimizer,
    "test_cfg": cfg.test,
    "train_cfg": cfg.train,
    "encoder": encoder,
    "encoder_visualizer": encoder_visualizer,
    "decoder": get_decoder(cfg.model.decoder, cfg.dataset),
    "losses": get_losses(cfg.loss),
    "step_tracker": step_tracker,
}
model_wrapper = ModelWrapper.load_from_checkpoint(
    checkpoint_path, **model_kwargs, strict=True, map_location="cpu",
)

Then, change the fit function from https://github.com/donydchen/mvsplat/blob/378ff818c0151719bbc052ac2797a2c769766320/src/main.py#L141 to

trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=None)

You can confirm the setting by checking the first validation logged at step 0, which should show a good visual result. I will find time in the following weeks to update the code to support fine-tuning. Feel free to let me know if you have any other questions or suggestions.

mystorm16 commented 2 days ago

Thanks for the quick reply! I made this modification and had a good visual result, is this the same:

    model_state_dict = encoder.state_dict()
    checkpoint = torch.load('checkpoints/acid.ckpt')
    checkpoint_state_dict = checkpoint['state_dict']
    for key in model_state_dict:
        if 'encoder.'+key in checkpoint_state_dict:
            if model_state_dict[key].shape == checkpoint_state_dict['encoder.'+key].shape:
                model_state_dict[key].copy_(checkpoint_state_dict['encoder.'+key])
            else:
                print(f"Shape mismatch for parameter {key}. Skipping...")
    encoder.load_state_dict(model_state_dict)
donydchen commented 20 hours ago

Hi @mystorm16, I think your solution does the same thing as the one I provided above since the decoder actually has no trainable parameters. Cheers.