LMZimmer / Auto-PyTorch_refactor

Apache License 2.0
0 stars 2 forks source link

Dataset properties to be calculated in the dataset class #41

Closed ravinkohli closed 3 years ago

ravinkohli commented 3 years ago

So, in terms of the dataset properties that the pipeline needs, they are X(train, valid, test), Y(train, valid, test), numerical and categorical columns, is_sparse, categories(an array of size (n_columns, num_categories_in_column), where it should contain column by column the set of categories in that column), num_features, task_type, output_type, is_small_preprocess, image_height, image_width. For an idea, we could also check this.

Also, currently, these keys are all in FitRequirements which can be called from the pipeline object, its just that we need a way to distinguish dataset_properties from other fit requirements like the autopytorchcomponents that are needed to fit another component, for example the lrschedulers need the optimizers. So we can either make a separate attribute of each component that contains the keys for all dataset properties and create a framework similar to that of fitrequirements that is already in the pipeline. For example, in a preprocessor,

self.dataset_properties = ['X', 'train_indices', 'is_sparse']
self.add_dataset_properties([FitRequirement('X', [np.ndarray, pd.DataFrame, torch.Tensor, csr_martrix],
                                                [FitRequirement('train_indices', [List[int],], 
                                                [FitRequirement('is_sparse', [bool,])
franchuterivera commented 3 years ago

In the case of image classification, n_columns and categories_in_column doesn't make sense (same fore image height/width, in the tabular case).

How is it gonna be handled? Will it always be returned as None? or will it be dynamic, so that the get_dataset_properties has a task-aware argument?

ravinkohli commented 3 years ago

In the case of image classification, n_columns and categories_in_column doesn't make sense (same fore image height/width, in the tabular case).

How is it gonna be handled? Will it always be returned as None? or will it be dynamic, so that the get_dataset_properties has a task-aware argument?

I think there would be different dataset class for tabular, image and timeseries and they'll be getting the dataset_properties from the pipeline and the pipeline would only have those properties that make sense in its context

LMZimmer commented 3 years ago

How is it gonna be handled? Will it always be returned as None? or will it be dynamic, so that the get_dataset_properties has a task-aware argument?

It should be the same as the fit_requirements are handled now for different tasks, by the pipeline automatically parsing its components.

Thinking about it, I think that is the way we want to do it. Separating the pipeline config into dataset properties and config makes sense, also for the user to know what he has to specify. The other way I could think of would be a delimiter to the fit_requirements such as data: but that seems less elegant to me. So basically, as Ravin suggested, I would argue for another type of requirement dataset_requirement that is the same as fit_requirement but expects its keys in a X["dataset_properties"] dict and not in the X dict. What do you think?

LMZimmer commented 3 years ago

I implemented this in a couple of ways today and I didn't really like any of them. I'd therefore be in favor of just removing the dataset_properties dict and having all of its keys in the X dict. Then it is the job of the API to match dataset and pipeline properly which also ensure complete modularization.

bastiscode commented 3 years ago

But the get_dataset_properties method on the dataset should still stay the same, right? So you would want to pass the dataset_properties one gets from calling it into the X dict as keyword arguments like this e.g.: **dataset_properties?

LMZimmer commented 3 years ago

Exactly. I actually implemented a new method in the pipeline that allows requesting the keys that are expected to come from the dataset. That should make it easier for the API to ensure it gets the correct inputs from the correct sources

LMZimmer commented 3 years ago

Closing as this is addressed by #43 and #46