Lornatang / SRGAN-PyTorch

A simple and complete implementation of super-resolution paper.
Apache License 2.0
410 stars 105 forks source link

When extracting features with the VGG, at what layer level do we stop on the VGG? #67

Closed jojupiter closed 1 year ago

jojupiter commented 1 year ago

Thank you very much for this model. I have a question about this model. It's about feature extraction when training the model. How to know which layer to stop when using for example VGG19 or VGG16 for feature extraction?

Lornatang commented 1 year ago

For VGG19 example:

from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models import vgg19

model = vgg19()

# Get all layer name
get_graph_node_names(model)[0]
(['x',
  'features.0',
  'features.1',
  'features.2',
  'features.3',
  'features.4',
  'features.5',
  'features.6',
  'features.7',
  'features.8',
  'features.9',
  'features.10',
  'features.11',
  'features.12',
  'features.13',
  'features.14',
  'features.15',
  'features.16',
  'features.17',
  'features.18',
  'features.19',
  'features.20',
  'features.21',
  'features.22',
  'features.23',
  'features.24',
  'features.25',
  'features.26',
  'features.27',
  'features.28',
  'features.29',
  'features.30',
  'features.31',
  'features.32',
  'features.33',
  'features.34',
  'features.35',-------------> vgg5.4
  'features.36',
  'avgpool',
  'flatten',
  'classifier.0',
  'classifier.1',
  'classifier.2',
  'classifier.3',
  'classifier.4',
  'classifier.5',
  'classifier.6'],

# Display model arch
model
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace=True)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace=True)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace=True)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace=True)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace=True)-------------> vgg5.4
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Wish hope you.