fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.3k stars 238 forks source link

Plotting training and validation curves #207

Open HassanAli545 opened 2 years ago

HassanAli545 commented 2 years ago

Hi @fangwei123456 , How I can plot training and validation curves for this model?. Please guide.

Best

fangwei123456 commented 2 years ago

You can refer to these tutorials about how to use tensorboard:

https://spikingjelly.readthedocs.io/zh_CN/latest/clock_driven_en/3_fc_mnist.html https://spikingjelly.readthedocs.io/zh_CN/latest/clock_driven_en/4_conv_fashion_mnist.html

https://github.com/fangwei123456/spikingjelly/blob/e43c2525040d439ad216f00863bfadfef2642168/spikingjelly/clock_driven/examples/conv_fashion_mnist.py#L299

HassanAli545 commented 2 years ago

Thank you very much for such a quick reply.

HassanAli545 commented 2 years ago

I have one more question. How I can train spiking resnet on my dataset. Please refer me to some tutorial or code.

fangwei123456 commented 2 years ago

How I can train spiking resnet on my dataset.

It is same with training ANN on custom datasets.

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

HassanAli545 commented 2 years ago

Sorry, my question was not clear. I was trying to ask how I can train spiking resnet from spikingjelly on my dataset. Is it same as training pre-trained models in pytorch?

fangwei123456 commented 2 years ago

Is it same as training pre-trained models in pytorch?

Yes, you build a custom dataset, and then you can train your SNN on this dataset. It is same with you train ANN on a custom dataset with pytorch.

Refer to this tutorial about training spiking resnet: https://spikingjelly.readthedocs.io/zh_CN/latest/clock_driven_en/16_train_large_scale_snn.html

HassanAli545 commented 2 years ago

Thank you very much, that really helped me. I have one more question. Is it possible to use spiking LSTM from spiking jelly to classify iris dataset?.

Best

fangwei123456 commented 2 years ago

Is it possible to use spiking LSTM from spiking jelly to classify iris dataset?.

Yes, iris is a simple dataset, and I think spiking LSTM can handle it.

HassanAli545 commented 2 years ago

I am sorry for asking too many questions. I am trying to work with spikingjelly, and I want to work with as many models as possible before starting my work. I have a question related to Resnet11, which is provided on this link. I want to use IF neurons instead of LIF neurons. Do I need to replace IF with LIF, or do I have to modify the forward function?

https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/clock_driven/examples/cifar10_r11_enabling_spikebased_backpropagation.py

fangwei123456 commented 2 years ago

Do I need to replace IF with LIF

Yes, you only need to replace it:

https://github.com/fangwei123456/spikingjelly/blob/1f01c9c500f73e0814cdc0dca4fb2f9ac81ad57a/spikingjelly/clock_driven/examples/cifar10_r11_enabling_spikebased_backpropagation.py#L124

But notice that this network is special because its backward calculation is different from the neuron in https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/clock_driven/neuron.py. You can refer to the origin paper for more details.

HassanAli545 commented 2 years ago

Thanks. I will check that paper first before implementing his function using resnet 11.

HassanAli545 commented 2 years ago

Hi, I saved a model in PyTorch. The saved model gave me maximum accuracy of 89.5 % on the validation dataset. But when I plotted the confusion matrix using the saved mode, it gave me around 52 % accuracy. I do not understand this. Any tips on what is happening here?.

I am using this code:

net = spiking_resnet.spiking_resnet18(pretrained=False, progress=True, single_step_neuron=neuron.IFNode, v_threshold=1., surrogate_function=surrogate.ATan())
#print(net)

model = net.to(DEVICE)
#print(model)

def main():

    image_path = 'Desktop/isic_test2'

    #normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534])
    data = ImageFolder(root=image_path,
                       transform=transforms.Compose([transforms.ToTensor(), transforms.Resize([64, 64])]))

    train_data, valid_data, test_data = random_split(data, [int(len(data) * 0.8), int(len(data) * 0.20),
                                                            len(data) - (int(len(data) * 0.8) + int(len(data) * 0.20))])

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True, pin_memory=True,drop_last=True)
    val_loader = DataLoader(valid_data, batch_size=32, shuffle=True, pin_memory=False,drop_last=True)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=True, pin_memory=False,drop_last=True)

    net = model.to(DEVICE)

    optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

    start_epoch = 0
    max_test_acc = 0

    writer = SummaryWriter('spiking_plots')

    for epoch in range(150):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        for img, label in train_loader:

            img = img.to(DEVICE)
            label = label.to(DEVICE)
            #label_onehot = F.one_hot(label, 2).float()
            out_fr = net(img)
            loss = nn.CrossEntropyLoss()
            loss = loss(out_fr, label.view((-1)))
            loss.backward()
            optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            optimizer.zero_grad()
            functional.reset_net(net)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)

        #lr_scheduler.step()

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        with torch.no_grad():
            for img, label in val_loader:
                img = img.to(DEVICE)
                label = label.to(DEVICE)
                #label_onehot = F.one_hot(label, 10).float()
                out_fr = net(img)
                loss = nn.CrossEntropyLoss()
                loss = loss(out_fr, label.view((-1)))
                #loss = F.mse_loss(out_fr, label.view((-1)))

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                functional.reset_net(net)

        test_loss /= test_samples
        test_acc /= test_samples

        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

       # %tensorboard -- logdir={spiking_plots}

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        PATH = "/Desktop/saved models/state_dict_resnet18_pre_trained1.pt"

        torch.save(net.state_dict(), PATH)

        print(
            f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, total_time={time.time() - start_time}')

if __name__ == '__main__':
    main()
fangwei123456 commented 2 years ago

save_max is not used.

HassanAli545 commented 2 years ago

You are just owesome. Where I have to use save_max?

fangwei123456 commented 2 years ago

If you want to save the model with maximum acc, you can do like this:

if test_acc > max_test_acc:
        max_test_acc = test_acc
        PATH = "/mnt/beegfs/home/sgilani2020/saved models/state_dict_resnet18_pre_trained1.pt"
        print(max_test_acc)
        torch.save(net.state_dict(), PATH)
HassanAli545 commented 2 years ago

Thank you very much for your help

HassanAli545 commented 2 years ago

Hi, I saved the code by following your suggestions like this:

   save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            PATH = "/mnt/beegfs/home/sgilani2020/saved models/state_dict_resnet18_pre_trained.pt"
            print(max_test_acc)
            torch.save(net.state_dict(), PATH)
            save_max = True

But I am still facing the same problem. I am getting this confusion matrix after using the best model. It is still giving me the accuracy of about 55% on test set.

tensor([[734., 28.], [588., 26.]])

fangwei123456 commented 2 years ago

IRIS dataset only has three classes. Does the output of SNN you used have shape [batch_size, 3]?

fangwei123456 commented 2 years ago

The default classes_num is 1000: https://github.com/fangwei123456/spikingjelly/blob/72132ff2d147ceaefe1f68aa74d6286086c62750/spikingjelly/clock_driven/model/spiking_resnet.py#L184

HassanAli545 commented 2 years ago

O got it. I am running this model on image dataset. I am using resnet11 on iris dataset. I had to change he number of classes also. My mistake sorry.

HassanAli545 commented 2 years ago

Hi @fangwei123456 , Is it possible to contact you via email?. I want to discuss my work related to SNN. I need your suggestion for my work.

fangwei123456 commented 2 years ago

My email is fangwei123456@pku.edu.cn. But I will not reply quickly because I am busy in doing other things.

HassanAli545 commented 2 years ago

Thanks for sending me your email. I can ask one quick question here. I trained VGG-6 from spiking jelly on my dataset; it gave me around 89% accuracy. But when I trained Resnet-18, it also gave me about 89%. I was hoping ResNet-18 would provide us with better results, but that was not the case. Your thoughts on this, please.

fangwei123456 commented 2 years ago

A larger/deep network does not guarantee a better accuracy. You can check the training accuracy of two networks. If the training accuracy of spiking resnet-18 is lower than spiking VGG-6, you can try to train a SEW-ResNet-18. Refer to Deep Residual Learning in Spiking Neural Networks for more details about.

HassanAli545 commented 2 years ago

test

Hi @fangwei123456, I have a question related to the feature visualization in SNN. I understand github tutorials shows the features at time step =0 and time step =1. But why we have have 16 * 16 grid of features. Generally, we have one one output after each layer. Your response will be appreciated.

Best

fangwei123456 commented 2 years ago

In this tutorial, the input has shape = [N, 1, H, W], and the first conv layer is c_in=1, c_out=128. We will get the output spikes with shape = [N, 128, H, W], and we can plot it to 8 * 16 grids.