naver-airush / NAVER-AI-RUSH

41 stars 20 forks source link

pytorch baseline도 추후에 제공되나요? #26

Closed jjxxmiin closed 4 years ago

jjxxmiin commented 4 years ago

Informations

CLI

WEB

What is your login ID?

jjeamin

Question

pytorch baseline도 추후에 제공이 되나요??

nsml-admin commented 4 years ago

hate speech 는 pytorch로 작성되었습니다. nsml 관련된 부분은 https://github.com/AI-RUSH-Operation/NAVER-AI-RUSH/blob/master/hate_speech/main.py#L21-L39 이쪽을 참고하시면 될것같습니다

감사합니다.

sooperset commented 4 years ago

slack 채널에 공유했었던 pytorch dataset 코드 공유드립니다! pytorch로 하시는 분들께 도움이 되면 좋겠습니다

from pathlib import Path
import shutil
from tempfile import mkdtemp
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from warnings import warn
from nsml.constants import DATASET_PATH

CLASSES = ['normal', 'monotone', 'screenshot', 'unknown', 'unlabeled']

UNLABELED = -1

# From baseline code
def prepare(base_dir: Path):
    def _initialize_directory(dataset: str) -> None:
        dataset_path = base_dir / dataset
        dataset_path.mkdir()
        for c in CLASSES:
            (dataset_path / c).mkdir()
    def _rearrange(dataset: str) -> None:
        output_dir = base_dir / dataset
        src_dir = Path(DATASET_PATH) / dataset
        metadata = pd.read_csv(src_dir / f'{dataset}_label')
        for _, row in metadata.iterrows():
            if row['annotation'] == UNLABELED:
                row['annotation'] = 4
            src = src_dir / 'train_data' / row['filename']
            if not src.exists():
                raise FileNotFoundError
            dst = output_dir / CLASSES[row['annotation']] / row['filename']
            if dst.exists():
                warn(f'File {src} already exists, this should not happen. Please notify 서동필 or 방지환.')
            else:
                shutil.copy(src=src, dst=dst)
    dataset = 'train'
    _initialize_directory(dataset)
    _rearrange(dataset)

def preprocess_train_info(base_dir: Path, sup: bool=True):
    prepare(base_dir)
    dataset_info = {
        'img_path': [],
        'label': []
    }
    for label, kind in enumerate(CLASSES):
        paths = [path for path in Path(base_dir / 'train').glob(f'{kind}/*.*') if path.suffix not in ['.gif', '.GIF']]
        for path in paths:
            dataset_info['img_path'].append(str(path))
            dataset_info['label'].append(label)
    dataset_info = pd.DataFrame(dataset_info).sample(frac=1.)
    if sup:
        # Remove unlabeled samples
        dataset_info = dataset_info[dataset_info.label != 4].reset_index(drop=True)
    train_info, valid_info = train_test_split(dataset_info, test_size=0.2)
    return train_info, valid_info

def preprocess_test_info(test_dir: str):
    dataset_info = {
        'img_path': []
    }
    paths = [path for path in (Path(test_dir) / 'test_data').glob('*.*') if path.suffix not in ['.gif', '.GIF']]
    for path in paths:
        dataset_info['img_path'].append(str(path))
    dataset_info = pd.DataFrame(dataset_info)
    return dataset_info

class SpamDataset(Dataset):
    def __init__(self, img_paths: list, labels: list,
                 num_classes: int = 4, tfms=None, test=False):
        self.img_paths = img_paths
        self.labels = labels
        self.num_classes = num_classes
        self.tfms = tfms
        self.test = test
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path)
        if self.tfms:
            image = self.tfms(image)
        if self.test:
            return image, img_path.split('/')[-1]
        else:
            return image, label
    def __len__(self):
        return len(self.img_paths)
    def get_labels(self):
        return list(self.labels)

if __name__ == '__main__':
    from torchvision.transforms import transforms
    mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
    tfms = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    """ Train & Valid """
    base_dir = Path(mkdtemp())
    train_info, valid_info = preprocess_train_info(base_dir, sup=True)
    train_dataset = SpamDataset(train_info.img_path.values,
                                train_info.label.values,
                                tfms=tfms)
    valid_dataset = SpamDataset(valid_info.img_path.values,
                                valid_info.label.values,
                                tfms=tfms)
    print(train_info, valid_info)
    print(next(iter(train_dataset)))
    """ Test at bind_model.infer(test_dir, **kwargs) """
    # test_info = preprocess_test_info(test_dir)
    # test_dataset = SpamDataset(test_info.img_path.values,
    #                            test_info.index.values,
    #                            tfms=tfms,
    #                            test=True)