KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.71k stars 1.35k forks source link

Hi, everyone I am trying to replace MLP with KAN in VIT for the MNIST images, but I am encountering this problem : 'TypeError: 'bool' object is not subscriptable' #274

Open ikramaha opened 3 months ago

ikramaha commented 3 months ago

import numpy as np import torch import torch.nn as nn import torch.optim as optim import time from torch.utils.data import DataLoader, TensorDataset from torchvision import datasets, transforms from kan import KAN, KANLayer

Define the Patches class

class Patches(nn.Module): def init(self, patch_size): super(Patches, self).init() self.patch_size = patch_size

def forward(self, images):
    batch_size, channels, height, width = images.size()
    patches = images.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
    patches = patches.contiguous().view(batch_size, channels, -1, self.patch_size * self.patch_size)
    return patches.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.patch_size * self.patch_size * channels)

Define the PatchEncoder class

class PatchEncoder(nn.Module): def init(self, num_patches, input_dim, output_dim): super(PatchEncoder, self).init() self.num_patches = num_patches self.projection = nn.Linear(input_dim, output_dim) self.position_embedding = nn.Parameter(torch.zeros(1, num_patches, output_dim))

def forward(self, patches):
    positions = torch.arange(self.num_patches).unsqueeze(0).repeat(patches.size(0), 1).to(patches.device)
    encoded = self.projection(patches) + self.position_embedding
    return encoded

Define the MultiHeadAttention class

class MultiHeadAttention(nn.Module): def init(self, dim, num_heads, dropout=0.1): super(MultiHeadAttention, self).init() self.num_heads = num_heads self.dim = dim self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5

    self.qkv = nn.Linear(dim, dim * 3)
    self.attention_dropout = nn.Dropout(dropout)
    self.projection = nn.Linear(dim, dim)
    self.projection_dropout = nn.Dropout(dropout)

def forward(self, x):
    batch_size, num_tokens, dim = x.size()
    qkv = self.qkv(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
    qkv = qkv.permute(2, 0, 3, 1, 4)
    queries, keys, values = qkv[0], qkv[1], qkv[2]

    attention_scores = (queries @ keys.transpose(-2, -1)) * self.scale
    attention_weights = torch.softmax(attention_scores, dim=-1)
    attention_weights = self.attention_dropout(attention_weights)

    out = (attention_weights @ values).transpose(1, 2).reshape(batch_size, num_tokens, dim)
    out = self.projection(out)
    out = self.projection_dropout(out)

    return out

class ViTWithKAN(nn.Module): def init(self, image_size=28, patch_size=14, num_classes=10, dim=96, depth=4, heads=4): super(ViTWithKAN, self).init() self.patch_size = patch_size self.num_patches = (image_size // patch_size) ** 2 self.dim = dim

    self.patches = Patches(patch_size)
    self.patch_encoder = PatchEncoder(self.num_patches, patch_size * patch_size * 1, dim)
    self.transformer_layers = nn.ModuleList([
        nn.ModuleList([
            nn.LayerNorm(dim),
            MultiHeadAttention(dim, heads),
            nn.LayerNorm(dim),
            KANLayer(in_dim=dim, out_dim=dim, num=4, k=3, noise_scale=0.1, scale_base=1.0, base_fun=torch.nn.SiLU(), grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, device='cpu')
        ]) for _ in range(depth)
    ])
    self.flatten = nn.Flatten()
    self.kan = KAN(width=[dim, 128, num_classes], grid=5, k=3, seed=0, device='cpu')

def forward(self, x):
    x = self.patches(x)
    x = self.patch_encoder(x)
    for norm1, attn, norm2, kan in self.transformer_layers:
        x = x + attn(norm1(x))
        x = x + kan(norm2(x))
    x = x.mean(dim=1)
    x = self.flatten(x)
    x = self.kan(x)
    return x

def train(model, device, train_loader, optimizer, criterion, epoch): model.train() epoch_loss = 0 correct = 0 total = 0 start_time = time.time()

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    epoch_loss += loss.item()
    pred = output.argmax(dim=1, keepdim=True)
    correct += pred.eq(target.view_as(pred)).sum().item()
    total += target.size(0)

    if batch_idx % 10 == 0:
        print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
epoch_loss /= len(train_loader.dataset)
accuracy = 100. * correct / total
end_time = time.time()
print(f"Time for epoch {epoch}: {end_time - start_time:.2f} seconds")
return epoch_loss, accuracy

def validate(model, device, val_loader, criterion): model.eval() val_loss = 0 correct = 0 with torch.no_grad(): for data, target in val_loader: data, target = data.to(device), target.to(device) output = model(data) val_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() val_loss /= len(val_loader.dataset) accuracy = 100. * correct / len(val_loader.dataset) return val_loss, accuracy

def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ViTWithKAN().to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate) criterion = nn.CrossEntropyLoss()

train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(1, num_epochs + 1):
    train_loss, train_accuracy = train(model, device, train_loader, optimizer, criterion, epoch)
    val_loss, val_accuracy = validate(model, device, test_loader, criterion)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

test_loss, test_accuracy = validate(model, device, test_loader, criterion)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

# Prepare data for KAN
train_data = torch.cat([data for data, _ in train_loader], dim=0)
train_labels = torch.cat([target for _, target in train_loader], dim=0).unsqueeze(1)
test_data = torch.cat([data for data, _ in test_loader], dim=0)
test_labels = torch.cat([target for _, target in test_loader], dim=0).unsqueeze(1)

# Convert tensors to numpy arrays
train_data_np = train_data.cpu().numpy()
train_labels_np = train_labels.cpu().numpy()
test_data_np = test_data.cpu().numpy()
test_labels_np = test_labels.cpu().numpy()

# Create dataset for KAN
dataset = {
    'train_input': train_data_np,
    'train_label': train_labels_np,
    'test_input': test_data_np,
    'test_label': test_labels_np
}

# Debug: Print dataset structure
print("Dataset structure:")
for key, value in dataset.items():
    print(f"{key}: {type(value)} - {value.shape}")

# Ensure dataset is correctly formed before passing to KAN train method
if not isinstance(dataset['train_input'], (np.ndarray, torch.Tensor)) or not isinstance(dataset['test_input'], (np.ndarray, torch.Tensor)):
    raise ValueError("Dataset inputs must be numpy arrays or torch tensors")

# Train KAN
print("Starting KAN training...")
model.kan.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10)
print("KAN training completed.")

# Plot KAN splines
print("Plotting KAN splines...")
model.kan.plot(beta=100, in_vars=[r'Input'], out_vars=['Output'], title='KAN Layer')

if name == "main":

Define transformations for MNIST data

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load MNIST data and define DataLoader
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=2000, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2000, shuffle=False)

# Define global parameters
learning_rate = 0.001
num_epochs = 5  # Increase number of epochs as necessary

main()

this is my code

Commit2Cosmos commented 3 months ago

Hi @ikramaha, could you please include the full error message?