mboudiaf / pytorch-meta-dataset

A non-official 100% PyTorch implementation of META-DATASET benchmark for few-shot classification
59 stars 9 forks source link

how to compute data set size form tfrecrods for mds (within python)? #26

Open brando90 opened 1 year ago

patricks-lab commented 1 year ago

I'm also collaborating on brando on this issue and found a potential fix. I was wondering if my potential fix might solve this issue.

To summarize the issue, initially we tried using len(metadataset_pipeline) where metadataset_pipeline is a object in the form metadataset_pipeline = pipeline.make_episodic_pipeline(...). However, as you may know the pipeline is a PyTorch IterableDataset derived from TFRecords that only implements __iter__(), so we couldn't call len() on it.

I did just find a solution around this issue using the dataset_spec.images_per_class metadata field for a given dataset, which returns a dictionary that maps a given class to the number of images in the class. Then, I sum up all the image counts belonging to a certain split via dataset_spec.get_classes(Split[split]).

I was wondering if the following code would work according to your API. I have provided a example run of my snippet below:

def get_num_images(args, split: str = 'VALID'):
    # first we want to get the sources to figure out which datasets we use
    data_config = config_lib.DataConfig(args)
    datasets = data_config.sources
    num_images = 0

    for dataset_name in datasets:
        dataset_records_path = os.path.join(data_config.path, dataset_name)
        dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)

        all_class_sizes = dataset_spec.images_per_class

        # let's get only the class sizes of our split
        class_set = dataset_spec.get_classes(Split[split])

        for c in class_set:
            # ignore classes that have less images than needed for our n-way k-shot task
            if (all_class_sizes[c] >= args.min_examples_in_class):
                num_images += all_class_sizes[c]

    return num_images

args: Namespace = parse_args_standard_sl() # our meta-dataset configuration args are in this
args.sources = ['dtd','cu_birds']

print(get_num_images(args, 'TRAIN')) # outputs 12199
 print(get_num_images(args, 'VALID')) # outputs  2619
 print(get_num_images(args, 'TEST')) # outputs 2610