Open SSamDav opened 3 years ago
Do you have more details on this? including the command you run, the error you get, etc.
I run your notebook with the DBs that I mentioned and i get this error:
File "/media/disk/workspace/gap-text2sql/rat-sql-gap/seq2struct/commands/infer.py", line 90, in _infer_one
model, data_item, preproc_item, beam_size=beam_size, max_steps=1000, from_cond=False)
File "/media/disk/workspace/gap-text2sql/rat-sql-gap/seq2struct/models/spider/spider_beam_search.py", line 21, in beam_search_with_heuristics
inference_state, next_choices = model.begin_inference(orig_item, preproc_item)
File "/media/disk/workspace/gap-text2sql/rat-sql-gap/seq2struct/models/enc_dec.py", line 133, in begin_inference
enc_state, = self.encoder([enc_input])
File "/home/ubuntu/anaconda3/envs/gap-text2sql/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/media/disk/workspace/gap-text2sql/rat-sql-gap/seq2struct/models/spider/spider_enc.py", line 1405, in forward
padded_token_lists, att_mask_lists, tok_type_lists = self.pad_sequence_for_bert_batch(batch_token_lists)
File "/media/disk/workspace/gap-text2sql/rat-sql-gap/seq2struct/models/spider/spider_enc.py", line 1546, in pad_sequence_for_bert_batch
max_len = max([len(it) for it in tokens_lists])
ValueError: max() arg is an empty sequence
So the code is the following one:
db_id = "baseball_1"
my_schema = dump_db_json_schema("data/sqlite_files/{db_id}/{db_id}.sqlite".format(db_id=db_id), db_id)
schema, eval_foreign_key_maps = load_tables_from_schema_dict(my_schema)
dataset = registry.construct('dataset_infer', {
"name": "spider", "schemas": schema, "eval_foreign_key_maps": eval_foreign_key_maps,
"db_path": "data/sqlite_files/"
})
for _, schema in dataset.schemas.items():
model.preproc.enc_preproc._preprocess_schema(schema)
spider_schema = dataset.schemas[db_id]
def infer(question):
data_item = SpiderItem(
text=None, # intentionally None -- should be ignored when the tokenizer is set correctly
code=None,
schema=spider_schema,
orig_schema=spider_schema.orig,
orig={"question": question}
)
model.preproc.clear_items()
enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)
preproc_data = enc_input, None
with torch.no_grad():
output = inferer._infer_one(model, data_item, preproc_data, beam_size=1, use_heuristic=True)
return output[0]["inferred_code"]
code = infer("How many players are there?")
print(code)
I tried to run the eval script on the baseball_1 and cre_Drama_Workshop_Groups DBs, however, I got an error related to the input size. How do you overcome this error?