slothfulxtx / MBPTrack3D

[ICCV2023] MBPTrack: Improving 3D Point Cloud Tracking with Memory Networks and Box Priors
24 stars 6 forks source link

How to load weights into the model? #5

Closed MaxTeselkin closed 1 year ago

MaxTeselkin commented 1 year ago

Hi! I am coding a script for MBPTrack inference on custom tracklet, but I am unable to load model state dict from your .ckpt files.

Here is my code for loading model on device:

def load_on_device(
        self,
        model_dir: str,
        device: Literal["cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"] = "cuda",
    ):
        seed_everything(42)
        model_path = configs_path + "mbptrack_kitti_car_cfg.yaml"
        checkpoint_path = checkpoints_path + "mbptrack_kitti_car.ckpt"
        with open(model_path, "r") as f:
            cfg = yaml.load(f, Loader=yaml.FullLoader)
        self.cfg = Dict(cfg)
        self.cfg.work_dir = "./work_dir/"
        self.cfg.resume_from = checkpoint_path
        self.cfg.save_test_result = True
        self.cfg.gpus = [0]
        os.makedirs(self.cfg.work_dir, exist_ok=True)
        with open(os.path.join(self.cfg.work_dir, "config.yaml"), "w") as f:
            yaml.dump(self.cfg.to_dict(), f)
        log_file_dir = os.path.join(self.cfg.work_dir, "3DSOT.log")
        log = Logger(name="3DSOT", log_file=log_file_dir)
        self.model = create_model(self.cfg.model_cfg, log)
        self.device = torch.device(device)
        self.model = self.model.to(self.device)
        self.model.eval()

When I used this code, the weights were initialised randomly, so self.cfg.resume_from = checkpoint_path has no effect. So I decided to load model state dict as for regular pytorch model:

checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["state_dict"])

but got multiple size mismatch errors.

After that I tried to load state dict as for Pytorch Lightning model using:

self.model = MBPTTask(self.cfg, log)
self.model = self.model.load_from_checkpoint(
            checkpoint_path=checkpoint_path,
            map_location=self.device,
            hparams_file=model_path,
            cfg=self.cfg,
            log=log,
        )

and got size mismatch error again.

What am I doing wrong?

P.S. You have only script for testing pretrained model on specific dataset using pl.Trainer.test, which is completely not suitable for inference on custom tracklet - that's why I need to code my own script for loading model weights.

MaxTeselkin commented 1 year ago

Update: I was able to load model weights by modifying original state dict.

When I tried to load weights as for regular pytorch model, I received the error because every key in original state dict has preffix "model.". So I decided to remove this preffix from all state dict keys using this function:

def preprocess_state_dict(self, state_dict):
        preprocessed_state_dict = {}
        for key, value in state_dict.items():
            preprocessed_state_dict[key[6:]] = value
        return preprocessed_state_dict

After running this code the weights were loaded successfully:

 checkpoint = torch.load(checkpoint_path, map_location=self.device)
        raw_state_dict = checkpoint["state_dict"]
        preprocessed_state_dict = self.preprocess_state_dict(raw_state_dict)
        self.model.load_state_dict(preprocessed_state_dict)