Closed zhoutong-fu closed 4 years ago
Wide_ftrs also need to match among tasks right?
Comments left. Thanks for the implementation for multitask training!
Let's add more details on what this PR is not implementing here. I.e., although different losses are computed, these losses are using the same loss_fn. Also, we do not support mixing ranking tasks with classification tasks.
Updated in the PR description. Let me know if it looks good to youl
Wide_ftrs also need to match among tasks right?
Updated the PR and added a section on how to run multitask training.
Description
Detext clients have asked for multitask learning on ranking tasks so that the model can benefit from different subtasks. After reviewing surveys of multitask learning literature, we've decided to add the basic yet very popular hard-parameter sharing neural network structures to Detext framework.
This implementation assumes that all subtasks share the same deep features while wide features and labels are different, which suggests that both MLP and LTR layers are task-specific and other layers are shared across subtasks. The implementation also assumes individual task losses to be optimized separately and each training record only contains features/labels for one subtask. For training records with more than one sub-tasks, please split them into multiple records.
In summary, this implementation supports multitask training when
This implementation does NOT support
How to run multitask training?
task_id
in the training data as a list of int64. The list should contain only one entry, representing the single subtask for the underlying training record.run_detext_multitask.py
for other input parametersType of change
List all changes
Please list all changes in the commit.
data_fn.py
. task_id is a list of typeint64
and only contains 1 entry (same as query or uid).deep_match.py
misc_utils.py
to specify and organize required parameters/featurestrain.py
multitask_examples.tfrecord
/test.tfrecord
: sample multitask training datarun_detext_multitask.sh
: sample script to run multitask trainingTesting
Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration
pytest
passedbash run_detext.sh
showed consistent resultsbash run_detext_multitask.sh
partial outputs shown below:Trainable variables
Eval results of best model on test data
Test Configuration:
Checklist