Blealtan / efficient-kan

An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).
MIT License
3.49k stars 306 forks source link

I would like to add a train test function to the KAN class #44

Open riteshshergill opened 3 weeks ago

riteshshergill commented 3 weeks ago

can't seem to open a Branch for raising a pull request so adding code here:

def train_model(self, model, trainloader, valloader, optimizer, scheduler, criterion, device, epochs): model.to(device) for epoch in range(epochs):

Train

        model.train()
        with tqdm(trainloader) as pbar:
            for i, (images, labels) in enumerate(pbar):
                images = images.view(-1, 28 * 28).to(device)
                optimizer.zero_grad()
                output = model(images)
                loss = criterion(output, labels.to(device))
                loss.backward()
                optimizer.step()
                accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
                pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

        # Validation
        model.eval()
        val_loss = 0
        val_accuracy = 0
        with torch.no_grad():
            for images, labels in valloader:
                images = images.view(-1, 28 * 28).to(device)
                output = model(images)
                val_loss += criterion(output, labels.to(device)).item()
                val_accuracy += (
                    (output.argmax(dim=1) == labels.to(device)).float().mean().item()
                )
        val_loss /= len(valloader)
        val_accuracy /= len(valloader)

        # Update learning rate
        scheduler.step()

        print(
            f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
        )

def test_model(self, model, testloader, device, num_samples=10):
    model.to(device)
    model.eval()
    predictions = []
    ground_truths = []
    images_to_show = []

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for i, (images, labels) in enumerate(testloader):
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            predictions.extend(output.argmax(dim=1).cpu().numpy())
            ground_truths.extend(labels.cpu().numpy())
            images_to_show.extend(images.view(-1, 28, 28).cpu().numpy())

            if len(predictions) >= num_samples:
                break

    # Print the predictions for the specified number of samples
    for i in range(num_samples):
        print(f"Ground Truth: {ground_truths[i]}, Prediction: {predictions[i]}")

    return predictions[:num_samples], ground_truths[:num_samples], images_to_show[:num_samples]