vmtmxmf5 / Pytorch-

pytorch로 머신러닝~딥러닝 구현
3 stars 0 forks source link

Pandas DataLoader #13

Open vmtmxmf5 opened 3 years ago

vmtmxmf5 commented 3 years ago
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from src.utils import PandasDataset
import pandas as pd
import numpy as np
import config

class PandasDataset(Dataset):
    def __init__(self, path):
        super().__init__()
        train = pd.read_csv(path).iloc[:, 1:] # csv idx 제거
        self.X_train, self.y_train = train.iloc[:, 4:], train.iloc[:, 0:4]
        self.tmp_x, self.tmp_y = self.X_train.values, self.y_train.values

    def __len__(self):
        return len(self.X_train)

    # dataset[idx] == 샘플 반환하도록 만들어주는 함수
    def __getitem__():
        return {
            'X':torch.from_numpy(self.tmp_x)[idx],
            'Y':torch.from_numpy(self.tmp_y)[idx]
        }

train_path = config.TRAIN_PATH
train_dataset = PandasDataset(train_path)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, sampler=None,
                              batch_sampler=None, num_workers=0, collate_fn=None,
                              )