boostcampaitech6 / level2-klue-nlp-01

level2-klue-nlp-01 created by GitHub Classroom
2 stars 1 forks source link

Max Length 확인 코드 / Results: 241 #2

Open ceo21ckim opened 10 months ago

ceo21ckim commented 10 months ago

Max Length를 확인하는 코드 공유합니다.

import os 
from tqdm.auto import tqdm 
import pandas as pd 

from transformers import AutoTokenizer
from settings import * 

# Load Dataset
dataset = pd.read_csv(os.path.join(TRAIN_DIR, 'train.csv'))

# tokenized_dataset
def preprocessing_dataset(dataset):
  """ 처음 불러온 csv 파일을 원하는 형태의 DataFrame으로 변경 시켜줍니다."""
  subject_entity = []
  object_entity = []
  for i,j in zip(dataset['subject_entity'], dataset['object_entity']):
    i = i[1:-1].split(',')[0].split(':')[1]
    j = j[1:-1].split(',')[0].split(':')[1]

    subject_entity.append(i)
    object_entity.append(j)
  out_dataset = pd.DataFrame({'id':dataset['id'], 'sentence':dataset['sentence'],'subject_entity':subject_entity,'object_entity':object_entity,'label':dataset['label'],})
  return out_dataset

# Define Function
def tokenized_dataset(dataset, model_name='klue/roberta-small'):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    concat_entity = []
    for idx, rows in tqdm(dataset.iterrows(), total=dataset.shape[0], desc='tokenizing...'):
        sub_entity, obj_entity, sentence = rows['subject_entity'], rows['object_entity'], rows['sentence']
        temp = sub_entity + '[SEP]' + obj_entity + '[SEP]' + sentence
        tokenized_sentences = tokenizer(
            temp,
            return_tensors='pt', 
            add_special_tokens=True
        )
        concat_entity.append(tokenized_sentences)

    return concat_entity

# Run 
preprocessed_data = tokenized_dataset(dataset)

# Check 
length = [len(token['input_ids'].squeeze()) for token in preprocessed_data]
max(length) # 241

ceo21ckim commented 10 months ago

EDA / Histogram

def plot_hist(li: list, bins=50, title=None, xlabel=None, ylabel=None, f_name='figure.png', save=False) -> None:
    sns.histplot(li, bins=bins, kde=True)
    plt.title(title); plt.xlabel(xlabel); plt.ylabel(ylabel)
    if save:    
        plt.savefig(os.path.join(FIG_DIR, f'{f_name}'), dpi=200)
    plt.show()

plot_hist(length, title='Sentence Length', xlabel='length', ylabel='count', f_name='sentence_length_histogram.png', save=True)

sentence_length_histogram