Closed miaodd98 closed 1 year ago
补充:目前经查找问题可能出现在torch.fx中,目前使用pytorch版本为1.9.1
torch.fx给每个变量建立符号追踪,出现proxy往往是原模型代码往往了条件判断。 所以建议检查一下有关“end”的代码,如果涉及了条件判断,就去掉它。
torch.fx给每个变量建立符号追踪,出现proxy往往是原模型代码往往了条件判断。 所以建议检查一下有关“end”的代码,如果涉及了条件判断,就去掉它。
感谢回复!请问是任何涉及到if判断的部分都会存在这类问题是嘛?也就是说模型结构最好就是全是直接不需要条件判断下来是吧
对,这是torch.fx构建计算图时的局限。最好根据先验知识把条件判断去掉。
感谢!torch.fx的问题已经解决了,但是又有一个新问题了 我个人是想直接给已经训练完成的ANN权重通过ann2snn.Converter转换,但是发现converter初始化参数里dataloader设置为None时,报错如下:
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(s)
0it [00:00, ?it/s]
Traceback (most recent call last):
File "demo1.py", line 344, in
这里报错明显问题是dataloader为None导致的,想请问下这里的dataloader必须加载数据集吗?dataloader不为空时再报错就是模型卷积核问题了,这个可以后续处理。
对,必须加载数据集。因为转换的原理,需要找到原来ANN的最大激活值,没数据集就没法找到最大激活值了。 可以去spikejelly的文档简单看下转换的原理。
明白了,感谢解惑!
Issue type
SpikingJelly version
latest
Description
尝试转换YOLOX已训练好的ANN模型时,出现报错:
TypeError: arange(): argument 'end' (position 1) must be Number, not Proxy
完整报错:
Traceback (most recent call last): File "demo1.py", line 343, in
main(exp, args) # 原来测试这里就不用了
File "demo1.py", line 322, in main
snn_model = model_converter(model)
File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, kwargs)
File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\spikingjelly-0.0.0.0.14-py3.8.egg\spikingjelly\activation_based\ann2snn\converter.py", line 104, in forward
File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 859, in symbolic_trace
graph = tracer.trace(root, concrete_args)
File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 571, in trace
self.create_node('output', 'output', (self.create_arg(fn(args)),), {},
File "D:\codes\YOLOX-snn\yolox\models\yolox.py", line 46, in forward
outputs = self.head(fpn_outs)
File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 560, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 391, in call_module
return forward(args, kwargs)
File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\fx\symbolic_trace.py", line 556, in forward
return _orig_module_call(mod, *args, *kwargs)
File "C:\Users\admin\anaconda3\envs\yolo\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(input, **kwargs)
File "D:\codes\YOLOX-snn\yolox\models\yolo_head.py", line 211, in forward
return self.decode_outputs(outputs, dtype=xin[0].type())
File "D:\codes\YOLOX-snn\yolox\models\yolo_head.py", line 239, in decode_outputs
yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
TypeError: arange(): argument 'end' (position 1) must be Number, not Proxy
@Lyu6PosHao
Minimal code to reproduce the error/bug
运行代码,需要配合下载YOLOX权重: python demo1.py image -n yolox-s -c .\yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device gpu
import argparse import os import time from loguru import logger
import cv2
import torch, torchvision
from torchvision import transforms from torch.utils.data import DataLoader from torch.fx import symbolic_trace from yolox.data.data_augment import ValTransform from yolox.data.datasets import COCO_CLASSES from yolox.exp import get_exp from yolox.utils import fuse_model, get_model_info, postprocess, vis from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer, ann2snn # 加进来ANN2SNN的东西
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
def make_parser(): parser = argparse.ArgumentParser("YOLOX Demo!") parser.add_argument( "demo", default="image", help="demo type, eg. image, video and webcam" ) parser.add_argument("-expn", "--experiment-name", type=str, default=None) parser.add_argument("-n", "--name", type=str, default=None, help="model name") # 留着确定模型是哪个
def get_image_list(path): image_names = [] for maindir, subdir, file_name_list in os.walk(path): for filename in file_name_list: apath = os.path.join(maindir, filename) ext = os.path.splitext(apath)[1] if ext in IMAGE_EXT: image_names.append(apath) return image_names
class Predictor(object): def init( self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None, device="cpu", fp16=False, legacy=False, ): self.model = model self.cls_names = cls_names self.decoder = decoder self.num_classes = exp.num_classes self.confthre = exp.test_conf self.nmsthre = exp.nmsthre self.test_size = exp.test_size self.device = device self.fp16 = fp16 self.preproc = ValTransform(legacy=legacy) if trt_file is not None: from torch2trt import TRTModule
def main(exp, args): if not args.experiment_name: args.experiment_name = exp.exp_name
if name == "main": args = make_parser().parse_args() exp = get_exp(args.exp_file, args.name)