Kaszanas / SC2_Datasets

https://sc2-datasets.readthedocs.io/
GNU General Public License v3.0
8 stars 3 forks source link

Verify Dataset __getitem__ return value #4

Closed Kaszanas closed 2 years ago

Kaszanas commented 2 years ago

There is no clear documentation on what should be returned from __getitem__ method using a PyTorch Dataset class.

It means a lack of clear interface for any machine learning model that is going to be used down the line.

After some research it seems that what is commonly used as a return type is return x, y where x is a tensor or numpy array containing features, and y is a tensor or numpy array containing labels.

For the case of SC2ReplayData there is no clear list of features and labels. Some kind of logic needs to be introduced on dataset initialization that allows for feature/label selection. PyTorch transforms seem to be a good way to pass such logic to the dataset.

This would imply that the rest of the logic is applied within the model.

References:

If you have any thoughts on that let me know: @leafnode

Kaszanas commented 2 years ago

I think that the reasonable decision is to return SC2ReplayData by default. All other tensor creation or features, label creation can be done through transforms.

Maybe such decision can make it easier to interface with other frameworks?

Still this needs to be researched further.

Kaszanas commented 2 years ago

This is solved with transforms.

They expose logic for returning custom information to models.