arkilpatel / SVAMP

NAACL 2021: Are NLP Models really able to Solve Simple Math Word Problems?
MIT License
116 stars 34 forks source link

Could you provide data preprocessing code? #4

Closed pocca2048 closed 3 years ago

pocca2048 commented 3 years ago

Hi, I noticed that data preprocessing applied here (mawps, asdiv-a, svamp) is bit different than what is used in other codes.

especially transfer_num functions

ex. 7 -> number0 rather than 7 -> NUM

which can be seen at here and there

Could you provide your data preprocessing code please?

Thank you.

arkilpatel commented 3 years ago

Please find below the various functions for preprocessing:

import numpy as np
import pandas as pd
from sympy import Eq, solve
from sympy.parsing.sympy_parser import parse_expr
import sympy as sp
import unicodedata
import re
import nltk
nltk.download('punkt')
nltk.download('stopwords')
from nltk.corpus import stopwords 
from nltk.tokenize import word_tokenize

import xml.etree.ElementTree as et

import os

import spacy
import random

import stanza
nlp_stanza = stanza.Pipeline(lang='en', processors='tokenize, pos, lemma, depparse')

nlp = spacy.load("en_core_web_lg")
nlp2 = spacy.load('en_vectors_web_lg')

stopw = nlp.Defaults.stop_words

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r",", r"", s)
    return s

def stack_to_string(stack):
    op = ""
    for i in stack:
        if op == "":
            op = op + i
        else:
            op = op + ' ' + i
    return op

singles = {'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4', 'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9', 'ten': '10', 
          'eleven': '11', 'twelve': '12', 'thirteen': '13', 'fourteen': '14', 'fifteen': '15', 'sixteen': '16', 'seventeen': '17', 'eighteen': '18', 'nineteen': '19'}
doubles = {'twenty': '2', 'thirty': '3', 'forty': '4', 'fifty': '5', 'sixty': '6', 'seventy': '7', 'eighty': '8', 'ninety': '9', 'hundred': '10'}

def words2num(ques):
    fin = []
    words = ques.split()
    flag = 0
    for w in range(len(words)):
        if words[w]=="hundreds" or words[w]=="tent":
            fin.append(words[w])
            continue
        if flag>0:
            flag = flag-1
            continue
        nums = ""
        for d in doubles:
            if words[w][:len(d)] == d:
                nums = doubles[d]
                if len(words[w]) > len(d):
                    if words[w][len(d)] == '-':
                        sec_word = words[w][len(d)+1:]
                    else:
                        sec_word = words[w][len(d):]
                    nums = nums + singles[sec_word]
                else:
                    try:
                        if words[w+1] in singles:
                            nums = nums + singles[words[w+1]]
                            flag = 1
                        elif words[w+1] == '-':
                            if words[w+2] in singles:
                                nums = nums + singles[words[w+2]]
                                flag = 2
                        else:
                            nums = nums + '0'
                    except:
                        nums = nums + '0'

        if len(nums) == 0:
            for s in singles:
                if words[w][:len(s)] == s:
                    try:
                        if words[w-1] == 'each':
                            continue
                    except:
                        nums = singles[s]
                    else:
                        nums = singles[s]

        if len(nums) > 0:
            fin.append(nums)
        else:
            fin.append(words[w])
    return stack_to_string(fin)

def format_eq(eq):
    fin_eq = ""
    ls = ['0','1','2','3','4','5','6','7','8','9','.']
    temp_num = ""
    flag = 0
    for i in eq:
        if flag == 1:
            fin_eq = fin_eq + i
            flag = 0
        elif i == 'n':
            flag = 1
            if fin_eq == "":
                fin_eq = fin_eq + i
            else:
                fin_eq = fin_eq + ' ' + i
        elif i in ls:
            temp_num = temp_num + i
        elif i == ' ':
            if temp_num == "":
                continue
            else:
                if fin_eq == "":
                    fin_eq = fin_eq + temp_num
                else:
                    fin_eq = fin_eq + ' ' + temp_num

            temp_num = ""
        else:
            if fin_eq == "":
                if temp_num == "":
                    fin_eq = fin_eq + i
                else:
                    fin_eq = fin_eq + temp_num + ' ' + i
            else:
                if temp_num == "":
                    fin_eq = fin_eq + ' ' + i
                else:
                    fin_eq = fin_eq + ' ' + temp_num + ' ' + i
            temp_num = ""
    if temp_num != "":
        fin_eq = fin_eq + ' ' + temp_num
    return fin_eq

def reverse_eq(eq):
    stack = eq.split()
    new_stack = []
    for i in range(len(stack)-1, -1, -1):
        if stack[i] == '(':
            new_stack.append(')')
        elif stack[i] == ')':
            new_stack.append('(')
        else:
            new_stack.append(stack[i])

    return new_stack

def num_align(eq, list_num):
    ls_num = []
    for l in list_num:
        ls_num.append(l)
    elements = eq.split()

    index = 0
    num_post = 1

    for i in range(len(elements)):
        for j in range(len(ls_num)):
            try:
                curr_ele = float(elements[i])
                curr_ls_num = float(ls_num[j])
            except Exception as e:
                continue
            else:
                if curr_ele == curr_ls_num:
                    elements[i] = 'n'+str(j)
                    ls_num[j] = ''
                    break

    return elements       

def op_precedence(op1, op2):
    dict = {'/': 1, '*': 2, '+': 3, '-': 3}
    return dict[op1] < dict[op2] # if 1, op1 has more precedence

def infix_to_postfix(eq_list):
    operators = ['/', '*', '+', '-']
    postfix = []
    stack = []
    for e in eq_list:
        if e[0] == 'n': # It's operand
            postfix.append(e)
        else:
            if e == '(':
                stack.append(e)
            elif e == ')':
                while(stack[-1] != '('):
                    postfix.append(stack.pop(-1))
                stack.pop(-1)
            else:
                if e in operators and (len(stack) == 0 or stack[-1] == '(' or (stack[-1] in operators and op_precedence(e, stack[-1]))):
                    stack.append(e)
                else:
                    while(stack[-1] in operators and op_precedence(stack[-1], e)):
                        postfix.append(stack.pop(-1))
                        if len(stack) == 0:
                            break
                    stack.append(e)
    while(len(stack)>0):
        postfix.append(stack.pop(-1))
    return postfix

def make_prefix(aligned_eq):
    reversed_aligned = reverse_eq(aligned_eq)
    postfix = stack_to_string(infix_to_postfix(reversed_aligned))
    prefix = stack_to_string(reverse_eq(postfix))
    return prefix

def change_side(eq):
    mov_side = ""
    stable_side = ""
    sides = eq.split(' = ')
    if sides[0].find('x') == -1:
        if sides[0][0] != '(' or sides[0][-1] != ')':
            mov_side = '-' + ' ' + '(' + ' ' + sides[0] + ' ' + ')'
        else:
            mov_side = '-' + ' ' + sides[0]
        if sides[1][0] != '(' or sides[1][-1] != ')':
            stable_side = '(' + ' ' + sides[1] + ' ' + ')'
        else:
            stable_side = sides[1]
    else:
        if sides[1][0] != '(' or sides[1][-1] != ')':
            mov_side = '-' + ' ' + '(' + ' ' + sides[1] + ' ' + ')'
        else:
            mov_side = '-' + ' ' + sides[1]
        if sides[0][0] != '(' or sides[0][-1] != ')':
            stable_side = '(' + ' ' + sides[0] + ' ' + ')'
        else:
            stable_side = sides[0]
    return stable_side + ' ' + mov_side

def evaluator(eq):
    x,n0,n1,n2,n3,n4,n5,n6,n7,n8,n9 = sp.symbols('x, n0, n1, n2, n3, n4, n5, n6, n7, n8, n9')
    expr = parse_expr(eq, evaluate = True)
    op = str(solve(expr, x)[0])
    return op

def find_operands(eq):
    operators = ['+', '-', '*', '/']
    elements = eq.split()
    end = 0
    cnt = 0
    if eq[0] in operators:
        cnt = 2
    else:
        cnt = 1
    for i in range(1, len(elements)):
        if elements[i] in operators:
            cnt += 1
        else:
            cnt -= 1
        if cnt == 0:
            end = i
            break
    return stack_to_string(elements[:end+1]), stack_to_string(elements[end+1:])

def fin_make_prefix(eq):
    if eq[0] == '-':
        prefix = make_prefix(format_eq(eq[1:]))
        index = 0
        for i in range(len(prefix)):
            if prefix[i] == '-':
                index = i
                break
        op1, op2 = find_operands(prefix[index+2:])
        final_prefix = prefix[:i+1] + ' ' + op2 + ' ' + op1
        return final_prefix
    elif (eq[0] == '(' and eq[2] == '-'):
        elements = eq.split()
        elements.pop(1)
        elements[2] = '-'
        temp = elements[1]
        elements[1] = elements[3]
        elements[3] = temp
        return make_prefix(stack_to_string(elements))
    else:
        return make_prefix(format_eq(eq))

def final_prefix(eq, num_list):
    eq = format_eq(eq)
    aligned = num_align(eq, num_list)
    prefix = fin_make_prefix(format_eq(stack_to_string(aligned)))
    return prefix

def add_group_nums(sent):
    sent = re.sub(r"-", r"", sent)
    sent = re.sub(r"mrs.", r"mrs", sent)
    sent_nums = re.findall('\d*\.?\d+', sent)
    doc = nlp_stanza(sent)
    sent = word_tokenize(sent)

    final_ids = []
    assoc_nouns = []
    adjectives = []
    assoc_verbs = []
    rates = []

    offset = 0

    for s in doc.sentences:
        last_id = 0
        for word in s.words:
            if word.text in sent_nums:
                final_ids.append(offset + word.id-1)
                if offset + (word.id-1) - 1 >= 0 and sent[offset + (word.id-1) - 1] not in [',', '.', ';']:
                    final_ids.append(offset + (word.id-1) - 1)
                if offset + (word.id-1) + 1 < len(sent) and sent[offset + (word.id-1) + 1] not in [',', '.', ';']:
                    final_ids.append(offset + (word.id-1) + 1)
                if word.deprel in ['nummod', 'nmode']:
                    assoc_nouns.append(s.words[word.head-1].text)
                    final_ids.append(offset + word.head-1)
            if word.text in ['each', 'every', 'per']:
                rates.append(word.text)
                final_ids.append(offset + word.id-1)
            last_id = word.id
        offset += last_id

    offset = 0

    for s in doc.sentences:
        last_id = 0
        for word in s.words:
            if word.deprel == 'amod':
                if s.words[word.head-1].text in assoc_nouns:
                    adjectives.append(word.text)
                    final_ids.append(word.id-1)

            if word.text in assoc_nouns and word.deprel in ['obj', 'nsubj']:
                assoc_verbs.append(s.words[word.head-1].text)
                final_ids.append(word.head-1)

            last_id = word.id
        offset += last_id

    if len(sent)-4 >= 0 and sent[len(sent)-4] not in [',', '.', ';']:
        final_ids.append(len(sent)-4)
    if len(sent)-3 >= 0 and sent[len(sent)-3] not in [',', '.', ';']:
        final_ids.append(len(sent)-3)
    if len(sent)-2 >= 0 and sent[len(sent)-2] not in [',', '.', ';']:
        final_ids.append(len(sent)-2)

    return list(set(final_ids))

def load_data(df):
    fin_data = []
    cnt = 0
    for d in range(len(df)):
        flag = 0
        try:
            eq = df.loc[d]['Equation']
        except Exception as e:
            print("Equation read exception at: ", df.loc[d]['ID'])
            continue
        else:
            body = words2num(stack_to_string(word_tokenize(normalizeString(df.loc[d]['Body']))))
            ques_stmt = words2num(stack_to_string(word_tokenize(normalizeString(df.loc[d]['Ques_Statement']))))
            ques = words2num(stack_to_string(word_tokenize(normalizeString(df.loc[d]['Question']))))
            group_nums = add_group_nums(ques)
            word_ques = ques.split()
            ques_nums = re.findall('\d*\.?\d+', ques)
            try:
                eq_nums = re.findall('\d*\.?\d+', eq)
            except:
                print("Equation numbers exception at: ", df.loc[d]['ID'])
                continue

            if len(ques_nums) != len(eq_nums):
                flag = 1
                if len(eq_nums) > len(ques_nums):
                    continue

            ####### INDEXES ########

            idx = "" 
            for i in range(len(ques_nums)):
                for j in range(len(word_ques)):
                    if word_ques[j].find(ques_nums[i]) != -1 and str(j) not in idx.split():
                        if idx == "":
                            idx = idx + str(j)
                        else:
                            idx = idx + " " + str(j)
                        break

            ques_numz = ques_nums
            eq_numz = eq_nums

            ques_nums = [float(q) for q in ques_nums]
            eq_nums = [float(q) for q in eq_nums]

            ####### SNI LABELS ########

            sni_label = ""

            for i in range(len(ques_nums)):
                if ques_nums[i] in eq_nums:
                    if sni_label == "":
                        sni_label = sni_label + "1"
                    else:
                        sni_label = sni_label + " " + "1"
                else:
                    if sni_label == "":
                        sni_label = sni_label + "0"
                    else:
                        sni_label = sni_label + " " + "0"

            temp_eq = format_eq(eq)

            aligned = num_align(temp_eq, ques_nums)
            final_eq = make_prefix(stack_to_string(aligned))

            ####### ANSWER #######

            ans = df.loc[d]['Answer']

            if flag == 1:
                cnt+=1

            fin_data.append([df.loc[d]['ID'], ques, stack_to_string(ques_numz), eq, final_eq, idx, 
                             sni_label, ans, group_nums, df.loc[d]['Grade'], df.loc[d]['Type'], body, ques_stmt])

    print(cnt)
    print(len(fin_data))
    return fin_data

def preprocess(data):
    new_data = []
    for i in range(len(data)):
        id1 = data.loc[i]['ID']
        ques = data.loc[i]['Question']
        nums = data.loc[i]['Numbers']
        idx = data.loc[i]['Indexes']
        sni = data.loc[i]['SNI_Labels']
        eqn = data.loc[i]['Prefix_Equation']
        ans = data.loc[i]['Answer']
        group_nums = data.loc[i]['group_nums']
        grade = data.loc[i]['Grade']
        type1 = data.loc[i]['Type']
        body = data.loc[i]['Body']
        ques_stmt = data.loc[i]['Ques_Statement']

        words = ques.split()
        ques_nums = nums.split()
        components = eqn.split()
        indexes = idx.split()
        sni_labels = sni.split()

        for j in range(len(indexes)):
            words[int(indexes[j])] = 'number'+str(j)

        for j in range(len(components)):
            if components[j][0] == 'n':
                components[j] = 'number' + components[j][1]

        body_len = len(body.split())

        new_data.append([id1, stack_to_string(words), nums, stack_to_string(components), ans, group_nums, grade, type1, 
                         stack_to_string(words[:body_len]), stack_to_string(words[body_len:])])

    return new_data

As an example, to apply the above functions on ASDiv-a dataset, the code is:

xtree = et.parse("ASDiv.xml")
xroot = xtree.getroot()

arith_ids = []

for filename in os.listdir(os.path.join(os.getcwd(), 'folds')):
    try:
        with open(os.path.join(os.getcwd(), 'folds', filename), 'r') as f:
            for l in f.readlines():
                arith_ids.append(l.split('\n')[0])
            print(len(arith_ids))
    except Exception as e:
        continue

df_cols = ["ID", "Question", "Equation", "Answer", "Grade", "Type", "Body", "Ques_Statement"]
rows = []

cnt = 0

types = {}

for node in xroot[0]: 
    id = node.attrib.get("ID")
    grade = node.attrib.get("Grade")
    body = node.find("Body").text if node is not None else None
    ques = node.find("Question").text if node is not None else None
    answer = node.find("Answer").text if node is not None else None
    eq = node.find("Formula").text if node is not None else None
    type1 = node.find("Solution-Type").text if node is not None else None

    if id in arith_ids:
        if type1 in types:
            types[type1] += 1
        else:
            types[type1] = 1

        question = body + " " + ques

        eq1 = eq.split("=")[0]
        eq2 = eq.split("=")[1]

        if len(eq2.split(" ")) == 2 and len(eq2.split(" ")[1])>0:
            cnt+=1
            ans = float(parse_expr(eq1, evaluate = True))
            print(id, eq1, ans)
        else:
            ans = float(answer.split(" ")[0])
            if float(eq2) != ans:
                print("ALERT", id)

        rows.append({"ID": id, "Question": question, "Equation": eq1, "Answer": ans, "Grade": grade, "Type": type1, "Body": body, "Ques_Statement": ques})

df = pd.DataFrame(rows, columns = df_cols)
print(cnt)

fin_data = load_data(df)

prelim_df = pd.DataFrame(columns=['ID', 'Question', 'Numbers', 'Orig_Equation', 'Prefix_Equation', 'Indexes', 'SNI_Labels', 'Answer', 'group_nums', 'Grade', 'Type', 'Body', 'Ques_Statement'])

for i in range(len(fin_data)):
    prelim_df.loc[i] = fin_data[i]

full_data = preprocess(prelim_df)
len(full_data)

full_df = pd.DataFrame(columns=['ID', 'Question', 'Numbers', 'Equation', 'Answer', 'group_nums', 'Grade', 'Type', 
                                'Body', 'Ques_Statement'])

for i in range(len(full_data)):
    full_df.loc[i] = full_data[i]

where full_df is the final preprocessed dataframe.

pocca2048 commented 3 years ago

Thank you very much!