mlech26l / ncps

PyTorch and TensorFlow implementation of NCP, LTC, and CfC wired neural models
https://www.nature.com/articles/s42256-020-00237-3
Apache License 2.0
1.86k stars 297 forks source link

Example for image sequence classifier #60

Open selcukyazarklu opened 6 months ago

selcukyazarklu commented 6 months ago

Hi,

Could you supply more detailed steps for image sequence classification?

I have 200x200 and 3 channels of images for about 49 classses

Regards.

MalekWahidi commented 6 months ago

To classify image sequences with dimensions (200,200,3) for about 49 classes, you can adapt the Atari behavior cloning example from the documentation with some modifications:

Basic PyTorch template code:

Import libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from ncps.torch import CfC
from ncps.wirings import AutoNCP

Define the Convolutional Block

class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        # Adjust these layers to match your 200x200x3 input
        self.conv1 = nn.Conv2d(3, 64, 5, padding=2, stride=2)
        self.conv2 = nn.Conv2d(64, 128, 5, padding=2, stride=2)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 5, padding=2, stride=2)
        self.bn3 = nn.BatchNorm2d(256)
        # Add more layers as needed

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        # Global average pooling
        x = x.mean((-1, -2))  
        return x

Define the combined Convolutional and CfC model

class ConvCfC(nn.Module):
    def __init__(self, n_classes, n_features, n_neurons):
        super(ConvCfC, self).__init__()
        self.conv_block = ConvBlock()
        wiring = AutoNCP(n_neurons, n_classes)  # Assuming n_classes is the same as n_outputs
        self.rnn = CfC(n_features, wiring, batch_first=True, return_sequences=False)
        # Add a fully connected layer
        self.fc = nn.Linear(n_neurons, n_classes)

    def forward(self, x, hx=None):
        batch_size, seq_len = x.size(0), x.size(1)
        # Reshape to combine batch and sequence dimensions
        x = x.view(batch_size * seq_len, *x.shape[2:])
        x = self.conv_block(x)
        # Separate batch and sequence dimensions
        x = x.view(batch_size, seq_len, -1)
        x, hx = self.rnn(x, hx)
        return x, hx

Define your custom dataset

class MyDataset(Dataset):
    # Implement dataset loading here
    def __init__(self):
        pass

    def __len__(self):
        # Return the size of dataset
        pass

    def __getitem__(self, idx):
        # Implement logic to get a single item at idx
        pass

Instantiate model, criterion, optimizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvCfC(n_classes=49).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Load your dataset

train_ds = MyDataset()  # Implement this
trainloader = DataLoader(train_ds, batch_size=32, shuffle=True)

Training loop

for epoch in range(num_epochs):
    model.train()
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs, hx = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

Keep in mind that this is just a quick draft of the general layout and a lot should be modified based on best practices and trial-and-error.

by90 commented 2 months ago

outputs, hx = model(inputs)...i found there isn't any example about hx....how we use the hidden state?here,you haven't pass the hx like outputs, hx = model(inputs,hx),and doesn't use the returned hx

noorchauhan commented 2 months ago

am unable to understand what exactly is your end goal regarding your recent statement @by90 can you elaborate more with what exactly are you trying to achieve?