Open zw81929 opened 2 weeks ago
abstract_dataloader.py 中如下部分的代码可能是有问题,self.sample_size 没有初始化
def __init__(self, config, dataset, sampler, shuffle=False):
self.shuffle = shuffle
self.config = config
self._dataset = dataset
self._sampler = sampler
self._batch_size = self.step = self.model = None
self._init_batch_size_and_step()
index_sampler = None
self.generator = torch.Generator()
self.generator.manual_seed(config["seed"])
self.transform = construct_transform(config)
self.is_sequential = config["MODEL_TYPE"] == ModelType.SEQUENTIAL
if not config["single_spec"]:
index_sampler = torch.utils.data.distributed.DistributedSampler(
list(range(self.sample_size)), shuffle=shuffle, drop_last=False
)
self.step = max(1, self.step // config["world_size"])
shuffle = False
super().__init__(
dataset=list(range(self.sample_size)),
batch_size=self.step,
collate_fn=self.collate_fn,
num_workers=config["worker"],
shuffle=shuffle,
sampler=index_sampler,
generator=self.generator,
)
报错信息如下
这块