ZK-Zhou / spikformer

ICLR 2023, Spikformer: When Spiking Neural Network Meets Transformer
MIT License
277 stars 43 forks source link

Error of "enumerate(loader) " in line 647 of "imagenet/train.py" #6

Closed navfour closed 1 year ago

navfour commented 1 year ago

在 "train.py"的647行,需要 "enumerate(loader)"来遍历取出“batch_idx,(input, target)两个变量。 但是会出现timm/data/transforms.py报错”line 18, in callnp_img = np.array(pil_img, dtype=np.uint8)“,”TypeError: array() takes 1 positional argument but 2 were given“。

不知道是不是timm的版本库问题或是loader数据失败。其中,timm为0.5.4, "train.py"647行处断点”loader“的数据如下:

result<timm.data.loader.PrefetchLoader object at 0x7f8f79a7d100> dataset: <timm.data.dataset.ImageDataset object at 0x7f8ffeb20040> fp16: False loader: <torch.utils.data.dataloader.DataLoader object at 0x7f9002cd8d30> mean: tensor([[[[123.6750]], [[116.2800]], [[103.5300]]]], device='cuda:0') //[torch.Size([1, 3, 1, 1])] mixup_enabled: True random_erasing: RandomErasing(p=0.25, mode=const, count=(1, 1)) sampler: <torch.utils.data.sampler.RandomSampler object at 0x7f8f79a7d040> std: tensor([[[[58.3950]],[[57.1200]],[[57.3750]]]], device='cuda:0')//[torch.Size([1, 3, 1, 1])]

ZK-Zhou commented 1 year ago

torch环境是多少呢?

navfour commented 1 year ago

torch环境是多少呢?

torch 1.7.1+cu110 torchvision 0.8.2+cu101 确实可能是torch版本低了,我下午更新一下torch版本再试试,谢谢