Open cuteboyqq opened 3 months ago
l remove the validation when training model, so it can train model , but just no vaidate... l modify your model network, and it also can get high accuracy, because l can not understand the logic of forward the resnet in CNNLSTM model network
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torchvision.models import resnet101
class CNNLSTM(nn.Module):
def __init__(self, num_classes=2):
super(CNNLSTM, self).__init__()
self.resnet = resnet101(pretrained=True)
self.resnet.fc = nn.Sequential(nn.Linear(self.resnet.fc.in_features, 300))
self.lstm = nn.LSTM(input_size=300, hidden_size=256, num_layers=3, batch_first=True)
self.fc1 = nn.Linear(256, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x_3d):
batch_size, seq_len, c, h, w = x_3d.size()
# Reshape the input to process all frames at once
x = x_3d.view(-1, c, h, w) # (batch_size * seq_len, c, h, w)
# Forward pass through ResNet
features = self.resnet(x) # (batch_size * seq_len, 300)
# Reshape features back to (batch_size, seq_len, 300)
features = features.view(batch_size, seq_len, -1)
# LSTM forward pass
out, _ = self.lstm(features) # (batch_size, seq_len, 256)
# Use the output from the last time step
x = out[:, -1, :] # (batch_size, 256)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
when l set --sample_duration 32 , it will have below error, do you know how to solve ?