Caoliangjie / pytorch-gradcam-resnet50

CAM图的resnet50版本
150 stars 42 forks source link

您好,我是从简书过来的,我的网络模型必须用GAP(类似的池化方式),所以我在class ModelOutputs():里面添加了一些代码,可以成功运行,产生结果,但是不知道是否正确(逻辑上) #9

Closed antecede closed 5 years ago

antecede commented 5 years ago
    class ModelOutputs():
        def __init__(self, model, target_layers):
            self.model = model
            self.feature_extractor = FeatureExtractor(self.model.features, target_layers)

        def get_gradients(self):
            return self.feature_extractor.gradients

        def __call__(self, x):
            target_activations, output = self.feature_extractor(x)

            **### output = self.model.representation(output)**

            output = output.view(output.size(0), -1)
            output = self.model.classifier(output)
            return target_activations, output
antecede commented 5 years ago

我看resnet50是包括avgpool的,但是您的代码中并没有用到,原因是
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0) (fc): Linear(in_features=2048, out_features=1000, bias=True) fc和上面的卷积层大小一样,但是我的必须经过类似avgpool才能使得卷积层和fc层维度一样,所以不知道这样加上avgpool逻辑上是否讲得通

Caoliangjie commented 5 years ago

我看resnet50是包括avgpool的,但是您的代码中并没有用到,原因是 (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0) (fc): Linear(in_features=2048, out_features=1000, bias=True) fc和上面的卷积层大小一样,但是我的必须经过类似avgpool才能使得卷积层和fc层维度一样,所以不知道这样加上avgpool逻辑上是否讲得通

您好,这个具体可以根据自己的网络要求进行适配。

antecede commented 5 years ago

您好,我在您的源代码里添加了如下内容: def call(self, x): target_activations, output = self.feature_extractor(x)

        **### output = self.model.representation(output)**

        ### **output** = resnet.avgpool(output).cuda()
        output = output.view(output.size(0), -1)
        output = self.model.classifier(output)
        return target_activations, output

而后对比生成的图片,发现没有肉眼可见的变化,其中如果是pytorch1.0的版本avgpool会报错,这个bug已经在1.1的版本更新了,我所用的是cuda9.0,可以适配pytorch1.0和1.1。 更新pytorch1.1的代码如下:conda install pytorch torchvision cudatoolkit=9.0 -c pytorch