kiudee / cs-ranking

Context-sensitive ranking and choice in Python with PyTorch
https://cs-ranking.readthedocs.io
Apache License 2.0
66 stars 15 forks source link

FATENetwork#fit crashes when X is a dict #143

Open daanvdn opened 4 years ago

daanvdn commented 4 years ago

hi @kiudee, first off: thanks for sharing this very interesting project with the community. I would very much like to experiment with it, esp. because of its support for learning discrete choices.

While playing around with the api, it seems I have encountered a potential bug pertaining to the csrank.core.fate_network.FATENetwork#fit method.

The data for which I would like to learn a discrete choice model has a variable number of objects (i.e. every instance may have a different value for n_objects). According to the documentation, the csrank.FATENetwork#fit method supports this scenario by allowing X to be a dict that maps n_objects to numpy arrays:

    X : numpy array or dict
        Feature vectors of the objects
        (n_instances, n_objects, n_features) if numpy array or map from n_objects to numpy arrays

I am using the csrank.DiscreteChoiceDatasetGenerator to create some synthetic data. More specifically, I am using the csrank.DiscreteChoiceDatasetGenerator#get_dataset_dictionaries method. However, when I pass the resulting X_train and y_train to the fit method this causes the error below:

Traceback (most recent call last):
  File "C:/Users/Daan_Vandennest/Git/landc-working-dx-ml/src/main/python/models/neural_ranking.py", line 21, in <module>
    fate.fit(X_train, Y_train, verbose=True, epochs=1)
  File "C:\Users\Daan_Vandennest\.virtualenvs\landc-working-dx-ml-eYqgHMIQ\lib\site-packages\csrank\objectranking\fate_object_ranker.py", line 98, in fit
    super().fit(X, Y, **kwd)
  File "C:\Users\Daan_Vandennest\.virtualenvs\landc-working-dx-ml-eYqgHMIQ\lib\site-packages\csrank\core\fate_network.py", line 539, in fit
    _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape
AttributeError: 'dict' object has no attribute 'shape'

It seems that on line 539 in fate_network.py it is attempted to access a shape attribute of X regardless of whether Xis a numpy.array or a dict
This error can be reproduced by running the code below:

from csrank import DiscreteChoiceDatasetGenerator
from csrank import FATEObjectRanker
from csrank.losses import smooth_rank_loss

seed = 123
n_train = 10000
n_test = 10000
n_features = 2
n_objects = 5
gen = DiscreteChoiceDatasetGenerator(dataset_type='medoid', random_state=seed,
                                     n_train_instances=n_train,
                                     n_test_instances=n_test,
                                     n_objects=n_objects,
                                     n_features=n_features)

# X_train, Y_train, X_test, Y_test = gen.get_single_train_test_split()
X_train, Y_train, X_test, Y_test = gen.get_dataset_dictionaries()

fate = FATEObjectRanker(loss_function=smooth_rank_loss)
fate.fit(X_train, Y_train, verbose=True, epochs=1) 

Can you confirm that this is a bug or am I using the api in a wrong way? If it is indeed a bug could you give me some guidance as to how I can fix it? If it's a relatively straightforward fix I can implement it and make a pull request.

Thanks

kiudee commented 4 years ago

Thank you for the detailed report. That indeed appears to be a bug in the sense that we originally supported both fixed-length input using NumPy arrays and variable-length input using dicts which was not actively maintained.

Passing variable length inputs is of course desirable, which is why we will work on restoring that interface.

daanvdn commented 4 years ago

thanks for the feedback @kiudee. Could I work around this by padding my inputs to make n_objects constant? Or will these dummy objects be too confusing to the FATENetwork causing it to not converge?

kiudee commented 4 years ago

Yes, one common workaround would be to pad the inputs with zeros to the maximum length. During prediction time, you would then need to use predict_scores and select the non-dummy object with the highest score.