lucidrains / lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
MIT License
1.62k stars 220 forks source link

Torchvision Assertion error while importing custom data #131

Closed demiahmed closed 2 years ago

demiahmed commented 2 years ago

Getting this error when trying to run the script.

Traceback (most recent call last):
  File "C:\Anaconda3\envs\aerialweb\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "C:\Anaconda3\envs\aerialweb\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "C:\Anaconda3\envs\aerialweb\Scripts\lightweight_gan.exe\__main__.py", line 7, in <module>
  File "C:\Anaconda3\envs\aerialweb\lib\site-packages\lightweight_gan\cli.py", line 195, in main
    fire.Fire(train_from_folder)
  File "C:\Anaconda3\envs\aerialweb\lib\site-packages\fire\core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "C:\Anaconda3\envs\aerialweb\lib\site-packages\fire\core.py", line 471, in _Fire
    target=component.__name__)
  File "C:\Anaconda3\envs\aerialweb\lib\site-packages\fire\core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "C:\Anaconda3\envs\aerialweb\lib\site-packages\lightweight_gan\cli.py", line 186, in train_from_folder
    run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash)
  File "C:\Anaconda3\envs\aerialweb\lib\site-packages\lightweight_gan\cli.py", line 63, in run_training
    model.set_data_src(data)
  File "C:\Anaconda3\envs\aerialweb\lib\site-packages\lightweight_gan\lightweight_gan.py", line 1145, in set_data_src
    self.dataset = ImageDataset(folder, self.image_size, transparent = self.transparent, greyscale = self.greyscale, aug_prob = self.dataset_aug_prob)
  File "C:\Anaconda3\envs\aerialweb\lib\site-packages\lightweight_gan\lightweight_gan.py", line 391, in __init__
    transforms.Lambda(convert_image_fn),
  File "C:\Anaconda3\envs\aerialweb\lib\site-packages\torchvision\transforms\transforms.py", line 228, in __init__
    assert isinstance(lambd, types.LambdaType)
AssertionError