OSU-NLP-Group / TableLlama

[NAACL'24] Dataset, code and models for "TableLlama: Towards Open Large Generalist Models for Tables".
https://osu-nlp-group.github.io/TableLlama/
MIT License
102 stars 8 forks source link

The result of wikisql_test generated by the model downloaded from huggingface #5

Open lucky20020327 opened 4 months ago

lucky20020327 commented 4 months ago

I tried the model downloaded from huggingface on the dataset wikisql_test and get the result like following:

{
  "idx": 44,
  "instruction": "This is a table QA task. The goal of this task is to answer the question given the table.",
  "input_seg": " [TAB] col: | scheme | tariff code | bts retail price (regulated) | approx premium | prefixes | [SEP] | pence per call, fixed at all times | ff15 | 5.10p/call | 7p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff28 nts | 8.50p/call | 10p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff0 prs | 8.60p/call | 10p/call | 0871, 0872 | [SEP] | pence per minute, fixed at all times | g21 | geographic rate or inclusive | nil | 030, 033, 034, 037 | [SEP] | pence per minute, fixed at all times | g22 | 0.425p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g8 | 0.851p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g27 | 1.277p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g9 | 1.702p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | (no code) | 2p/min or inclusive | 2p/min | 0845 only | [SEP] | pence per minute, fixed at all times | g28 | 2.128p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g10 | 2.553p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g11 | 3.404p/min | 5p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g6 and p27 | 4.255p/min | 7p/min | 0843, 0844, 0871, 0872 | [SEP] | pence per minute, fixed at all times | (no code) | 8p/min or inclusive | ~10p/min 0p/min | 0870 only |",
  "question": "how many tariff codes have a bts retail price of 2.553p/min?",
  "output": "1.0",
  "predict": "15/2/1/5/2/1/2/1/5/2/2/1/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2/2"
},
{
  "idx": 45,
  "instruction": "This is a table QA task. The goal of this task is to answer the question given the table.",
  "input_seg": " [TAB] col: | scheme | tariff code | bts retail price (regulated) | approx premium | prefixes | [SEP] | pence per call, fixed at all times | ff15 | 5.10p/call | 7p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff28 nts | 8.50p/call | 10p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff0 prs | 8.60p/call | 10p/call | 0871, 0872 | [SEP] | pence per minute, fixed at all times | g21 | geographic rate or inclusive | nil | 030, 033, 034, 037 | [SEP] | pence per minute, fixed at all times | g22 | 0.425p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g8 | 0.851p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g27 | 1.277p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g9 | 1.702p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | (no code) | 2p/min or inclusive | 2p/min | 0845 only | [SEP] | pence per minute, fixed at all times | g28 | 2.128p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g10 | 2.553p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g11 | 3.404p/min | 5p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g6 and p27 | 4.255p/min | 7p/min | 0843, 0844, 0871, 0872 | [SEP] | pence per minute, fixed at all times | (no code) | 8p/min or inclusive | ~10p/min 0p/min | 0870 only |",
  "question": "what prefixes are priced at pence per minute, fixed at all times with a premium of 3p/min?",
  "output": "0843, 0844, 0843, 0844",
  "predict": "<0.0:0>, <0:0> <0:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0:0:000:000:00000000000000000000000000000000000000000000000000000000000000000000000"
},
{
  "idx": 46,
  "instruction": "This is a table QA task. The goal of this task is to answer the question given the table.",
  "input_seg": " [TAB] col: | scheme | tariff code | bts retail price (regulated) | approx premium | prefixes | [SEP] | pence per call, fixed at all times | ff15 | 5.10p/call | 7p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff28 nts | 8.50p/call | 10p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff0 prs | 8.60p/call | 10p/call | 0871, 0872 | [SEP] | pence per minute, fixed at all times | g21 | geographic rate or inclusive | nil | 030, 033, 034, 037 | [SEP] | pence per minute, fixed at all times | g22 | 0.425p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g8 | 0.851p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g27 | 1.277p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g9 | 1.702p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | (no code) | 2p/min or inclusive | 2p/min | 0845 only | [SEP] | pence per minute, fixed at all times | g28 | 2.128p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g10 | 2.553p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g11 | 3.404p/min | 5p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g6 and p27 | 4.255p/min | 7p/min | 0843, 0844, 0871, 0872 | [SEP] | pence per minute, fixed at all times | (no code) | 8p/min or inclusive | ~10p/min 0p/min | 0870 only |",
  "question": "what is the bts retail price (regulated) for tariff code g10?",
  "output": "2.553p/min",
  "predict": "2013:3:2:1:4:2#3:1:3:1:1:2#3:1:2:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3:1:3"
}

There are a lot meaningless tokens in the prediction results. Is this normal, or did I make some mistakes?

The shell command I used:

MODEL_DIR=...
INPUT_DIR=...
OUTPUT_DIR=...

INPUT_FILE=wikisql_test.json

python3 inference_hitab_tabfact_fetaqa.py  \
        --base_model $MODEL_DIR \
        --context_size 8192 \
        --max_gen_len 128 \
        --flash_attn True \
        --input_data_file $INPUT_DIR/$INPUT_FILE \
        --output_data_file $OUTPUT_DIR/$INPUT_FILE

The machine I'm using:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.06              Driver Version: 545.23.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla V100-PCIE-32GB           Off | 00000000:00:05.0 Off |                    0 |
| N/A   50C    P0              42W / 250W |  26118MiB / 32768MiB |      8%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla V100-PCIE-32GB           Off | 00000000:00:06.0 Off |                    0 |
| N/A   54C    P0              43W / 250W |   5874MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   2  Tesla V100-PCIE-32GB           Off | 00000000:00:07.0 Off |                    0 |
| N/A   54C    P0              45W / 250W |   5874MiB / 32768MiB |     66%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   3  Tesla V100-PCIE-32GB           Off | 00000000:00:08.0 Off |                    0 |
| N/A   55C    P0             224W / 250W |  15684MiB / 32768MiB |     53%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
zhangtianshu commented 4 months ago

Did you encounter any issues during the inference? The normal prediction should be like this for your examples:

{
    "idx": 44,
    "instruction": "This is a table QA task. The goal of this task is to answer the question given the table.",
    "input_seg": " [TAB] col: | scheme | tariff code | bts retail price (regulated) | approx premium | prefixes | [SEP] | pence per call, fixed at all times | ff15 | 5.10p/call | 7p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff28 nts | 8.50p/call | 10p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff0 prs | 8.60p/call | 10p/call | 0871, 0872 | [SEP] | pence per minute, fixed at all times | g21 | geographic rate or inclusive | nil | 030, 033, 034, 037 | [SEP] | pence per minute, fixed at all times | g22 | 0.425p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g8 | 0.851p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g27 | 1.277p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g9 | 1.702p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | (no code) | 2p/min or inclusive | 2p/min | 0845 only | [SEP] | pence per minute, fixed at all times | g28 | 2.128p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g10 | 2.553p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g11 | 3.404p/min | 5p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g6 and p27 | 4.255p/min | 7p/min | 0843, 0844, 0871, 0872 | [SEP] | pence per minute, fixed at all times | (no code) | 8p/min or inclusive | ~10p/min 0p/min | 0870 only |",
    "question": "how many tariff codes have a bts retail price of 2.553p/min?",
    "output": "1.0",
    "predict": "1.0"
  },
  {
    "idx": 45,
    "instruction": "This is a table QA task. The goal of this task is to answer the question given the table.",
    "input_seg": " [TAB] col: | scheme | tariff code | bts retail price (regulated) | approx premium | prefixes | [SEP] | pence per call, fixed at all times | ff15 | 5.10p/call | 7p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff28 nts | 8.50p/call | 10p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff0 prs | 8.60p/call | 10p/call | 0871, 0872 | [SEP] | pence per minute, fixed at all times | g21 | geographic rate or inclusive | nil | 030, 033, 034, 037 | [SEP] | pence per minute, fixed at all times | g22 | 0.425p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g8 | 0.851p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g27 | 1.277p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g9 | 1.702p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | (no code) | 2p/min or inclusive | 2p/min | 0845 only | [SEP] | pence per minute, fixed at all times | g28 | 2.128p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g10 | 2.553p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g11 | 3.404p/min | 5p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g6 and p27 | 4.255p/min | 7p/min | 0843, 0844, 0871, 0872 | [SEP] | pence per minute, fixed at all times | (no code) | 8p/min or inclusive | ~10p/min 0p/min | 0870 only |",
    "question": "what prefixes are priced at pence per minute, fixed at all times with a premium of 3p/min?",
    "output": "0843, 0844, 0843, 0844",
    "predict": "<g22>, <g27>, <g28>"
  },
  {
    "idx": 46,
    "instruction": "This is a table QA task. The goal of this task is to answer the question given the table.",
    "input_seg": " [TAB] col: | scheme | tariff code | bts retail price (regulated) | approx premium | prefixes | [SEP] | pence per call, fixed at all times | ff15 | 5.10p/call | 7p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff28 nts | 8.50p/call | 10p/call | 0871, 0872 | [SEP] | pence per call, fixed at all times | ff0 prs | 8.60p/call | 10p/call | 0871, 0872 | [SEP] | pence per minute, fixed at all times | g21 | geographic rate or inclusive | nil | 030, 033, 034, 037 | [SEP] | pence per minute, fixed at all times | g22 | 0.425p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g8 | 0.851p/min | 2p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g27 | 1.277p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g9 | 1.702p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | (no code) | 2p/min or inclusive | 2p/min | 0845 only | [SEP] | pence per minute, fixed at all times | g28 | 2.128p/min | 3p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g10 | 2.553p/min | 4p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g11 | 3.404p/min | 5p/min | 0843, 0844 | [SEP] | pence per minute, fixed at all times | g6 and p27 | 4.255p/min | 7p/min | 0843, 0844, 0871, 0872 | [SEP] | pence per minute, fixed at all times | (no code) | 8p/min or inclusive | ~10p/min 0p/min | 0870 only |",
    "question": "what is the bts retail price (regulated) for tariff code g10?",
    "output": "2.553p/min",
    "predict": "2.553p/min"
lucky20020327 commented 4 months ago

There was a warning:

UserWarning: Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593

No other useful information.

zhangtianshu commented 4 months ago

Here is the list of GPU types that flash attention can support: https://github.com/Dao-AILab/flash-attention/issues/148. You can check whether your machine is included. If not, you can try to use the following script to see whether you can get the correct predictions (basically set the flash attention as False)

python3 inference_hitab_tabfact_fetaqa.py  \
        --base_model $MODEL_DIR \
        --context_size 8192 \
        --max_gen_len 128 \
        --flash_attn False \
        --input_data_file $INPUT_DIR/$INPUT_FILE \
        --output_data_file $OUTPUT_DIR/$INPUT_FILE
lucky20020327 commented 4 months ago

Still the same. Should I fine tune the model first?