Closed happsky closed 6 years ago
Hi @happsky,
you can try by implementing your custom dataset class as described on the official PyTorch documentation, then you can instantiate your dataset object in init_dataset()
: https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/de8eb029237a950f2aea2e78e48bc79b45d48316/src/train.py#L28-L29 and create the sampler by passing your dataset's labels list: https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/de8eb029237a950f2aea2e78e48bc79b45d48316/src/train.py#L40-L43
Let us know if it let us know if you made it work and thank you for pointing it out, I'll update the readme with a brief guide asap.
I'm closing this due to no updates. Please re-open this issue if you run in any trouble with your own dataset.
@dnlcrl I am trying to implement your code on a custom dataset. Do you happen to have a sample block of code for this?
@dnlcrl I am trying to implement your code on a custom dataset. Do you happen to have a sample block of code for this?
Here is my sample with comment , hope it will be helpful. I think the key is to understand the meaning under every code step, then you can change a little bit code to implement your cutom dataset.
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import numpy as np
import shutil
import errno
import torch
import os
IMG_CACHE = {}
class AirSARDataset(data.Dataset):
splits_folder = 'splits'
raw_folder = 'raw'
processed_folder = 'data'
def __init__(self, mode='train', root='..' + os.sep + 'airsar_dataset', transform=None):
'''
The items are (filename,category).
The index of all the categories can be found in self.idx_classes
Args:
- root: the directory where the dataset will be stored
- transform: how to transform the input
'''
super(AirSARDataset,self).__init__()
self.root = root
self.transform = transform
self.make_dirs()
if not self._check_exists():
raise RuntimeError('Dataset not found.')
self.classes = get_current_classes(os.path.join(self.root, self.splits_folder, mode + '.txt'))
self.all_items = find_items(os.path.join(self.root, self.processed_folder), self.classes)
self.idx_classes = index_classes(self.all_items)
# 所有图片路径和标签
paths, self.y = zip(*[self.get_path_label(pl) for pl in range(len(self))])
self.x = map(load_img, paths, range(len(paths)))
self.x = list(self.x)
def __len__(self):
return len(self.all_items)
def __getitem__(self, idx):
x = self.x[idx]
if self.transform:
x = self.transform(x)
return x, self.y[idx]
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder))
def get_path_label(self, index):
filename = self.all_items[index][0]
rot = self.all_items[index][-1]
# 图片完整路径
img = str.join(os.sep, [self.all_items[index][2], filename]) + rot
# 图片index
target = self.idx_classes[self.all_items[index][1] + self.all_items[index][-1]]
return img, target
def make_dirs(self):
if self._check_exists():
return
try:
os.makedirs(os.path.join(self.root, self.splits_folder))
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
# 读取(mode).txt,每行作为一个类:例如'train/A220/rot000'
def get_current_classes(fname):
with open(fname) as f:
classes = f.read().replace('/', os.sep).splitlines()
return classes
# 遍历data文件夹,只添加类别存在于classes中的文件名
# 返回数据格式:[(xxx.jpg,飞机型号(A220),图片绝对路径,旋转度数标签(rotxxx))]
def find_items(root_dir, classes):
items = []
# 图片旋转度数,不同于Omniglot,无旋转数据
rots = [os.sep + 'rot000']
for (root, dirs, files) in os.walk(root_dir):
for f in files:
r = root.split(os.sep)
lr = len(r)
# label: mode[train,test]/A220
label = r[lr-2] + os.sep + r[lr - 1]
for rot in rots:
if label + rot in classes and (f.endswith("jpg")):
items.extend([(f, label, root, rot)])
print("== Dataset: Found %d items " % len(items))
return items
# 将各个类打上index
def index_classes(items):
idx = {}
for i in items:
# i[1]为图片文件名,i[-1]为图片旋转度数:xxx.jpg/rot000
if (not i[1] + i[-1] in idx):
# 当前idx长度作为index 001.jpg/rot000->0 002.jpg/rot000->1 ...
idx[i[1] + i[-1]] = len(idx)
print("== Dataset: Found %d classes" % len(idx))
return idx
def load_img(path, idx):
path, rot = path.split(os.sep + 'rot')
if path in IMG_CACHE:
x = IMG_CACHE[path]
else:
x = Image.open(path)
IMG_CACHE[path] = x
x = x.rotate(float(rot))
x = x.resize((100, 100))
shape = 3, x.size[0], x.size[1]
x = np.array(x, np.float32, copy=False)
# 归一化
mean = np.mean(x)
std = np.std(x)
x = torch.from_numpy((x-mean)/std)
# 扩展为三通道
x = x.expand(shape)
# transpose(0, 1) 表示将张量的维度 0 和维度 1 进行交换
x = x.transpose(0, 1).contiguous().view(shape)
return x
How to train and test with own dataset?