fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.35k stars 239 forks source link

AttributeError: 'ResNet' object has no attribute 'delete_all_unused_submodules' #302

Closed xianyi11 closed 1 year ago

xianyi11 commented 1 year ago

When I run resnet18_cifar10.py in spikingjelly\spikingjelly\activation_based\ann2snn\examples\resnet18_cifar10.py, I encountered the error above. I run my code in jupyter notebook and my code is below. I only changed some path to fit my environment.

import torch
import torchvision
from tqdm import tqdm
import spikingjelly.activation_based.ann2snn as ann2snn

from spikingjelly.activation_based.ann2snn.sample_models import cifar10_resnet

def val(net, device, data_loader, T=None):
    net.eval().to(device)
    correct = 0.0
    total = 0.0
    with torch.no_grad():
        for batch, (img, label) in enumerate(tqdm(data_loader)):
            img = img.to(device)
            if T is None:
                out = net(img)
            else:
                for m in net.modules():
                    if hasattr(m, 'reset'):
                        m.reset()
                for t in range(T):
                    if t == 0:
                        out = net(img)
                    else:
                        out += net(img)
            correct += (out.argmax(dim=1) == label.to(device)).float().sum().item()
            total += out.shape[0]
        acc = correct / total
        print('Validating Accuracy: %.3f' % (acc))
    return acc

def main():
    torch.random.manual_seed(0)
    device = 'cpu'
    dataset_dir = r'D:\learn\SNN\my_spikingjelly\data'
    batch_size = 100
    T = 400

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    model = cifar10_resnet.ResNet18()
    model.load_state_dict(torch.load('SJ-cifar10-resnet18_model-sample.pth',map_location=torch.device('cpu')))

    train_data_dataset = torchvision.datasets.CIFAR10(
        root=dataset_dir,
        train=True,
        transform=transform,
        download=True)
    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_data_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=False)
    test_data_dataset = torchvision.datasets.CIFAR10(
        root=dataset_dir,
        train=False,
        transform=transform,
        download=True)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=test_data_dataset,
        batch_size=50,
        shuffle=True,
        drop_last=False)

    print('ANN accuracy:')
    val(model, device, test_data_loader)
    print('Converting...')
    print(model)
    model_converter = ann2snn.Converter(mode='Max', dataloader=train_data_loader)
    snn_model = model_converter(model)
    print('SNN accuracy:')
    val(snn_model, device, test_data_loader, T=T)

if __name__ == '__main__':
    print('Downloading SJ-cifar10-resnet18_model-sample.pth')
    ann2snn.download_url("https://ndownloader.figshare.com/files/26676110",'./SJ-cifar10-resnet18_model-sample.pth')
    main()

The error is: image

Thanks.

Lyu6PosHao commented 1 year ago

Hi, I run your code in my jupyter notebook, and there is nothing wrong.

The very tiny difference is that I copy the raw codes of cifar10_resnet.py into my notebook cell (as my notebook couldn't find ann2snn.sample_models somehow) . However, I don't think it's the key point :(

My suggestions are:

  1. What's your version information? My pytorch is v1.13. And is your spikingjelly the newest?

  2. _delete_all_unused_submodules is a method of torch.fx.GraphModule_. You can use fx_model=fx.symbolic_trace(model) to get a torch.fx.GraphModule. And the call fx_model.delete_all_unused_submodules() to check if there is any bug.

xianyi11 commented 1 year ago

Thanks for your reply!!!!! as you said, the reason is my Pytorch version is too low to run the code. my pytorch version is 1.8.0. When I change it to 1.13.0, the error is gone.