OML-Team / open-metric-learning

Metric learning and retrieval pipelines, models and zoo.
https://open-metric-learning.readthedocs.io/en/latest/index.html
Apache License 2.0
833 stars 57 forks source link

Allow changing Dataset (class) in Pipelines #528

Open AlekseySh opened 2 months ago

AlekseySh commented 2 months ago

Let's allow changing Dataset in Pipelines. It also assumes we need registry for Datasets.

Note, let's keep back compatibility with the previous format. We can have a condition which checks if the dataset is in the old format. For example, if one of the old keys is presented (like dataframe_name or dataset_root). If you see such keys, first, you need to reorgonize yaml/dict, second, you can process it by an updated parser.

leoromanovich commented 3 weeks ago

I've started work. WIP PR will be attached shortly.

AlekseySh commented 3 weeks ago

@leoromanovich great, waiting for it

leoromanovich commented 3 weeks ago

Start work here: #585

AlekseySh commented 3 weeks ago

Let's start with the first PR, where we don't add texts support, but refactor the way of processing images datasets. Particularly, we had get_retrieval_datasets function that was hardcoded, but now we introduce registry on functions like this.

Registry

DATASETS_BUILDER_REGISTRY = {"oml_img_datasets": build_img_dataset, "oml_txt_datasets": build_txt_dataset}

def build_img_dataset(cfg) -> (IQGLD | ILD):
    df = pd.read_csv(cfg["df_path"])
    df = enumerate(df)
    df_train, df_val = df.split(by='split')

    dataset_train = ImageLD(df_train)
    dataset_val = ImageQGLD(df_val)

    # or just reuse get_retrieval_datasets 

    return dataset_train, dataset_val

def build_txt_dataset(cfg) -> (IQGLD | ILD):
    pass

...

Config.yaml

dataset_builder: oml_img_datasets
args:
    df: df_full.csv
    cache_size: 100
    transforms_train:
      name: hypvit_resize
      args:
        im_Size: 224
    trainsforms_val:
      name: hypvit_resize
      args:
        im_Size: 224

Back compatibility

def convert_to_oml_three_format_if_needed(cfg):
     if "dataset_root" and "transforms_train" and ... in cfg:
        cfg["dataset_train"] = {"name": "image_label_dataset", args: {"df": ..., "transform": }}
    # don't forget to delete refactored keys
     ....

def extractor_training_pipeline():
        cfg = dictconfig_to_dict(cfg)
        cfg = convert_to_new_format_if_needed(cfg)

        dataset_train, dataset_val = get_datasets_builder(cfg)
        assert dataset_train is ILD and dataset_val is IQGLD
        assert check_consistency(dataset_train, dataset_val)

Update mock dataset and pipelines test

@hydra.main(config_path="configs", config_name="train_postprocessor.yaml", version_base=HYDRA_BEHAVIOUR)
def main_hydra(cfg: DictConfig) -> None:
    cfg = dictconfig_to_dict(cfg)
    download_mock_dataset(MOCK_DATASET_PATH) 
    cfg["dataset_builder"]["dataset_root"] = str(MOCK_DATASET_PATH)
    extractor_training_pipeline(cfg)

if __name__ == "__main__":
    main_hydra()

TESTS PIPELINES