Otherwise when running main.py you will get this error:
Traceback (most recent call last):
File "main.py", line XXX, in
main(
File "main.py", line XXX, in main
train_loop(
File "/home/diffusion_models3/segmentation-guided-diffusion/training.py", line 87, in train_loop
for step, batch in enumerate(train_dataloader):
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>
You should add these two lines in
main.py
after line 237 because the dataset is not being processed using transform when using --segmentation_guided .dataset_train.set_transform(transform) dataset_eval.set_transform(transform)
Otherwise when running main.py you will get this error: Traceback (most recent call last): File "main.py", line XXX, in
main(
File "main.py", line XXX, in main
train_loop(
File "/home/diffusion_models3/segmentation-guided-diffusion/training.py", line 87, in train_loop
for step, batch in enumerate(train_dataloader):
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>