fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.22k stars 233 forks source link

网络在训练过程中loss不变化 #513

Open bigtigerisme opened 3 months ago

bigtigerisme commented 3 months ago

上面这段代码每次迭代出来的额loss都是一样的,参数没有进行更新,结果如下图所示

for epoch in range(5):
    start_time = time.time()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    counter = 0
    #with torch.no_grad():
    for img, label in train_loader:
        optimizer_gd.zero_grad()
        optimizer_stdp.zero_grad()
        label_onehot = F.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(100):
            encoded_img = encoder(img)
            out_fr += net(encoded_img)
        out_fr = out_fr / 100
        loss = F.mse_loss(out_fr, label_onehot)
        loss.backward()

        optimizer_stdp.zero_grad()

        for i in range(stdp_learners.__len__()):
            stdp_learners[i].step(on_grad=True)

        optimizer_gd.step()
        optimizer_stdp.step()

        for i in range(stdp_learners.__len__()):
            stdp_learners[i].reset()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(net)

        #print(f"Iteration: {counter}")
        counter += 1

        print('Iteration:',counter,'epoch:',epoch,'train_loss:',train_loss)
        print('Iteration:',counter,'epoch:',epoch,'train_speed:',train_speed)
        print('Iteration:',counter,'epoch:',epoch,'train_acc:',train_acc)

    train_time = time.time()
    train_speed = train_samples / (train_time - start_time)
    train_loss /= train_samples
    train_acc /= train_samples

    print('epoch',epoch,'train_loss',train_loss)
    print('epoch',epoch,'train_speed',train_speed)
    print('epoch',epoch,'train_acc',train_acc)

    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0
    with torch.no_grad():
        for img, label in test_loader:
            label_onehot = F.one_hot(label, 10).float()
            out_fr = 0.
            for t in range(100):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr / 100
            loss = F.mse_loss(out_fr, label_onehot)

            test_samples += label.numel()
            test_loss += loss.item() * label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)

    test_time = time.time()
    test_speed = test_samples / (test_time - train_time)
    test_loss /= test_samples
    test_acc /= test_samples

    print('epoch',epoch,'test_loss',test_loss)
    print('epoch',epoch,'test_speed',test_speed)
    print('epoch',epoch,'test_acc',test_acc)

image

fangwei123456 commented 3 months ago

去掉optimizer_stdp试试,stdp学习器作为无监督的学习方法,不保证(大概率也不会)提升性能

bigtigerisme commented 3 months ago

你好,我把stdp删掉后,修改了一下优化器,确实可以运行了。想要问您一下如果我想要将网络中的权重量化为几比特的二进制数,有相关的方法吗?python中有相关的库推荐吗?

On Fri, Mar 29, 2024 at 5:32 PM fangwei123456 @.***> wrote:

去掉optimizer_stdp试试,stdp学习器作为无监督的学习方法,不保证(大概率也不会)提升性能

— Reply to this email directly, view it on GitHub https://github.com/fangwei123456/spikingjelly/issues/513#issuecomment-2026944771, or unsubscribe https://github.com/notifications/unsubscribe-auth/A65E63VBFP66PRSZDCDEWTTY2UYLBAVCNFSM6AAAAABFODQZHSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMRWHE2DINZXGE . You are receiving this because you authored the thread.Message ID: @.***>

fangwei123456 commented 3 months ago

可以试试框架里面的量化器

https://github.com/fangwei123456/spikingjelly/blob/master/spikingjelly/activation_based/quantize.py

另外pytorch好像自带量化相关的模块了