Open mdit108 opened 2 years ago
from torchvision.models import googlenet
import torch
model = googlenet(pretrained=True)
extractor = torch.nn.Sequential(*list(model.children())[:-2])
im = torch.randn(1,3,720,1280) # NCHW
feature = extractor(im).cpu().numpy().flatten() # [1,1024,1,1] -> [1024]
i try like this...
I wonder if he is using the pool5 layer of the googlenet network for feature extraction, so is that the code you wrote? Or is there some other additional code.
I wonder if he is using the pool5 layer of the googlenet network for feature extraction, so is that the code you wrote? Or is there some other additional code.
I write the code myself.
@ehdrndd Can you share your feature extraction code? Or give a link, thanks! You can add a contact if it is convenient
@ruanzhijian It may help you.
https://github.com/HERIUN/vsumm-reinforce_re/blob/main/generate_dataset.py
In which part of the code is the GoogLe Net as the first part of the DSN specified?