Open daanvdn opened 4 years ago
Thank you for the detailed report. That indeed appears to be a bug in the sense that we originally supported both fixed-length input using NumPy arrays and variable-length input using dicts which was not actively maintained.
Passing variable length inputs is of course desirable, which is why we will work on restoring that interface.
thanks for the feedback @kiudee. Could I work around this by padding my inputs to make n_objects
constant? Or will these dummy objects be too confusing to the FATENetwork causing it to not converge?
Yes, one common workaround would be to pad the inputs with zeros to the maximum length.
During prediction time, you would then need to use predict_scores
and select the non-dummy object with the highest score.
hi @kiudee, first off: thanks for sharing this very interesting project with the community. I would very much like to experiment with it, esp. because of its support for learning discrete choices.
While playing around with the api, it seems I have encountered a potential bug pertaining to the
csrank.core.fate_network.FATENetwork#fit
method.The data for which I would like to learn a discrete choice model has a variable number of objects (i.e. every instance may have a different value for
n_objects
). According to the documentation, thecsrank.FATENetwork#fit
method supports this scenario by allowingX
to be adict
that mapsn_objects
to numpy arrays:I am using the
csrank.DiscreteChoiceDatasetGenerator
to create some synthetic data. More specifically, I am using thecsrank.DiscreteChoiceDatasetGenerator#get_dataset_dictionaries
method. However, when I pass the resultingX_train
andy_train
to thefit
method this causes the error below:It seems that on line 539 in
fate_network.py
it is attempted to access ashape
attribute ofX
regardless of whetherX
is anumpy.array
or adict
This error can be reproduced by running the code below:
Can you confirm that this is a bug or am I using the api in a wrong way? If it is indeed a bug could you give me some guidance as to how I can fix it? If it's a relatively straightforward fix I can implement it and make a pull request.
Thanks