AntreasAntoniou / HowToTrainYourMAMLPytorch

The original code for the paper "How to train your MAML" along with a replication of the original "Model Agnostic Meta Learning" (MAML) paper in Pytorch.
https://arxiv.org/abs/1810.09502
Other
773 stars 137 forks source link

Unable to use Custom Dataset #19

Open alex-service-ml opened 5 years ago

alex-service-ml commented 5 years ago

maybe_unzip_dataset in HowToTrainYourMAMLPytorch/utils/dataset_tools.py doesn't appear to support datasets other than omniglot and mini-ImageNet. I have created my own dataset (following the two-layer directory structure) with 100 classes. Following the instructions in the README, I created a config for the custom dataset(+experiment) to run. I then generated the config and the script, then attempted to run the script. I'm met with the following error:

...
inner_loop_optimizer.names_learning_rates_dict.layer_dict-linear-weights torch.Size([6]) cuda:0 True
inner_loop_optimizer.names_learning_rates_dict.layer_dict-linear-bias torch.Size([6]) cuda:0 True
datasets/my_dataset
count stuff________________________________________ 13190
datasets/my_dataset
Not found dataset folder structure.. searching for .tar.bz2 file
Traceback (most recent call last):
  File "train_maml_system.py", line 12, in <module>
    maybe_unzip_dataset(args=args)
  File "/home/aservice/HowToTrainYourMAMLPytorch/utils/dataset_tools.py", line 44, in maybe_unzip_dataset
    maybe_unzip_dataset(args)
  File "/home/aservice/HowToTrainYourMAMLPytorch/utils/dataset_tools.py", line 19, in maybe_unzip_dataset
    "place dataset in datasets folder as explained in README".format(os.path.abspath(zip_directory))
AssertionError: /home/aservice/HowToTrainYourMAMLPytorch/datasets/my_dataset.tar.bz2 dataset zip file not foundplace dataset in datasets folder as explained in README

I suspect this section to be the issue, as it seems my dataset gets counted, then deleted, then the function calls itself again and obviously no longer sees the dataset.

In addition to the issue above, I wanted to ask: do all tasks/classes have to have the same number of samples for training, e.g. 600? Or can one task have 600 samples and another have 257?

AntreasAntoniou commented 5 years ago

I see. You are correct. I will amend the dataset_tools.py file to ensure that new datasets are treated differently.

As for your second question, since I am using dictionaries to store the samples, in theory it should be able to handle classes of variable amounts of samples.

AntreasAntoniou commented 5 years ago

Just updated the dataset_tools.py file. Have another go at this and let me know if it works.

alex-service-ml commented 5 years ago

Wow, thank you. I'll be giving it another try today and I'll let you know!

alex-service-ml commented 5 years ago

I identified a couple more issues with using a custom dataset. I decided to use Fashion MNIST, which is grayscale (28, 28, 1), but load_image in data.py defaults to converting to 3-channel, which then causes shape issues during the forward pass. In a similar vein, get_transforms_for_dataset doesn't return anything if using a custom dataset name. Fixing these two issues, it seems fashionmnist has kicked off training.

alex-service-ml commented 5 years ago

A couple more issues in generate_configs.py: must create a relevant entry in hyper_config_dict and add a relevant elif entry to assign search_name.

tnaren3 commented 4 years ago

Hello, I know it's been a while but if you managed to get a custom dataset working, would you mind explaining the changes you had to make? I'm a bit confused on what exactly the entries in hyper_config_dict refer to and how I should adjust them to my dataset. Also, how did you adjust load_image to account for a single channel image?