Closed brando90 closed 1 year ago
Hello @brando90, I think you can just concatenate the pytorch datasets directly without using learn2learn.
Hello @brando90, I think you can just concatenate the pytorch datasets directly without using learn2learn.
Hi Seb, I think this is not the entire solution because one still needs to concatenate and re-index the class labels -- which the ConcatDataset
does not do but your wonderful union data set does.
would doing union.dataset
give me what I want i.e. concatenation of data sets + the relabeling done correctly?
for it to work with any standard pytorch data set I think the workflow would be:
pytorch dataset -> l2l meta data set -> union data set -> .dataset field -> data loader
for l2l data sets:
l2l meta data set -> union data set -> .dataset field -> data loader
odd the union data set's iterator doesn't work:
assert isinstance(dataset, Dataset), f'Expect dataset to be of type Dataset but got {type(dataset)=}.'
counts: dict = {}
iter_dataset = iter(dataset)
for datapoint in iter_dataset:
x, y = datapoint
see:
Connected to pydev debugger (build 221.5080.212)
/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/site-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1491, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 240, in <module>
# check_cifar100_is_100_in_usl()
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 223, in check_mi_usl
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 102, in get_relabling_counts
x, y = datapoint
TypeError: cannot unpack non-iterable NoneType object
really weird because when I index it it does return something:
len(union[0])
2
maybe the iterators next or something doesn't call get item or something?
I wanted to do this (ref: https://arxiv.org/abs/1903.03096) :
would it work to do something like:
if union data sets have the right interface/api it should work. Currently I have a weird bug that makes no sense:
but should this work?
I could test this code: