ignacio-rocco / weakalign

End-to-end weakly-supervised semantic alignment
MIT License
209 stars 47 forks source link

Compatibility with PyTorch 0.4 #10

Open AziziShekoofeh opened 5 years ago

AziziShekoofeh commented 5 years ago

Hi, Thanks for the nice package and code. I had a few issues in runnig the code on PyTorch 0.4, especially in reading the model. I saw a few similar open issues which people sugessted to change the verison to PyTorch 0.2. Since Most of the recent packages are based on PyTorch 0.4+, and I wasn't intrested in use the conda solution or downgrading, I spent time to find a way to run the code. This issue is just for sharing the location that you may need to change for the recent version of the PyTorch:

1- Loading the pre-trained model issue:

OrderedDict "checkpoint['state_dict']['FeatureExtraction.model.1.num_batches_tracked']" does not exist I'd appreciate it if you could check for the error.

To solve this you need to find the comman names between this pretrained checkpoint and TwoStageCNNGeometric model namespaces.

    if model_aff_tps_path != '':
        checkpoint = torch.load(model_aff_tps_path, map_location=lambda storage, loc: storage)
        checkpoint['state_dict'] = OrderedDict(
            [(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()])

        for name, param in model.FeatureExtraction.state_dict().items():
            if 'FeatureExtraction.' + name in checkpoint['state_dict']:
                model.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name])
        for name, param in model.FeatureRegression.state_dict().items():
            if 'FeatureRegression.' + name in checkpoint['state_dict']:
                model.FeatureRegression.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression.' + name])
        for name, param in model.FeatureRegression2.state_dict().items():
            if 'FeatureRegression2.' + name in checkpoint['state_dict']:
                model.FeatureRegression2.state_dict()[name].copy_(checkpoint['state_dict']['FeatureRegression2.' + name])

The other optimum way is to add a more pythonic statement when you are generating the checkpoint and make a OrderedDic, something like:

 [(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items() if v in model.FeatureExtraction.state_dict()])

but I it doesn't work like this definitly and I couldn't find an optimum way anyway, so I ended up to add the explicit if statements.

2- The second issue is happening later on, in preprocess_image(), normalize_image() in ./image/normalization.py, line 38

if isinstance(image,torch.autograd.variable.Variable):
....

The fact is in the classes Tensor and Variable got merged in newer version of PyTorch, so there is no need to check if image is a Variale type and so on.

So, you can easily replace this wholeline by "else:"

Hope this would be helpful for others too.

yimengli46 commented 5 years ago

Thanks, your suggestions work perfectly.