fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.29k stars 236 forks source link

ann2cnn转换报错 #410

Closed miaodd98 closed 1 year ago

miaodd98 commented 1 year ago

Issue type

SpikingJelly version

latest

Description

尝试转换YOLOX已训练好的ANN模型时,出现报错:

TypeError: arange(): argument 'end' (position 1) must be Number, not Proxy

完整报错:

Traceback (most recent call last): File "demo1.py", line 343, in main(exp, args) # 原来测试这里就不用了 File "demo1.py", line 322, in main snn_model = model_converter(model) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(*input, kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\ann2snn\converter.py", line 104, in forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 859, in symbolic_trace graph = tracer.trace(root, concrete_args) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 571, in trace self.create_node('output', 'output', (self.create_arg(fn(args)),), {}, File "D:\codes\YOLOX-snn\yolox\models\yolox.py", line 46, in forward outputs = self.head(fpn_outs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 560, in module_call_wrapper return self.call_module(mod, forward, args, kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 391, in call_module return forward(args, kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 556, in forward return _orig_module_call(mod, *args, *kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(input, **kwargs) File "D:\codes\YOLOX-snn\yolox\models\yolo_head.py", line 211, in forward return self.decode_outputs(outputs, dtype=xin[0].type()) File "D:\codes\YOLOX-snn\yolox\models\yolo_head.py", line 239, in decode_outputs yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) TypeError: arange(): argument 'end' (position 1) must be Number, not Proxy

@Lyu6PosHao

Minimal code to reproduce the error/bug

运行代码,需要配合下载YOLOX权重: python demo1.py image -n yolox-s -c .\yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device gpu

import argparse import os import time from loguru import logger

import cv2

import torch, torchvision

from torchvision import transforms from torch.utils.data import DataLoader from torch.fx import symbolic_trace from yolox.data.data_augment import ValTransform from yolox.data.datasets import COCO_CLASSES from yolox.exp import get_exp from yolox.utils import fuse_model, get_model_info, postprocess, vis from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer, ann2snn # 加进来ANN2SNN的东西

IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]

def make_parser(): parser = argparse.ArgumentParser("YOLOX Demo!") parser.add_argument( "demo", default="image", help="demo type, eg. image, video and webcam" ) parser.add_argument("-expn", "--experiment-name", type=str, default=None) parser.add_argument("-n", "--name", type=str, default=None, help="model name") # 留着确定模型是哪个

parser.add_argument(
    "--path", default="./assets/dog.jpg", help="path to images or video"            # 这个可以不用,跑到吗用的
)
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
parser.add_argument(
    "--save_result",
    action="store_true",
    help="whether to save the inference result of image/video",
)

# exp file
parser.add_argument(
    "-f",
    "--exp_file",
    default=None,
    type=str,
    help="please input your experiment description file",
)
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")       # 权重路径
parser.add_argument(
    "--device",
    default="cpu",
    type=str,
    help="device to run our model, can either be cpu or gpu",
)
parser.add_argument("--conf", default=0.3, type=float, help="test conf")
parser.add_argument("--nms", default=0.3, type=float, help="test nms threshold")
parser.add_argument("--tsize", default=None, type=int, help="test img size")
parser.add_argument(
    "--fp16",
    dest="fp16",
    default=False,
    action="store_true",
    help="Adopting mix precision evaluating.",
)
parser.add_argument(
    "--legacy",
    dest="legacy",
    default=False,
    action="store_true",
    help="To be compatible with older versions",
)
parser.add_argument(
    "--fuse",
    dest="fuse",
    default=False,
    action="store_true",
    help="Fuse conv and bn for testing.",
)
parser.add_argument(
    "--trt",
    dest="trt",
    default=False,
    action="store_true",
    help="Using TensorRT model for testing.",
)
return parser

def get_image_list(path): image_names = [] for maindir, subdir, file_name_list in os.walk(path): for filename in file_name_list: apath = os.path.join(maindir, filename) ext = os.path.splitext(apath)[1] if ext in IMAGE_EXT: image_names.append(apath) return image_names

class Predictor(object): def init( self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None, device="cpu", fp16=False, legacy=False, ): self.model = model self.cls_names = cls_names self.decoder = decoder self.num_classes = exp.num_classes self.confthre = exp.test_conf self.nmsthre = exp.nmsthre self.test_size = exp.test_size self.device = device self.fp16 = fp16 self.preproc = ValTransform(legacy=legacy) if trt_file is not None: from torch2trt import TRTModule

        model_trt = TRTModule()
        model_trt.load_state_dict(torch.load(trt_file))

        x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
        self.model(x)
        self.model = model_trt

def inference(self, img):
    img_info = {"id": 0}
    if isinstance(img, str):
        img_info["file_name"] = os.path.basename(img)
        img = cv2.imread(img)
    else:
        img_info["file_name"] = None

    height, width = img.shape[:2]
    img_info["height"] = height
    img_info["width"] = width
    img_info["raw_img"] = img

    ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
    img_info["ratio"] = ratio

    img, _ = self.preproc(img, None, self.test_size)
    img = torch.from_numpy(img).unsqueeze(0)
    img = img.float()
    if self.device == "gpu":
        img = img.cuda()
        if self.fp16:
            img = img.half()  # to FP16

    with torch.no_grad():
        t0 = time.time()
        outputs = self.model(img)
        if self.decoder is not None:
            outputs = self.decoder(outputs, dtype=outputs.type())
        outputs = postprocess(
            outputs, self.num_classes, self.confthre,
            self.nmsthre, class_agnostic=True
        )
        logger.info("Infer time: {:.4f}s".format(time.time() - t0))
    return outputs, img_info

def visual(self, output, img_info, cls_conf=0.35):
    ratio = img_info["ratio"]
    img = img_info["raw_img"]
    if output is None:
        return img
    output = output.cpu()

    bboxes = output[:, 0:4]

    # preprocessing: resize
    bboxes /= ratio

    cls = output[:, 6]
    scores = output[:, 4] * output[:, 5]

    vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
    return vis_res

def main(exp, args): if not args.experiment_name: args.experiment_name = exp.exp_name

file_name = os.path.join(exp.output_dir, args.experiment_name)
os.makedirs(file_name, exist_ok=True)

vis_folder = None
if args.save_result:
    vis_folder = os.path.join(file_name, "vis_res")
    os.makedirs(vis_folder, exist_ok=True)

if args.trt:
    args.device = "gpu"

logger.info("Args: {}".format(args))

if args.conf is not None:
    exp.test_conf = args.conf
if args.nms is not None:
    exp.nmsthre = args.nms
if args.tsize is not None:
    exp.test_size = (args.tsize, args.tsize)

model = exp.get_model()
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))

if args.device == "gpu":
    model.cuda()
    if args.fp16:
        model.half()  # to FP16
model.eval()

if not args.trt:
    if args.ckpt is None:
        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
    else:
        ckpt_file = args.ckpt
    logger.info("loading checkpoint")
    ckpt = torch.load(ckpt_file, map_location="cpu")
    # load the model state dict
    model.load_state_dict(ckpt["model"])        # 模型读进来了
    logger.info("loaded checkpoint done.")

if args.fuse:
    logger.info("\tFusing model...")
    model = fuse_model(model)

if args.trt:
    assert not args.fuse, "TensorRT model is not support model fusing!"
    trt_file = os.path.join(file_name, "model_trt.pth")
    assert os.path.exists(
        trt_file
    ), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
    model.head.decode_in_inference = False
    decoder = model.head.decode_outputs
    logger.info("Using TensorRT to inference")
else:
    trt_file = None
    decoder = None

# 先试一下ANN2SNN转换的
dataset_dir = 'D:\datasets'    # mnist数据集位置
train_data_dataset = torchvision.datasets.MNIST(
    root=dataset_dir,
    train=True,
    transform=transforms.ToTensor(),
    download=True)
train_data_loader = DataLoader(
    dataset=train_data_dataset,
    batch_size=2,
    shuffle=True,
    drop_last=False)
model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)
snn_model.graph.print_tabular()

if name == "main": args = make_parser().parse_args() exp = get_exp(args.exp_file, args.name)

main(exp, args)
miaodd98 commented 1 year ago

补充:目前经查找问题可能出现在torch.fx中,目前使用pytorch版本为1.9.1

Lyu6PosHao commented 1 year ago

torch.fx给每个变量建立符号追踪,出现proxy往往是原模型代码往往了条件判断。 所以建议检查一下有关“end”的代码,如果涉及了条件判断,就去掉它。

miaodd98 commented 1 year ago

torch.fx给每个变量建立符号追踪,出现proxy往往是原模型代码往往了条件判断。 所以建议检查一下有关“end”的代码,如果涉及了条件判断,就去掉它。

感谢回复!请问是任何涉及到if判断的部分都会存在这类问题是嘛?也就是说模型结构最好就是全是直接不需要条件判断下来是吧

Lyu6PosHao commented 1 year ago

对,这是torch.fx构建计算图时的局限。最好根据先验知识把条件判断去掉。

miaodd98 commented 1 year ago

感谢!torch.fx的问题已经解决了,但是又有一个新问题了 我个人是想直接给已经训练完成的ANN权重通过ann2snn.Converter转换,但是发现converter初始化参数里dataloader设置为None时,报错如下:

return torch.from_numpy(parsed.astype(m[2], copy=False)).view(s) 0it [00:00, ?it/s] Traceback (most recent call last): File "demo1.py", line 344, in main(exp, args) # 原来测试这里就不用了 File "demo1.py", line 323, in main snn_model = model_converter(model) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl return forward_call(input, **kwargs) File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\ann2snn\converter.py", line 108, in forward File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\tqdm\std.py", line 1178, in iter for obj in iterable: TypeError: 'NoneType' object is not iterable

这里报错明显问题是dataloader为None导致的,想请问下这里的dataloader必须加载数据集吗?dataloader不为空时再报错就是模型卷积核问题了,这个可以后续处理。

Lyu6PosHao commented 1 year ago

对,必须加载数据集。因为转换的原理,需要找到原来ANN的最大激活值,没数据集就没法找到最大激活值了。 可以去spikejelly的文档简单看下转换的原理。

miaodd98 commented 1 year ago

明白了,感谢解惑!