Closed bluefier closed 1 year ago
目前我训练的模型在测试集上的准确率只有(71.49%),在训练集上的准确率比较高,在 97%左右。看起来好像有些过拟合,但是改怎样进行调参来优化模型的准确率呢? 跪求大神帮帮。请问有可以参考的训练模型吗?现在的训练参数如下: T = 100, momentum = 0.9, lr = 1e-3, tau =2.0 代码如下:
net = spiking_resnet.spiking_resnet18(pretrained=False, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True) print(net) net.to(args.device) train_dataset = torchvision.datasets.CIFAR10( root=args.data_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True ) test_dataset = torchvision.datasets.CIFAR10( root=args.data_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True ) train_data_loader = data.DataLoader( dataset=train_dataset, batch_size=args.b, shuffle=True, drop_last=True, num_workers=args.j, pin_memory=True ) test_data_loader = data.DataLoader( dataset=test_dataset, batch_size=args.b, shuffle=False, drop_last=False, num_workers=args.j, pin_memory=True ) scaler = None if args.amp: scaler = amp.GradScaler() start_epoch = 0 max_test_acc = -1 optimizer = None if args.opt == 'sgd': optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum) elif args.opt == 'adam': optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) else: raise NotImplementedError(args.opt) # if args.resume: # checkpoint = torch.load(args.resume, map_location='cpu') # net.load_state_dict(checkpoint['net']) # optimizer.load_state_dict(checkpoint['optimizer']) # start_epoch = checkpoint['epoch'] + 1 # max_test_acc = checkpoint['max_test_acc'] out_dir = os.path.join(args.out_dir, f'T{args.T}_b{args.b}_{args.opt}_lr{args.lr}') if args.amp: out_dir += '_amp' if not os.path.exists(out_dir): os.makedirs(out_dir) print(f'Mkdir {out_dir}.') with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: args_txt.write(str(args)) writer = SummaryWriter(out_dir, purge_step=start_epoch) with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: args_txt.write(str(args)) args_txt.write('\n') args_txt.write(' '.join(sys.argv)) encoder = encoding.PoissonEncoder() for epoch in range(start_epoch, args.epochs): start_time = time.time() net.train() train_loss = 0 train_acc = 0 train_samples = 0 for img, label in train_data_loader: optimizer.zero_grad() img = img.to(args.device) label = label.to(args.device) label_onehot = F.one_hot(label, 10).float() if scaler is not None: with amp.autocast(): out_fr = 0. for t in range(args.T): encoded_img = encoder(img) out_fr += net(encoded_img) out_fr = out_fr / args.T loss = F.mse_loss(out_fr, label_onehot) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: out_fr = 0. for t in range(args.T): encoded_img = encoder(img) out_fr += net(encoded_img) out_fr = out_fr / args.T loss = F.mse_loss(out_fr, label_onehot) loss.backward() optimizer.step() train_samples += label.numel() train_loss += loss.item() * label.numel() train_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) train_time = time.time() train_speed = train_samples / (train_time - start_time) train_loss /= train_samples train_acc /= train_samples writer.add_scalar('train_loss', train_loss, epoch) writer.add_scalar('train_acc', train_acc, epoch) net.eval() test_loss = 0 test_acc = 0 test_samples = 0 # begin test with torch.no_grad(): for img, label in test_data_loader: img = img.to(args.device) label = label.to(args.device) label_onehot = F.one_hot(label, 10).float() out_fr = 0. for t in range(args.T): encoded_img = encoder(img) out_fr += net(encoded_img) out_fr = out_fr / args.T loss = F.mse_loss(out_fr, label_onehot) test_samples += label.numel() test_loss += loss.item() * label.numel() test_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) test_time = time.time() test_speed = test_samples / (test_time - train_time) test_loss /= test_samples test_acc /= test_samples writer.add_scalar('test_loss', test_loss, epoch) writer.add_scalar('test_acc', test_acc, epoch) save_max = False print('*****test_acc*****', test_acc) print('*****max_test_acc*****', max_test_acc) if test_acc > max_test_acc: max_test_acc = test_acc save_max = True checkpoint = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'max_test_acc': max_test_acc }
resnet结构的SNN跑CIFAR10可以参考 https://github.com/ikarosy/gated-lif
感谢大佬!!!
目前我训练的模型在测试集上的准确率只有(71.49%),在训练集上的准确率比较高,在 97%左右。看起来好像有些过拟合,但是改怎样进行调参来优化模型的准确率呢? 跪求大神帮帮。请问有可以参考的训练模型吗?现在的训练参数如下: T = 100, momentum = 0.9, lr = 1e-3, tau =2.0 代码如下: