fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.32k stars 237 forks source link

使用spikingvgg11识别cifar10的时候,adaptive_avg_pool2d_backward_cuda报错 #554

Closed Jealousc11gx closed 3 months ago

Jealousc11gx commented 3 months ago

Read before creating a new issue

For faster response

You can @ the corresponding developers for your issue. Here is the division:

Features Developers
Neurons and Surrogate Functions fangwei123456
Yanqi-Chen
CUDA Acceleration fangwei123456
Yanqi-Chen
Reinforcement Learning lucifer2859
ANN to SNN Conversion DingJianhao
Lyu6PosHao
Biological Learning (e.g., STDP) AllenYolk
Others Grasshlw
lucifer2859
AllenYolk
Lyu6PosHao
DingJianhao
Yanqi-Chen
fangwei123456

We are glad to add new developers who are volunteering to help solve issues to the above table.

Issue type

SpikingJelly version

0.0.0.0.2

Description

...RuntimeError: adaptive_avg_pool2d_backward_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.

在我尝试用spiking-vgg11 训练cifar10的时候 出现了如下报错 我不知道这里如何修改torch.use_deterministic_algorithms

Minimal code to reproduce the error/bug

import torch
import sys
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet, train_classify, spiking_vgg, parametric_lif_net
from spikingjelly.activation_based.monitor import GPUMonitor
from datetime import datetime

class SResNetTrainer(train_classify.Trainer):
    def preprocess_train_sample(self, args, x: torch.Tensor):
        # define how to process train sample before send it to model
        return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]

    def preprocess_test_sample(self, args, x: torch.Tensor):
        # define how to process test sample before send it to model
        return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]

    def process_model_output(self, args, y: torch.Tensor):
        return y.mean(0)  # return firing rate

    def get_args_parser(self, add_help=True):
        parser = super().get_args_parser()
        parser.add_argument('--T', type=int, help="total time-steps")
        parser.add_argument('--cupy', action="store_true", help="set the neurons to use cupy backend")
        return parser

    def get_tb_logdir_name(self, args):
        now_time = datetime.now()
        return super().get_tb_logdir_name(args) + f'_T{args.T}' + '_' + str(now_time.day) + '_' + str(
            now_time.hour) + '_' + str(now_time.minute)

    def load_data(self, args):
        return self.load_CIFAR10(args)

    def load_model(self, args, num_classes):
        if args.model in spiking_vgg.__all__:
            model = spiking_vgg.__dict__[args.model](pretrained=args.pretrained, spiking_neuron=neuron.LIFNode,
                                                 surrogate_function=surrogate.ATan(), detach_reset=True,
                                                 num_classes=num_classes)
            functional.set_step_mode(model, step_mode='m')
            if args.cupy:
                functional.set_backend(model, 'cupy', neuron.LIFNode)
            return model
        else:
            raise ValueError(f"args.model should be one of {spiking_vgg.__all__}")

if __name__ == "__main__":
    trainer = SResNetTrainer()
    args = trainer.get_args_parser().parse_args()
    trainer.main(args)