Open mtk380 opened 3 years ago
Hi mtk380! I got the same error with you, and tried to fix it. The error was caused because metadata (kwargs) about dataset_path is not delivered to the dataloader.
So, in fid_evaluation.py (maybe line 32)
def setup_evaluation(dataset_name, generated_dir, data_path, target_size=128, num_imgs=8000): # add data_path
# Only make real images if they haven't been made yet
real_dir = os.path.join('EvalImages', dataset_name + '_real_images_' + str(target_size))
if not os.path.exists(real_dir):
os.makedirs(real_dir)
dataloader, CHANNELS = datasets.get_dataset(dataset_name, img_size=target_size, dataset_path=data_path) # add dataset_path
print('outputting real images...')
output_real_images(dataloader, num_imgs, real_dir)
print('...done')
if generated_dir is not None:
os.makedirs(generated_dir, exist_ok=True)
return real_dir
and, in train.py (maybe near line 366)
if opt.eval_freq > 0 and (discriminator.step + 1) % opt.eval_freq == 0:
generated_dir = os.path.join(opt.output_dir, 'evaluation/generated')
if rank == 0:
fid_evaluation.setup_evaluation(metadata['dataset'], generated_dir, data_path=metadata["dataset_path"], target_size=128) # add data_path
dist.barrier()
ema.store(generator_ddp.parameters())
ema.copy_to(generator_ddp.parameters())
generator_ddp.eval()
fid_evaluation.output_images(generator_ddp, metadata, rank, world_size, generated_dir)
ema.restore(generator_ddp.parameters())
dist.barrier()
if rank == 0:
fid = fid_evaluation.calculate_fid(metadata['dataset'], generated_dir, target_size=128)
with open(os.path.join(opt.output_dir, f'fid.txt'), 'a') as f:
f.write(f'\n{discriminator.step}:{fid}')
torch.cuda.empty_cache()
In my case, it works for me.
when i train the script, this problem occur when running the fid_evaluation.setup_evaluation (Progress to next stage: 2%|▊ | 5000/200000 )