Open gengxiaomeng opened 4 years ago
In py_factory.py, "def load_pretrained_params(self, pretrained_model)" should be modified like this:
def load_pretrained_params(self, pretrained_model):
print("loading from {}".format(pretrained_model))
with open(pretrained_model, "rb") as f:
params = torch.load(f)
params = {k: v for k, v in params.items()
if k in self.model.state_dict()
and self.model.state_dict()[k].shape == v.shape}
self.model.load_state_dict(params, strict=False)
Dear all:
when I train my data, the category count is 1,my code is modified as follow:
1.modified one: models/CornerNet.py: line 72 class model(kp): def init(self, db): n = 5 dims = [256, 256, 384, 384, 384, 512] modules = [2, 2, 2, 2, 2, 4] out_dim = 1<---------------------- here need change to my category count 2.modified two: db/detection.py: line 8 self._configs["categories"] = 1 <---------------------- here need change to my category count