Tushar-N / pytorch-resnet3d

I3D Nonlocal ResNets in Pytorch
245 stars 39 forks source link

Functionality for feature extraction? #12

Closed yushuinanrong closed 4 years ago

yushuinanrong commented 4 years ago

Hi @Tushar-N , First of all, thanks for sharing this great repo! I'm wondering if you know any easy way to extract intermediate feature (e.g., the feature before the last fc layer) for a clip from a pretrained model?

Tushar-N commented 4 years ago

Yes, that's possible. The simplest way is to just call the forward pass, without the flattening/fc layers. In the I3Res50 class in models/resnet:

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool1(x)
    x = self.layer1(x)
    x = self.maxpool2(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.avgpool(x)
    return x

Either create a function, or a subclass that replaces the forward method, and this should work.