orobix / Prototypical-Networks-for-Few-shot-Learning-PyTorch

Implementation of Prototypical Networks for Few Shot Learning (https://arxiv.org/abs/1703.05175) in Pytorch
MIT License
986 stars 210 forks source link

How to train and test with own dataset? #2

Closed happsky closed 6 years ago

happsky commented 6 years ago

How to train and test with own dataset?

dnlcrl commented 6 years ago

Hi @happsky,

you can try by implementing your custom dataset class as described on the official PyTorch documentation, then you can instantiate your dataset object in init_dataset(): https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/de8eb029237a950f2aea2e78e48bc79b45d48316/src/train.py#L28-L29 and create the sampler by passing your dataset's labels list: https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/de8eb029237a950f2aea2e78e48bc79b45d48316/src/train.py#L40-L43 Let us know if it let us know if you made it work and thank you for pointing it out, I'll update the readme with a brief guide asap.

dnlcrl commented 6 years ago

I'm closing this due to no updates. Please re-open this issue if you run in any trouble with your own dataset.

pranay-ar commented 3 years ago

@dnlcrl I am trying to implement your code on a custom dataset. Do you happen to have a sample block of code for this?

JMYok commented 8 months ago

@dnlcrl I am trying to implement your code on a custom dataset. Do you happen to have a sample block of code for this?

Here is my sample with comment , hope it will be helpful. I think the key is to understand the meaning under every code step, then you can change a little bit code to implement your cutom dataset.

from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import numpy as np
import shutil
import errno
import torch
import os

IMG_CACHE = {}

class AirSARDataset(data.Dataset):
    splits_folder = 'splits'
    raw_folder = 'raw'
    processed_folder = 'data'

    def __init__(self, mode='train', root='..' + os.sep + 'airsar_dataset', transform=None): 
        '''
        The items are (filename,category). 
        The index of all the categories can be found in self.idx_classes
        Args:
        - root: the directory where the dataset will be stored
        - transform: how to transform the input
        '''
        super(AirSARDataset,self).__init__()
        self.root = root
        self.transform = transform
        self.make_dirs()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.')

        self.classes = get_current_classes(os.path.join(self.root, self.splits_folder, mode + '.txt'))
        self.all_items = find_items(os.path.join(self.root, self.processed_folder), self.classes)
        self.idx_classes = index_classes(self.all_items)

        # 所有图片路径和标签
        paths, self.y = zip(*[self.get_path_label(pl) for pl in range(len(self))])

        self.x = map(load_img, paths, range(len(paths)))
        self.x = list(self.x)

    def __len__(self):
        return len(self.all_items)    

    def __getitem__(self, idx):
        x = self.x[idx]
        if self.transform:
            x = self.transform(x)
        return x, self.y[idx]

    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder))

    def get_path_label(self, index):
        filename = self.all_items[index][0]
        rot = self.all_items[index][-1]
        # 图片完整路径
        img = str.join(os.sep, [self.all_items[index][2], filename]) + rot
        # 图片index
        target = self.idx_classes[self.all_items[index][1] + self.all_items[index][-1]]

        return img, target
    def make_dirs(self):
        if self._check_exists():
            return
        try:
            os.makedirs(os.path.join(self.root, self.splits_folder))
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

# 读取(mode).txt,每行作为一个类:例如'train/A220/rot000'
def get_current_classes(fname):
    with open(fname) as f:
        classes = f.read().replace('/', os.sep).splitlines()
    return classes

# 遍历data文件夹,只添加类别存在于classes中的文件名
# 返回数据格式:[(xxx.jpg,飞机型号(A220),图片绝对路径,旋转度数标签(rotxxx))]
def find_items(root_dir, classes):
    items = []
    # 图片旋转度数,不同于Omniglot,无旋转数据
    rots = [os.sep + 'rot000']
    for (root, dirs, files) in os.walk(root_dir):
        for f in files:
            r = root.split(os.sep)
            lr = len(r)
            # label: mode[train,test]/A220
            label = r[lr-2] + os.sep + r[lr - 1]
            for rot in rots:
                if label + rot in classes and (f.endswith("jpg")):
                    items.extend([(f, label, root, rot)])
    print("== Dataset: Found %d items " % len(items))
    return items

# 将各个类打上index
def index_classes(items):
    idx = {}
    for i in items:
        # i[1]为图片文件名,i[-1]为图片旋转度数:xxx.jpg/rot000
        if (not i[1] + i[-1] in idx):
            # 当前idx长度作为index 001.jpg/rot000->0  002.jpg/rot000->1 ...
            idx[i[1] + i[-1]] = len(idx)
    print("== Dataset: Found %d classes" % len(idx))
    return idx

def load_img(path, idx):
    path, rot = path.split(os.sep + 'rot')
    if path in IMG_CACHE:
        x = IMG_CACHE[path]
    else:
        x = Image.open(path)
        IMG_CACHE[path] = x
    x = x.rotate(float(rot))
    x = x.resize((100, 100))

    shape = 3, x.size[0], x.size[1]
    x = np.array(x, np.float32, copy=False)
    # 归一化
    mean = np.mean(x)
    std = np.std(x)
    x = torch.from_numpy((x-mean)/std)

    # 扩展为三通道
    x = x.expand(shape)
    # transpose(0, 1) 表示将张量的维度 0 和维度 1 进行交换
    x = x.transpose(0, 1).contiguous().view(shape)
    return x