kyotovision-public / multimodal-material-segmentation

MIT License
53 stars 14 forks source link

from dataloaders import make_data_loader #2

Open OscarMind opened 2 years ago

OscarMind commented 2 years ago

File "test.py", line 9, in from dataloaders import make_data_loader ImportError: cannot import name 'make_data_loader'

jamesyoung0623 commented 2 years ago

Same problem.

wonjunior commented 1 year ago

It seems a file was removed from the /dataloaders directory.

Looking at the diff with the original repository, one workaround is to add a __init__.py file under /dataloaders and implement make_data_loader yourself. Something like this would work:

from torch.utils.data import DataLoader

from .datasets.multimodal_dataset import MultimodalDatasetSegmentation

def make_data_loader(args, **kwargs):
    if args.dataset == 'multimodal_dataset':
        train_set = MultimodalDatasetSegmentation(args, split='train')
        val_set = MultimodalDatasetSegmentation(args, split='val')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        test_loader = None
        return train_loader, val_loader, test_loader, num_class

    else:
        raise NotImplementedError

You can extend the function to support additional datasets...

Best, Ivan

wonjunior commented 1 year ago

You can refer to https://github.com/kyotovision-public/multimodal-material-segmentation/pull/3