Open vmtmxmf5 opened 3 years ago
from torch.utils.data import Dataset from torch.utils.data import DataLoader from src.utils import PandasDataset import pandas as pd import numpy as np import config class PandasDataset(Dataset): def __init__(self, path): super().__init__() train = pd.read_csv(path).iloc[:, 1:] # csv idx 제거 self.X_train, self.y_train = train.iloc[:, 4:], train.iloc[:, 0:4] self.tmp_x, self.tmp_y = self.X_train.values, self.y_train.values def __len__(self): return len(self.X_train) # dataset[idx] == 샘플 반환하도록 만들어주는 함수 def __getitem__(): return { 'X':torch.from_numpy(self.tmp_x)[idx], 'Y':torch.from_numpy(self.tmp_y)[idx] } train_path = config.TRAIN_PATH train_dataset = PandasDataset(train_path) train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, )