Closed CristianaTiago closed 3 months ago
To classify time series data with more than two classes, you can adapt the approach used in script 06-ecg classification to handle multiple classes. Suggested solution:-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
class TimeSeriesDataset(Dataset): def init(self, data, labels): self.data = data self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
class CNN_LSTM(nn.Module): def init(self, input_dim, hidden_dim, num_classes): super(CNN_LSTM, self).init() self.conv1 = nn.Conv1d(input_dim, 32, kernel_size=5) self.relu = nn.ReLU() self.lstm = nn.LSTM(32, hidden_dim, batch_first=True) self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x, _ = self.lstm(x.permute(0, 2, 1))
x = self.fc(x[:, -1, :]) # Use only the last output of LSTM
return x
input_dim = 1 hidden_dim = 64 num_classes = 5
model = CNN_LSTM(input_dim, hidden_dim, num_classes) optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs): for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs.float()) # Assuming inputs are tensors loss = criterion(outputs, labels.long()) # Assuming labels are LongTensor loss.backward() optimizer.step()
# Evaluate on validation set if needed
# Calculate accuracy, etc.
model.eval() total_correct = 0 total_samples = 0
for inputs, labels in testloader: outputs = model(inputs.float()) , predicted = torch.max(outputs, 1) total_correct += (predicted == labels).sum().item() total_samples += labels.size(0)
accuracy = total_correct / total_samples print(f'Accuracy on test set: {accuracy}')
Hope this helps,
Thanks
Hello,
Thanks for the good examples in the repo. I need to classify some time series (similar to what you did in script 06-ecg classification) but I have more than just 2 classes. I can't seem to make it work. Do you have any input on how it should be done?
Thanks!