learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.61k stars 350 forks source link

how to create a normal pytorch data loader from a concatenation of all the data? #355

Closed brando90 closed 1 year ago

brando90 commented 1 year ago

I wanted to do this (ref: https://arxiv.org/abs/1903.03096) :

"The non-episodic baselines are trained to solve the large classification problem that results from ‘concatenating’ the training classes of all datasets."

would it work to do something like:

  1. get the normal pytorch datasets
  2. wrap the pytorch data sets in a l2l meta-data set so we get a list of meta-datasets
  3. pass those meta-data sets to a union data
  4. then pass that to a normal pytorch data loader

if union data sets have the right interface/api it should work. Currently I have a weird bug that makes no sense:

Traceback (most recent call last):
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'get_omniglot_datasets.<locals>.<lambda>'

but should this work?

I could test this code:

        train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
        train = l2l.data.MetaDataset(train)
        valid = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="validation")
        valid = l2l.data.MetaDataset(valid)
        test = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="test")
        test = l2l.data.MetaDataset(test)
        union = UnionMetaDataset([train, valid, test])
        assert len(union.labels) == 100
        union_loader = DataLoader(union)
       next(iter(union_loader))
seba-1511 commented 1 year ago

Hello @brando90, I think you can just concatenate the pytorch datasets directly without using learn2learn.

brando90 commented 1 year ago

Hello @brando90, I think you can just concatenate the pytorch datasets directly without using learn2learn.

Hi Seb, I think this is not the entire solution because one still needs to concatenate and re-index the class labels -- which the ConcatDataset does not do but your wonderful union data set does.

brando90 commented 1 year ago

would doing union.dataset give me what I want i.e. concatenation of data sets + the relabeling done correctly?

brando90 commented 1 year ago

for it to work with any standard pytorch data set I think the workflow would be:

pytorch dataset -> l2l meta data set -> union data set -> .dataset field -> data loader

for l2l data sets:

l2l meta data set -> union data set -> .dataset field -> data loader
brando90 commented 1 year ago

odd the union data set's iterator doesn't work:

    assert isinstance(dataset, Dataset), f'Expect dataset to be of type Dataset but got {type(dataset)=}.'
    counts: dict = {}
    iter_dataset = iter(dataset)
    for datapoint in iter_dataset:
        x, y = datapoint

see:

Connected to pydev debugger (build 221.5080.212)
/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1491, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 240, in <module>
    # check_cifar100_is_100_in_usl()
  File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 223, in check_mi_usl

  File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 102, in get_relabling_counts
    x, y = datapoint
TypeError: cannot unpack non-iterable NoneType object
brando90 commented 1 year ago

really weird because when I index it it does return something:

len(union[0])
2

maybe the iterators next or something doesn't call get item or something?