google-research / tapas

End-to-end neural table-text understanding models.
Apache License 2.0
1.13k stars 217 forks source link

How to get inference on multiple tables and their queries by passing them in batch? #129

Closed shabbie closed 3 years ago

shabbie commented 3 years ago

I am trying to pass multiple tables and their corresponding questions as a batch to the perdict function inside wtq_predictions.ipynb notebook provided for inference.

def predict(table_data, queries): table = [list(map(lambda s: s.strip(), row.split("|"))) for row in table_data.split("\n") if row.strip()] examples = convert_interactions_to_examples([(table, queries)]) write_tf_example("results/wtq/tf_examples/test.tfrecord", examples) write_tf_example("results/wtq/tf_examples/random-split-1-dev.tfrecord", [])

How to pass those multiple number of tables and associated queries to "convert_interactions_to_examples" at once?

eisenjulian commented 3 years ago

Hi there, thanks for the interest in Tapas. You could modify the function as follows: (maybe you'll have to adjust the batch size to something smaller.

def predict(table_data_and_queries):
  tables = []
  queries = []
  for table_data, query in table_data_and_queries:
    table = [list(map(lambda s: s.strip(), row.split("|"))) 
             for row in table_data.split("\n") if row.strip()]
    tables.append(table)
    queries.append([query])

  examples = convert_interactions_to_examples(zip(tables, queries))
  write_tf_example("results/wtq/tf_examples/test.tfrecord", examples)
  write_tf_example("results/wtq/tf_examples/random-split-1-dev.tfrecord", [])

  ! python -m tapas.run_task_main \
    --task="WTQ" \
    --output_dir="results" \
    --noloop_predict \
    --test_batch_size={len(queries)} \
    --tapas_verbosity="ERROR" \
    --compression_type= \
    --reset_position_index_per_cell \
    --init_checkpoint="tapas_model/model.ckpt" \
    --bert_config_file="tapas_model/bert_config.json" \
    --mode="predict" 2> error

  results_path = "results/wtq/model/test.tsv"
  all_coordinates = []
  with open(results_path) as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t')
    for index, row in enumerate(reader):
      table = tables[index]
      df = pd.DataFrame(table[1:], columns=table[0])
      display(IPython.display.HTML(df.to_html(index=False)))
      print()

      coordinates = sorted(prediction_utils.parse_coordinates(row["answer_coordinates"]))
      all_coordinates.append(coordinates)
      answers = ', '.join([table[row + 1][col] for row, col in coordinates])
      aggregation = aggregation_to_string(int(row["pred_aggr"]))
      print(">", queries[index])
      answer_text = str(answers)
      if aggregation != "NONE":
        answer_text = f"{aggregation} of {answer_text}"
      print(answer_text)
  return all_coordinates
shabbie commented 3 years ago

Thank you @eisenjulian. This helped.