abhishekkrthakur / tez

Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.
Apache License 2.0
1.16k stars 145 forks source link

Load models #16

Open chammami opened 3 years ago

chammami commented 3 years ago

Thanks for this great project.

I started using Tez con Kaggle competition and found that model loading is not appropriately handled when training/infering on different devices: GPU -> CPU. This one possible solution

    def load(self, model_path, device="cuda"):
        self.device = device
        if next(self.parameters()).device != self.device:
            self.to(self.device)
        model_dict = torch.load(model_path, map_location=torch.device(device))
        self.load_state_dict(model_dict["state_dict"])
abhishekkrthakur commented 3 years ago

should be fixed now