This PR fixes the retrieval encoder methods (e.g. to_top_k_model(), batch_predict()), that were failing in some cases, depending on the input features, e.g. multi-hot non-ragged item features.
Implementation Details :construction:
The retrieval models (i.e. based on RetrievalModelV2) are composed by two towers, that encode item features and query/user features in separate towers. It allows for encoding the towers separately, generating the item or query embeddings.
Internally those encoding methods use Dask DataFrame.map_partitions() to call the encoding function for every partition and generate the corresponding output of the encoding function (i.e., the output of the tower).
If the meta argument is not passed to Dask DataFrame.map_partitions(), it generates some fake data base on the input dataframe schema to infer the output dataframe schema. But that may generate fake data that is different from the real data, in particular, the fake data generated for dense list columns (not-ragged), (e.g. multi-hot or embedding features), causes an error when the model encode function is called.
This PR sets the meta argument of the DataFrame.map_partitions() by computing manually the expected output dataframe schema from a sample batch from real data in order to make the encoding more robust for different types of inputs.
The PR also changes the data_iterator_func() that is used by the model encoder to use directly the schema rather than the old Loader arguments that set categorical, continuous and targets separately, as the previous code did not deal correctly with list features.
Testing Details :mag:
Created the test_two_tower_v2_export_item_tower_embeddings_with_seq_item_features test, that uses the music_streaming_data synthetic data and contains multi-hot list features (ragged and not ragged), for which the encoding functions were failing before this fix
Goals :soccer:
This PR fixes the retrieval encoder methods (e.g.
to_top_k_model()
,batch_predict()
), that were failing in some cases, depending on the input features, e.g. multi-hot non-ragged item features.Implementation Details :construction:
RetrievalModelV2
) are composed by two towers, that encode item features and query/user features in separate towers. It allows for encoding the towers separately, generating the item or query embeddings.DataFrame.map_partitions()
to call the encoding function for every partition and generate the corresponding output of the encoding function (i.e., the output of the tower).meta
argument is not passed to DaskDataFrame.map_partitions()
, it generates some fake data base on the input dataframe schema to infer the output dataframe schema. But that may generate fake data that is different from the real data, in particular, the fake data generated for dense list columns (not-ragged), (e.g. multi-hot or embedding features), causes an error when the model encode function is called.meta
argument of theDataFrame.map_partitions()
by computing manually the expected output dataframe schema from a sample batch from real data in order to make the encoding more robust for different types of inputs.data_iterator_func()
that is used by the model encoder to use directly the schema rather than the oldLoader
arguments that setcategorical
,continuous
andtargets
separately, as the previous code did not deal correctly with list features.Testing Details :mag:
test_two_tower_v2_export_item_tower_embeddings_with_seq_item_features
test, that uses the music_streaming_data synthetic data and contains multi-hot list features (ragged and not ragged), for which the encoding functions were failing before this fix