Closed timokau closed 4 years ago
Example:
[ins] In [1]: from csrank.objectranking import FATEObjectRanker; fate = FATEObjectRanker()
Using TensorFlow backend.
/nix/store/kax45bpa01hh152r06d0x049yb5pjwxn-python3-3.7.9-env/lib/python3.7/site-packages/sklearn/utils/deprecation.py:143: FutureWarning: The sklearn.datasets.samples_generator module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.datasets. Anything that cannot be imported from sklearn.datasets is now part of the private API.
warnings.warn(message, FutureWarning)
[ins] In [2]: fate.get_params()
Out[2]:
{'activation': 'selu',
'batch_size': 256,
'kernel_initializer': 'lecun_normal',
'kernel_regularizer': <function keras.regularizers.l2(l=0.01)>,
'loss_function': <function csrank.losses.identifiable.<locals>.wrap_loss(y_true, y_pred)>,
'metrics': (<function csrank.metrics.zero_one_rank_loss_for_scores_ties(y_true, s_pred)>,),
'n_hidden_joint_layers': 32,
'n_hidden_joint_units': 32,
'n_hidden_set_layers': 2,
'n_hidden_set_units': 2,
'optimizer': keras.optimizers.SGD,
'random_state': None,
'optimizer__learning_rate': 0.01,
'optimizer__momentum': 0.0,
'optimizer__nesterov': False}
[ins] In [3]: fate.set_params(optimizer__learning_rate=0.02)
[ins] In [4]: fate.get_params()
Out[4]:
{'activation': 'selu',
'batch_size': 256,
'kernel_initializer': 'lecun_normal',
'kernel_regularizer': <function keras.regularizers.l2(l=0.01)>,
'loss_function': <function csrank.losses.identifiable.<locals>.wrap_loss(y_true, y_pred)>,
'metrics': (<function csrank.metrics.zero_one_rank_loss_for_scores_ties(y_true, s_pred)>,),
'n_hidden_joint_layers': 32,
'n_hidden_joint_units': 32,
'n_hidden_set_layers': 2,
'n_hidden_set_units': 2,
'optimizer': keras.optimizers.SGD,
'random_state': None,
'optimizer__learning_rate': 0.02,
'optimizer__momentum': 0.0,
'optimizer__nesterov': False}
Description
See #169 for context. This fix is a bit "hacky, works for now". We will likely have to rewrite or significantly modify this code anyway as part of #125.
How Has This Been Tested?
Lints & tests, CI.
Does this close/impact existing issues?
Fixes #169.
Types of changes
Checklist: