06/12 03:27:47 PM - ==> Starting training with 1 nodes x 1 GPUs
Traceback (most recent call last):
File "/home/ubuntu/code_space/ADer/run.py", line 31, in <module>
main()
File "/home/ubuntu/code_space/ADer/run.py", line 27, in main
trainer.run()
File "/home/ubuntu/code_space/ADer/trainer/_base_trainer.py", line 252, in run
self.train()
File "/home/ubuntu/code_space/ADer/trainer/_base_trainer.py", line 177, in train
train_data = next(train_loader)
File "/home/ubuntu/anaconda3/envs/ADer/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
File "/home/ubuntu/anaconda3/envs/ADer/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
return self._process_data(data)
File "/home/ubuntu/anaconda3/envs/ADer/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
data.reraise()
File "/home/ubuntu/anaconda3/envs/ADer/lib/python3.9/site-packages/torch/_utils.py", line 694, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/ubuntu/anaconda3/envs/ADer/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/ubuntu/anaconda3/envs/ADer/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ubuntu/anaconda3/envs/ADer/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ubuntu/code_space/ADer/data/ad_dataset.py", line 919, in __getitem__
anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item()
RuntimeError: random_ expects 'from' to be less than 'to', but got from=0 >= to=0
错误出现在ad_dataset.py中919行。
def __getitem__(self, index):
if self.train:
idx = torch.randint(0, len(self.data_all), (1,)).item()
anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item()
data = self.data_all[idx]
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
data['specie_name'], data['anomaly']
image, augmented_image, anomaly_mask, has_anomaly = self.transform_image(os.path.join(self.root, img_path),
self.anomaly_source_paths[
anomaly_source_idx])
image = self.transform(image) if self.transform is not None else image
augmented_image = self.transform(augmented_image) if self.transform is not None else augmented_image
sample = {'img': image, "img_mask": anomaly_mask, 'cls_name': cls_name,
'augmented_image': augmented_image, 'anomaly': has_anomaly}
return sample
我使用
destseg_256_100e.py
,模型为destseg
数据集为mvtec
,已经生成了meta.json
。其中一个参数
self.data.anomaly_source_path = 'data/dtd/images/'
中,这个参数不知道如何设置,或者如何获得这个文件。导致了一下错误:
错误出现在
ad_dataset.py
中919行。anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item()
中self.anomaly_source_paths
为0。 请问应该如何做呢。