LYH-YF / MWPToolkit

MWPToolkit is an open-source framework for math word problem(MWP) solvers.
MIT License
162 stars 37 forks source link

How to perform inference for only one math problem? #26

Closed ghost closed 2 years ago

ghost commented 2 years ago

I also have another doubt, when I run the evaluation on the mawps dataset, parenthesis are gone and signs are messed up on the predicted answer:

Example: { "id": 460, "prediction": "= x - + 47.0 40.0 25.0", --> this should be X=((47.0+40.0)-25.0) "target": "= x - + 47.0 40.0 25.0", "number list": [ "47.0", "25.0", "40.0" ], "value acc": true, "equ acc": true },

ghost commented 2 years ago

@LYH-YF Can you help me please?

LYH-YF commented 2 years ago

I'm sorry that we haven't completely implement function for predict any math problem.

In order to implement predicting any math problem, we have realized dataloader.build_batch_for_predict() and model.predict(), you can call the two function to predict any problem after finishing preprocessing for the problem.

Here i can show you a sample code to perform inference for only one math problem. You can change somewhere for your requirement.

import os
import re
from copy import deepcopy

import nltk
import torch

from mwptoolkit.config import Config
from mwptoolkit.data.utils import get_dataset_module, get_dataloader_module
from mwptoolkit.utils.enum_type import SpecialTokens, NumMask
from mwptoolkit.utils.preprocess_tool.number_operator import english_word_2_num
from mwptoolkit.utils.preprocess_tool.number_transfer import get_num_pos
from mwptoolkit.utils.utils import get_model, str2float

def main():
    # a MWP sample
    problem = 'a math word problem : 5 + 7 = ?'

    # load config
    temp_config = Config()
    config = Config.load_from_pretrained(temp_config['trained_model_dir'])
    # load dataset parameters
    dataset = get_dataset_module(config).load_from_pretrained(config['trained_model_dir'])
    dataset.dataset_load()
    dataloader = get_dataloader_module(config)(config, dataset)
    # load model parameters
    model = get_model(config['model'])(config, dataset)
    model_file = os.path.join(config['trained_model_dir'], 'model.pth')
    state_dict = torch.load(model_file, map_location=config["map_location"])
    model.load_state_dict(state_dict["model"], strict=False)

    # preprocess
    language = 'english' if config['language'] == 'en' else 'zh'
    word_list = nltk.word_tokenize(problem, language)
    word_list = english_word_2_num(word_list,fraction_acc=2)
    pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?|(-\d+)")

    # input_seq, num_list, final_pos, all_pos, nums, num_pos_dict, nums_for_ques, nums_fraction
    process_data = get_num_pos(word_list, config['mask_symbol'], pattern)
    source = deepcopy(process_data[0])
    for pos in process_data[3]:
        for key, value in process_data[5].items():
            if pos in value:
                num_str = key
                break
        num = str(str2float(num_str))
        source[pos] = num
    source = ' '.join(source)
    data_dict = {'question': process_data[0], 'number list': process_data[1], 'number position': process_data[2],
                 'ques source 1': source}

    # build batch data
    batch = dataloader.build_batch_for_predict([data_dict])

    # predict
    token_logits,symbol_outputs,_ = model.predict(batch)

    # output process
    symbol_list = dataloader.convert_idx_2_symbol(symbol_outputs[0])
    equation = []
    for symbol in symbol_list:
        if symbol not in [SpecialTokens.SOS_TOKEN, SpecialTokens.EOS_TOKEN, SpecialTokens.PAD_TOKEN]:
            equation.append(symbol)
        else:
            break

    def trans_symbol_2_number(equ_list, num_list):
        symbol_list = NumMask.number
        new_equ_list = []
        for symbol in equ_list:
            if 'NUM' in symbol:
                index = symbol_list.index(symbol)
                if index >= len(num_list):
                    new_equ_list.append(symbol)
                else:
                    new_equ_list.append(str(num_list[index]))
            else:
                new_equ_list.append(symbol)
        return new_equ_list

    equation = trans_symbol_2_number(equation, data_dict['number list'])
    # final equation
    print(equation)

if __name__ == '__main__':
    main()

I hope it helps for you.

LYH-YF commented 2 years ago

A correct equation can be transformed into a binary tree. Different traversal methods of binary trees correspond to different node traversal orders. The "messed-up" signs are the nodes traversed by preorder.X=((47.0+40.0)-25.0) is inorder traversal and the parenthesis makes sure about calculation priority. However, preorder keeps its calculation priority so there is no need to insert parenthesis.

LYH-YF commented 2 years ago

@Gabriel11101

ghost commented 2 years ago

Thank you very much!