Closed Ikarosy closed 2 years ago
Well, this error occurs at loss.backward()
right before functional.reset_net(net)
due to, I think, the inplace
operation somewhere in spikingjelly/clock_driven/surrogate.py
.
functional.reset_net(net)
should be called after loss.backward()
.
Certainly, functional.reset_net(net)
is called after loss.backward()
, but the error occuers at loss.backward()
. So, there is nothing to do with functional.reset_net(net)
. The script failed right after the very first data forwarding of the first batch.
I didn't look into spikingjelly/clock_driven/surrogate.py
. BUT, after setting retain_graph=True
in torch/_tensor
, the following issue occurs:
Traceback (most recent call last):
File "train.py", line 465, in <module>
main()
File "train.py", line 428, in main
train(args, model, device, train_loader, optimizer, epochs)
File "train.py", line 135, in train
loss.backward()
File "/home/anaconda3/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 147, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Above all, I believe some inplace
operations in spikingjelly/clock_driven/surrogate.py
might not be dealt with correctly.
Can you show me your codes?
Sure, but only relevant codes are as follows:
Model:
class CupyNet(nn.Module):
def __init__(self, T=steps):
super().__init__()
self.T = T
self.static_conv = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
)
self.conv = nn.Sequential(
neuron.MultiStepIFNode(surrogate_function=surrogate.ATan(), backend='cupy'),
layer.SeqToANNContainer(
nn.MaxPool2d(2, 2), # 14 * 14
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
),
neuron.MultiStepIFNode(surrogate_function=surrogate.ATan(), backend='cupy'),
layer.SeqToANNContainer(
nn.AdaptiveAvgPool2d((1, 1)), # 7 * 7
nn.Flatten(),
),
)
self.fc = nn.Sequential(
layer.SeqToANNContainer(nn.Linear(128, 128, bias=False)),
neuron.MultiStepIFNode(surrogate_function=surrogate.ATan(), backend='cupy'),
layer.SeqToANNContainer(nn.Linear(128, 10, bias=False)),
neuron.MultiStepIFNode(surrogate_function=surrogate.ATan(), backend='cupy'),
)
def forward(self, x):
x_seq = self.static_conv(x).unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
return self.fc(self.conv(x_seq)).mean(0)
training function:
def train(args, model, device, train_loader, optimizer, epoch, writer, criterion):
t1 = time.time()
Top1, Top5 = 0.0, 0.0
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
data, _ = torch.broadcast_tensors(data, torch.zeros((3 ,) + data.shape))
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
prec1, prec5 = accuracy(output, target, topk=(1, 5))
Top1 += prec1.item() / 100
Top5 += prec5.item() / 100
if batch_idx % args.display_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTop-1 = {:.6f}\tTop-5 = {:.6f}\tTime = {:.6f}'.format(
epoch, batch_idx * len(data / steps), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item(),
Top1 / args.display_interval, Top5 / args.display_interval, time.time() - t1
)
)
Top1, Top5 = 0.0, 0.0
optimizer = torch.optim.SGD(model.parameters(),
lr=args.learning_rate,
momentum=0.9,
weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
criterion = nn.CrossEntropyLoss()
dataset: cifar10
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import torchvision
import time
from spikingjelly.clock_driven import neuron, surrogate, layer, functional
class CupyNet(nn.Module):
def __init__(self, T):
super().__init__()
self.T = T
self.static_conv = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
)
self.conv = nn.Sequential(
neuron.MultiStepIFNode(surrogate_function=surrogate.ATan(), backend='cupy'),
layer.SeqToANNContainer(
nn.MaxPool2d(2, 2), # 14 * 14
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
),
neuron.MultiStepIFNode(surrogate_function=surrogate.ATan(), backend='cupy'),
layer.SeqToANNContainer(
nn.AdaptiveAvgPool2d((1, 1)), # 7 * 7
nn.Flatten(),
),
)
self.fc = nn.Sequential(
layer.SeqToANNContainer(nn.Linear(128, 128, bias=False)),
neuron.MultiStepIFNode(surrogate_function=surrogate.ATan(), backend='cupy'),
layer.SeqToANNContainer(nn.Linear(128, 10, bias=False)),
neuron.MultiStepIFNode(surrogate_function=surrogate.ATan(), backend='cupy'),
)
def forward(self, x):
x_seq = self.static_conv(x).unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
return self.fc(self.conv(x_seq)).mean(0)
def train(args, model, device, train_loader, optimizer, epoch, writer, criterion):
t1 = time.time()
Top1, Top5 = 0.0, 0.0
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
# data, _ = torch.broadcast_tensors(data, torch.zeros((3 ,) + data.shape))
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
functional.reset_net(model)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classify Fashion-MNIST')
parser.add_argument('-T', default=4, type=int, help='simulating time-steps')
parser.add_argument('-device', default='cuda:0', help='device')
parser.add_argument('-data_dir', type=str, help='root dir of Fashion-MNIST dataset', default='/datasets/CIFAR10')
# python test.py -opt SGD -data_dir /userhome/datasets/CIFAR10/ -amp -cupy
args = parser.parse_args()
print(args)
net = CupyNet(T=args.T)
net.to(args.device)
optimizer = torch.optim.SGD(net.parameters(),
lr=0.1,
momentum=0.9,
weight_decay=0
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=64)
criterion = nn.CrossEntropyLoss()
train_set = torchvision.datasets.CIFAR10(
root=args.data_dir,
train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
test_set = torchvision.datasets.CIFAR10(
root=args.data_dir,
train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
train_data_loader = torch.utils.data.DataLoader(
dataset=train_set,
batch_size=128,
shuffle=True,
drop_last=True,
num_workers=4
)
test_data_loader = torch.utils.data.DataLoader(
dataset=test_set,
batch_size=128,
shuffle=True,
drop_last=False,
num_workers=4
)
for epoch in range(64):
train(args, net, args.device, train_data_loader, optimizer, epoch, None, criterion)
print(epoch)
(pytorch-env) wfang@Precision-5820-Tower-X-Series:~/tempdir$ python w1.py
Namespace(T=4, device='cuda:0', data_dir='/datasets/CIFAR10')
Files already downloaded and verified
Files already downloaded and verified
0
1
2
This code works with no error. There are two changes:
# data, _ = torch.broadcast_tensors(data, torch.zeros((3 ,) + data.shape))
functional.reset_net(model)
data, _ = torch.broadcast_tensors(data, torch.zeros((3 ,) + data.shape))
is meant for another model defined by myself. BUT, I wrongly used it to spikingjelly.clock_driven.examples.conv_fashion_mnist
, which has .repeat(self.T, 1, 1, 1, 1)
that functions as data, _ = torch.broadcast_tensors(data, torch.zeros((3 ,) + data.shape))
.
_I forget to comment it out when running spikingjelly.clock_driven.examples.conv_fashion_mnist
_.
Sorry to waste your time and for my wrongly reporting a 'bug' that was actually my coding mistake.
Thanks for your patience and your excellent work 'SpikingJelly'
I ran the
cupynet
in 'spikingjelly.clock_driven.examples.conv_fashion_mnist' with the exact code in the tutorial under the environment:BUT, an annoying bug shows something wrong with the grad. :
While, the model is spikingjelly.clock_driven.examples.conv_fashion_mnist copied from the tutorial and my model-training script written as:
It works well training other models.
I've tried changing the
backend = 'cupy'
into'torch'
orneuron.MultiStepIFNode
into any other nodes, setting theretain_graph=True
inneuron_kernel.py, line 335
or in the packagetorch
.If
backend = torch
, the bug report shows:The bug just wont be solved. Do u have any clue where the issue lies and how to solve it