Closed shabbie closed 3 years ago
Hi there, thanks for the interest in Tapas. You could modify the function as follows: (maybe you'll have to adjust the batch size to something smaller.
def predict(table_data_and_queries):
tables = []
queries = []
for table_data, query in table_data_and_queries:
table = [list(map(lambda s: s.strip(), row.split("|")))
for row in table_data.split("\n") if row.strip()]
tables.append(table)
queries.append([query])
examples = convert_interactions_to_examples(zip(tables, queries))
write_tf_example("results/wtq/tf_examples/test.tfrecord", examples)
write_tf_example("results/wtq/tf_examples/random-split-1-dev.tfrecord", [])
! python -m tapas.run_task_main \
--task="WTQ" \
--output_dir="results" \
--noloop_predict \
--test_batch_size={len(queries)} \
--tapas_verbosity="ERROR" \
--compression_type= \
--reset_position_index_per_cell \
--init_checkpoint="tapas_model/model.ckpt" \
--bert_config_file="tapas_model/bert_config.json" \
--mode="predict" 2> error
results_path = "results/wtq/model/test.tsv"
all_coordinates = []
with open(results_path) as csvfile:
reader = csv.DictReader(csvfile, delimiter='\t')
for index, row in enumerate(reader):
table = tables[index]
df = pd.DataFrame(table[1:], columns=table[0])
display(IPython.display.HTML(df.to_html(index=False)))
print()
coordinates = sorted(prediction_utils.parse_coordinates(row["answer_coordinates"]))
all_coordinates.append(coordinates)
answers = ', '.join([table[row + 1][col] for row, col in coordinates])
aggregation = aggregation_to_string(int(row["pred_aggr"]))
print(">", queries[index])
answer_text = str(answers)
if aggregation != "NONE":
answer_text = f"{aggregation} of {answer_text}"
print(answer_text)
return all_coordinates
Thank you @eisenjulian. This helped.
I am trying to pass multiple tables and their corresponding questions as a batch to the perdict function inside wtq_predictions.ipynb notebook provided for inference.
def predict(table_data, queries): table = [list(map(lambda s: s.strip(), row.split("|"))) for row in table_data.split("\n") if row.strip()] examples = convert_interactions_to_examples([(table, queries)]) write_tf_example("results/wtq/tf_examples/test.tfrecord", examples) write_tf_example("results/wtq/tf_examples/random-split-1-dev.tfrecord", [])
How to pass those multiple number of tables and associated queries to "convert_interactions_to_examples" at once?