boostcampaitech5 / level2_klue-nlp-04

level2_klue-nlp-04 created by GitHub Classroom
1 stars 0 forks source link

[Data] Train / Validation 데이터 Split #2

Closed dbsrlskfdk closed 1 year ago

dbsrlskfdk commented 1 year ago

현재

Try

기대사항

dbsrlskfdk commented 1 year ago

train.py

from torch.utils.data import random_split

RE_total_dataset = RE_Dataset(tokenized_train, train_label)
RE_train_dataset, RE_val_dataset = random_split(RE_total_dataset, [int(len(RE_total_dataset)*0.8), int(len(RE_total_dataset)*0.2)])

load_data.py

def preprocessing_dataset(dataset):
  # # BaseLine Code
  # """ 처음 불러온 csv 파일을 원하는 형태의 DataFrame으로 변경 시켜줍니다."""
  # subject_entity = []
  # object_entity = []
  # for i,j in zip(dataset['subject_entity'], dataset['object_entity']):
  #   i = i[1:-1].split(',')[0].split(':')[1]
  #   j = j[1:-1].split(',')[0].split(':')[1]
  #
  #   subject_entity.append(i)
  #   object_entity.append(j)
  # out_dataset = pd.DataFrame({'id':dataset['id'], 'sentence':dataset['sentence'],'subject_entity':subject_entity,'object_entity':object_entity,'label':dataset['label'],})
  out_dataset = dataset
  return out_dataset
halimx2 commented 1 year ago

load_data.py

from sklearn.model_selection import train_test_split

def load_data(dataset_dir, use_data):
    pd_dataset = pd.read_csv(dataset_dir)
    train_dataset, dev_dataset = train_test_split(data, test_size=0.2, random_state=42)

    if use_data == 'train':
        dataset = preprocessing_dataset(train_dataset)
        return dataset
    else :
        dataset = preprocessing_dataset(dev_dataset)
        return dataset

이렇게 하면 마지막 random_state 변수로 seed도 고정되는 것 같습니다.

dbsrlskfdk commented 1 year ago

load_data.py

from sklearn.model_selection import train_test_split

def load_data(dataset_dir, use_data):
    pd_dataset = pd.read_csv(dataset_dir)
    train_dataset, dev_dataset = train_test_split(data, test_size=0.2, random_state=42)

    if use_data == 'train':
        dataset = preprocessing_dataset(train_dataset)
        return dataset
    else :
        dataset = preprocessing_dataset(dev_dataset)
        return dataset

이렇게 하면 마지막 random_state 변수로 seed도 고정되는 것 같습니다.

@4-trees-in-summer

3 이슈를 통하면, random_split seed가 같이 고정됩니다. sklearn 쓰기 귀찮아서... ㅋㅋㅋㅋ torch 모듈 안에서 다 해결하려고 이렇게 작성했습니다. 제가 나중에 seed 고정이랑 다 해서 PR 올리겠습니다.