pranoyr / cnn-lstm

CNN LSTM architecture implemented in Pytorch for Video Classification
MIT License
260 stars 46 forks source link

change #8

Open cuteboyqq opened 3 months ago

cuteboyqq commented 3 months ago

when l set --sample_duration 32 , it will have below error, do you know how to solve ? image

cuteboyqq commented 3 months ago

l remove the validation when training model, so it can train model , but just no vaidate... l modify your model network, and it also can get high accuracy, because l can not understand the logic of forward the resnet in CNNLSTM model network

import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torchvision.models import resnet101

class CNNLSTM(nn.Module):
    def __init__(self, num_classes=2):
        super(CNNLSTM, self).__init__()
        self.resnet = resnet101(pretrained=True)
        self.resnet.fc = nn.Sequential(nn.Linear(self.resnet.fc.in_features, 300))
        self.lstm = nn.LSTM(input_size=300, hidden_size=256, num_layers=3, batch_first=True)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x_3d):
        batch_size, seq_len, c, h, w = x_3d.size()

        # Reshape the input to process all frames at once
        x = x_3d.view(-1, c, h, w)  # (batch_size * seq_len, c, h, w)

        # Forward pass through ResNet
        features = self.resnet(x)  # (batch_size * seq_len, 300)

        # Reshape features back to (batch_size, seq_len, 300)
        features = features.view(batch_size, seq_len, -1)  

        # LSTM forward pass
        out, _ = self.lstm(features)  # (batch_size, seq_len, 256)

        # Use the output from the last time step
        x = out[:, -1, :]  # (batch_size, 256)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x