Closed chezou closed 3 months ago
When I try to fit with LightFMWrapperModel by setting loss="warp-kos", it fails with the error.
loss="warp-kos"
NotImplementedError: k-OS loss with sample weights not implemented.
This is because sample_weight to be passed to LightFM model is fixed with ui_coo.
sample_weight
ui_coo
Possible fix can be:
diff --git a/rectools/models/lightfm.py b/rectools/models/lightfm.py index 693d7ae..32dad45 100644 --- a/rectools/models/lightfm.py +++ b/rectools/models/lightfm.py @@ -77,12 +77,13 @@ class LightFMWrapperModel(FixedColdRecoModelMixin, VectorModel): ui_coo = dataset.get_user_item_matrix(include_weights=True).tocoo(copy=False) user_features = self._prepare_features(dataset.get_hot_user_features(), dataset.n_hot_users) item_features = self._prepare_features(dataset.get_hot_item_features(), dataset.n_hot_items) + sample_weight = None if self._model.loss == "warp-kos" else ui_coo self.model.fit( ui_coo, user_features=user_features, item_features=item_features, - sample_weight=ui_coo, + sample_weight=sample_weight, epochs=self.n_epochs, num_threads=self.n_threads, verbose=self.verbose > 0,
Train successfly.
macOS Sonoma
Python 3.11.9
0.7.0
What happened?
When I try to fit with LightFMWrapperModel by setting
loss="warp-kos"
, it fails with the error.This is because
sample_weight
to be passed to LightFM model is fixed withui_coo
.Possible fix can be:
Expected behavior
Train successfly.
Relevant logs and/or screenshots
Detailed Traceback
``` ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /Users/ariga/src/td-misc/aki/rectools-test/./src/rectools_test/training.py:362 in run_all │ │ │ │ 359 │ │ │ 360 │ s_time = time.time() │ │ 361 │ model = make_base_model(algorithm=algorithm, config=config) │ │ ❱ 362 │ model.fit(prep_data.dataset) │ │ 363 │ train_time = time.time() - s_time │ │ 364 │ logger.info(f"Finished training: elapsed time: {train_time:.4f} seconds") │ │ 365 │ │ │ │ ╭─────────────────────────────────────────── locals ───────────────────────────────────────────╮ │ │ │ add_cold = False │ │ │ │ algorithm = 'lightfm' │ │ │ │ alpha = 10.0 │ │ │ │ ann = False │ │ │ │ arg_params = { │ │ │ │ │ 'algorithm': 'lightfm', │ │ │ │ │ 'factors': 32, │ │ │ │ │ 'num_threads': 16, │ │ │ │ │ 'regularization': 0.1, │ │ │ │ │ 'alpha': 10.0, │ │ │ │ │ 'iterations': 10, │ │ │ │ │ 'verify_negative_samples': True, │ │ │ │ │ 'max_k': 100, │ │ │ │ │ 'similarity': 'bm25', │ │ │ │ │ 'k': 5, │ │ │ │ │ ... +12 │ │ │ │ } │ │ │ │ begin_from = None │ │ │ │ config = LightFMConfig( │ │ │ │ │ no_components=32, │ │ │ │ │ learning_rate=0.05, │ │ │ │ │ k=5, │ │ │ │ │ n=10, │ │ │ │ │ item_alpha=0.0, │ │ │ │ │ user_alpha=0.0, │ │ │ │ │ loss='warp-kos', │ │ │ │ │ learning_schedule='adagrad', │ │ │ │ │ max_sampled=10, │ │ │ │ │ random_state=42, │ │ │ │ │ num_threads=16, │ │ │ │ │ epochs=1 │ │ │ │ ) │ │ │ │ df = │ │ │ │ │ │ user_id item_id weight │ │ │ │ datetime time │ │ │ │ 0 3zBJUlWtPNoZ0uN83ODbyg 2bXm0SynOfxDzfrdrCyXqg 4.0 │ │ │ │ 2005-02-16 03:23:22 1706854726 │ │ │ │ 1 3zBJUlWtPNoZ0uN83ODbyg 3g6XqkBikTTbZmTukbeGnw 4.0 │ │ │ │ 2005-02-16 03:29:39 1706854726 │ │ │ │ 2 3zBJUlWtPNoZ0uN83ODbyg PP3BBaVxZLcJU54uP_wL6Q 5.0 │ │ │ │ 2005-02-16 04:06:26 1706854726 │ │ │ │ 3 XCsZ3hWa_6oP1WkWvK7pmg U3grYFIeu6RgAAQgdriHww 5.0 │ │ │ │ 2005-03-01 16:57:17 1706854726 │ │ │ │ 4 XCsZ3hWa_6oP1WkWvK7pmg Aes-0Q_guDeYewMapFs_vg 2.0 │ │ │ │ 2005-03-01 16:59:37 1706854726 │ │ │ │ ... ... ... ... │ │ │ │ ... ... │ │ │ │ 6990275 lmiiFd9KC15fs4xtEoXRvw XDMno4l95AXgYOd0yDtHZA 5.0 │ │ │ │ 2022-01-19 19:48:13 1706854726 │ │ │ │ 6990276 2Mb0st9WVyccaz6sKNLHWw M88FFZZ2o_7QKpCFA_8RtA 5.0 │ │ │ │ 2022-01-19 19:48:16 1706854726 │ │ │ │ 6990277 3TQKP7KlNRdrI2gOkG7slg jVg-KTXEFIeAq47DTp4Hrw 5.0 │ │ │ │ 2022-01-19 19:48:19 1706854726 │ │ │ │ 6990278 i1PMqye40QWNkJ0MYGHuzg J0joPXxmN-_9Lzafspqdbw 5.0 │ │ │ │ 2022-01-19 19:48:25 1706854726 │ │ │ │ 6990279 IH0ToaZ8hJXO2pVieN7dpQ VItkA7pL82rCZdxHH8vBGA 5.0 │ │ │ │ 2022-01-19 19:48:45 1706854726 │ │ │ │ │ │ │ │ [6990280 rows x 5 columns] │ │ │ │ elapsed_time = 9.071298122406006 │ │ │ │ epochs = 1 │ │ │ │ factors = 32 │ │ │ │ inverse = False │ │ │ │ item_alpha = 0.0 │ │ │ │ item_col = 'item_id' │ │ │ │ iterations = 10 │ │ │ │ k = 5 │ │ │ │ learning_rate = 0.05 │ │ │ │ loss = 'warp-kos' │ │ │ │ max_k = 100 │ │ │ │ max_sampled = 10 │ │ │ │ model =Operating System
macOS Sonoma
Python Version
Python 3.11.9
RecTools version
0.7.0