namisan / mt-dnn

Multi-Task Deep Neural Networks for Natural Language Understanding
MIT License
2.22k stars 412 forks source link

Prediction: How to find the task id? #234

Closed hoangthangta closed 2 years ago

hoangthangta commented 2 years ago

I train 5 tasks at the same. In the prediction, I use this command, --task_id=1 is triggered an error. I check and see that the output is an array like this [0] only? I though the ouput should be something like [0,1,0,0,1]?

python3 predict.py --task accident --task_def dataset/all_task_def.yml --max_seq_len 128 --batch_size_eval 8 --checkpoint checkpoint/model_0.pt --prep_input dataset/bert_base_cased/test_train.json --task_id 1 --score dataset/bert_base_cased/test_pred_1.json

namisan commented 2 years ago

If you have e.g., 3 task A, B, C during the training, the task id for A, B, C is 0, 1, 2 respectively which is used to select the correct task header.

hoangthangta commented 2 years ago

I use this command:

python3 predict.py --task homicide --task_def dataset/all_task_def.yml --max_seq_len 128 --batch_size_eval 8 --checkpoint checkpoint/model_2.pt --prep_input dataset/bert-base-multilingual-cased/test.json --with_label --task_id 1 --score dataset/bert-base-multilingual-cased/test_pred.json

and here is an error:

Traceback (most recent call last):
  File "predict.py", line 108, in <module>
    with_label=args.with_label,
  File "/home/thang/mt-dnn/mt_dnn/inference.py", line 97, in eval_model
    score, pred, gold = model.predict(batch_info, batch_data)
  File "/home/thang/mt-dnn/mt_dnn/model.py", line 421, in predict
    score = self.mnetwork(*inputs)
  File "/home/thang/.env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/thang/.env/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 166, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/thang/.env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/thang/mt-dnn/mt_dnn/matcher.py", line 204, in forward
    decoder_opt = self.decoder_opt[task_id]
IndexError: list index out of range

Hope you can help to figure out how to fix.

hoangthangta commented 2 years ago

In the predict.py, I have to modify this:

task_def_list = [task_def]*5 # 5 is the number of tasks, here all 5 tasks are binary classification or define your one here, or by arguments if you want. task_def_list = [task_def, task_def, task_def, task_def, task_def]

It seems predict.py needed to rewrite a bit more to support multi head predictions.