fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.22k stars 233 forks source link

将脉冲神经元添加到其他模型时,训练时发生错误 #531

Closed CaoJu600 closed 2 months ago

CaoJu600 commented 2 months ago

我用的是原版的yolov9,修改了下面几处: train.py的train函数读取模型的部分

Model

check_suffix(weights, '.pt')  # check weights
pretrained = weights.endswith('.pt')
if pretrained:
    with torch_distributed_zero_first(LOCAL_RANK):
        weights = attempt_download(weights)  # download if not found locally
    ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
    model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
    exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
    csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
    csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # intersect
    model.load_state_dict(csd, strict=False)  # load
    model2.load_state_dict(csd, strict=False)  # load
    LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}')  # report
else:
    model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
    model2 = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # 修改1,额外创建了一个副本模型
amp = check_amp(model)  # check AMP

......

EMA

ema = ModelEMA(model2) if RANK in {-1, 0} else None #修改2,把副本模型传给EMA模块

torch_utils.py的ModelEMA模块 class ModelEMA: """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models Keeps a moving average of everything in the model state_dict (parameters and buffers) For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage """

def __init__(self, model, decay=0.9999, tau=2000, updates=0):
    # Create EMA
    #print(model)
    self.ema = de_parallel(model).eval()  #修改3,添加了脉冲神经元的模型不能被deepcopy,不清楚原因,所以前面直接创建了个副本传进来
    #self.ema = deepcopy(de_parallel(model)).eval()  #这句注释掉了
    self.updates = updates  # number of EMA updates
    self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
    for p in self.ema.parameters():
        p.requires_grad_(False)

common.py的Conv模块,将默认的激活函数换成了IF神经元 class Conv(nn.Module):

Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)

default_act = nn.SiLU()  # default activation

def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
    super().__init__()
    self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
    self.bn = nn.BatchNorm2d(c2)
    #self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
    self.act = neuron.IFNode() #修改4,这里

def forward(self, x):
    return self.act(self.bn(self.conv(x)))

def forward_fuse(self, x):
    return self.act(self.conv(x))

做完上述修改后,train.py可以正常加载模型开始训练(用的是gelan.yaml),但开始后报错 File "train.py", line 314, in train scaler.scale(loss).backward() ...... RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument weight in method wrapper_CUDA__native_batch_norm_backward)

之前好像也有人遇到过类似的问题#210,请问最终有找到原因吗,谢谢!

CaoJu600 commented 2 months ago

找到问题了,没有重置神经元 def forward(self, x): out = self.act(self.bn(self.conv(x))) self.act.reset() return out