boostcampaitech4lv23nlp1 / final-project-level3-nlp-03

Multi-Modal Model for DocVQA(Document Visual Question Answering)
3 stars 0 forks source link

bounding box check #11

Closed Ssunbell closed 1 year ago

Ssunbell commented 1 year ago
from PIL import Image, ImageDraw
from IPython.core.display import display
import pandas as pd
import json
import editdistance
from nltk.tag import pos_tag
import re
from typing import *

class FindCategory():
    def __init__(self, data_path='data/test/', result_path='./', start_idx=2582, end_idx=3874):
        self.data_path = data_path
        self.result_path = result_path
        self.dataset = self.make_dataset(start_idx, end_idx)

    def NLD(self, s1:str,s2:str) -> float:
        return editdistance.eval(s1.lower(),s2.lower()) / ((len(s1)+len(s2))/2) # normalized_levenshtein_distance

    def clean_text(self,raw_string:str) -> str:
        #텍스트에 포함되어 있는 특수 문자 제거
        text = re.sub('[-=+,#?^$.@*※~!…]','', raw_string)

        return text

    def find_noun_ngram(self,questions:str, ngram:int) -> Set[Tuple[str]]:
        if ngram == 1: # unigram일 경우
            part_of_speech = {'NN', 'NNS','NNP', 'NNPS', 'POS','RP', 'CD', 'FW', 'VBG'}
        else:
            part_of_speech = {'NN', 'NNS','NNP', 'NNPS', 'IN', 'POS','RP', 'CD', 'FW', 'VBG', 'JJR', 'JJS', 'RBR', 'RBS'}

        result = set()
        questions:List[str] = questions.split()
        ngram_questions = [questions[i:i+ngram] for i in range(len(questions) - (ngram-1))]
        for question in ngram_questions:
            tmp_storage = []
            for tag in pos_tag(question):
                if tag[1] in part_of_speech:
                    tmp_storage.append(self.clean_text(tag[0]))

            if len(tmp_storage) == ngram:
                result.add(tuple(tmp_storage))

        yield from result
    def make_dataset(self, start_idx, end_idx):
        df_raw = pd.read_csv(self.result_path + 'result.csv', index_col=0)
        df_raw = df_raw.drop('6', axis=1).drop('7', axis=1).drop('0', axis=1).reset_index().drop('index',axis=1)
        df_raw = df_raw.rename(columns={
            '1':'question_id',
            '2' : 'score',
            '3':'question',
            '4':'prediction',
            '5':'answer'}
        )
        with open(self.data_path+'test_v1.0.json', 'r') as f:
            json_data = json.load(f)
        df_test = pd.DataFrame(eval(json.dumps(json_data))['data'])

        df = df_raw.iloc[start_idx : end_idx+1]
        result = []
        for _, row_test in df.iterrows():
            row = df_test[row_test['question_id'] == df_test['questionId']]
            image_file = row['image']
            img = Image.open("data/test/"+image_file.iloc[0]).convert('RGB')
            draw = ImageDraw.Draw(img)
            ocr_file = 'data/test/ocr_results/' + row['ucsf_document_id'] + '_' + row['ucsf_document_page_no'] + '.json'
            with open(ocr_file.iloc[0], 'r') as f:
                json_data = json.load(f)
            answers = set(row_test['answer'].replace("\'", '').replace('[','').replace(']','').split(','))
            ngrams = 3
            question_words = sum([[ngram_question for ngram_question in self.find_noun_ngram(row_test['question'], ngram)] for ngram in range(1, ngrams+1)], [])
            for answer in answers:
                for data in json_data['recognitionResults']:
                    indices_blue = []
                    indices_red = []
                    for i, line in enumerate(data['lines']):
                        percent_question = 0
                        words_list = [word['text'] for word in line['words']]

                        ###### question 찾기 ######

                        for words_q in question_words:
                            for word_q in words_q:
                                check = False
                                for word in words_list:
                                    ld_score = self.NLD(word_q, word)

                                    if ld_score < 0.2:
                                        check = True
                                if check:
                                    percent_question += 1

                            if percent_question / len(line['words']) > 0.6:
                                indices_blue.append(i)

                        ###### answer 찾기 ######
                        percent_answer = 0

                        for word_a in answer.split():
                            check = False
                            for word in words_list:
                                ld_score = self.NLD(word_a, word)

                                if ld_score < 0.2:
                                    check = True
                            if check:
                                percent_answer += 1

                        if percent_answer / len(line['words']) > 0.3:
                            indices_red.append(i)

                for bbox_i in indices_blue: # 질문
                    bboxes = json_data['recognitionResults'][0]['lines'][bbox_i]['boundingBox']
                    draw.rectangle((bboxes[0],bboxes[1],bboxes[4],bboxes[5]), outline=(0,0,255), width = 3)
                for bbox_i in indices_red: # 정답
                    bboxes = json_data['recognitionResults'][0]['lines'][bbox_i]['boundingBox']
                    draw.rectangle((bboxes[0],bboxes[1],bboxes[4],bboxes[5]), outline=(255,0,0), width = 3)

            result.append({'question': row_test['question'], 'answer' : list(answers), 'img':img})

        return result

    def __getitem__(self, idx):
        data = self.dataset[idx]
        print('question :',data['question'])
        print('answer :', data['answer'])
        img = data['img']
        display(img.resize((int(img.width/2), int(img.height/2))))

파란색 박스가 질문과 관련된 단어가 있는 바운딩 박스, 빨간색이 정답이 있는 바운딩 박스

스크린샷 2023-01-15 오후 3 45 14