Closed brando90 closed 1 year ago
full_omniglot.py new file:
#!/usr/bin/env python3
import os
from torch.utils.data import Dataset, ConcatDataset
from torchvision.datasets.omniglot import Omniglot
def target_transform_omni(x):
return x + 10
class FullOmniglot(Dataset):
"""
[[Source]]()
**Description**
This class provides an interface to the Omniglot dataset.
The Omniglot dataset was introduced by Lake et al., 2015.
Omniglot consists of 1623 character classes from 50 different alphabets, each containing 20 samples.
While the original dataset is separated in background and evaluation sets,
this class concatenates both sets and leaves to the user the choice of classes splitting
as was done in Ravi and Larochelle, 2017.
The background and evaluation splits are available in the `torchvision` package.
**References**
1. Lake et al. 2015. “Human-Level Concept Learning through Probabilistic Program Induction.” Science.
2. Ravi and Larochelle. 2017. “Optimization as a Model for Few-Shot Learning.” ICLR.
**Arguments**
* **root** (str) - Path to download the data.
* **transform** (Transform, *optional*, default=None) - Input pre-processing.
* **target_transform** (Transform, *optional*, default=None) - Target pre-processing.
* **download** (bool, *optional*, default=False) - Whether to download the dataset.
**Example**
~~~python
omniglot = l2l.vision.datasets.FullOmniglot(root='./data',
transform=transforms.Compose([
transforms.Resize(28, interpolation=LANCZOS),
transforms.ToTensor(),
lambda x: 1.0 - x,
]),
download=True)
omniglot = l2l.data.MetaDataset(omniglot)
~~~
"""
def __init__(self, root, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
# Set up both the background and eval dataset
omni_background = Omniglot(self.root, background=True, download=download)
self.len_background_chars = len(omni_background._characters)
# Eval labels also start from 0.
# It's important to add 964 to label values in eval so they don't overwrite background dataset.
omni_evaluation = Omniglot(self.root,
background=False,
download=download,
# target_transform=lambda x: x + len(omni_background._characters),
# target_transform=target_transform_omni,
target_transform=self.target_transform_omni,
)
self.dataset = ConcatDataset((omni_background, omni_evaluation))
self._bookkeeping_path = os.path.join(self.root, 'omniglot-bookkeeping.pkl')
def target_transform_omni(self, x):
return x + self.len_background_chars
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
image, character_class = self.dataset[item]
if self.transform:
image = self.transform(image)
if self.target_transform:
character_class = self.target_transform(character_class)
return image, character_class
make callable objects instead of lambdas. Those can be pickled.
omniglot_benchmark.py file:
#!/usr/bin/env python3
import random
import learn2learn as l2l
from torchvision import transforms
from PIL.Image import LANCZOS
# def one_minus_x(x):
# return 1.0 - x
from typing import Callable
class OneMinusX(Callable):
def __int__(self):
pass
def __call__(self, x):
return 1.0 - x
def omniglot_tasksets(
train_ways,
train_samples,
test_ways,
test_samples,
root,
device=None,
**kwargs,
):
"""
Benchmark definition for Omniglot.
"""
data_transforms = transforms.Compose([
transforms.Resize(28, interpolation=LANCZOS),
transforms.ToTensor(),
# lambda x: 1.0 - x,
# one_minus_x,
OneMinusX(),
])
omniglot = l2l.vision.datasets.FullOmniglot(
root=root,
transform=data_transforms,
download=True,
)
if device is not None:
dataset = l2l.data.OnDeviceDataset(omniglot, device=device)
dataset = l2l.data.MetaDataset(omniglot)
classes = list(range(1623))
random.shuffle(classes)
train_dataset = l2l.data.FilteredMetaDataset(dataset, labels=classes[:1100])
validation_datatset = l2l.data.FilteredMetaDataset(dataset, labels=classes[1100:1200])
test_dataset = l2l.data.FilteredMetaDataset(dataset, labels=classes[1200:])
train_transforms = [
l2l.data.transforms.FusedNWaysKShots(dataset,
n=train_ways,
k=train_samples),
l2l.data.transforms.LoadData(dataset),
l2l.data.transforms.RemapLabels(dataset),
l2l.data.transforms.ConsecutiveLabels(dataset),
l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
]
validation_transforms = [
l2l.data.transforms.FusedNWaysKShots(dataset,
n=test_ways,
k=test_samples),
l2l.data.transforms.LoadData(dataset),
l2l.data.transforms.RemapLabels(dataset),
l2l.data.transforms.ConsecutiveLabels(dataset),
l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
]
test_transforms = [
l2l.data.transforms.FusedNWaysKShots(dataset,
n=test_ways,
k=test_samples),
l2l.data.transforms.LoadData(dataset),
l2l.data.transforms.RemapLabels(dataset),
l2l.data.transforms.ConsecutiveLabels(dataset),
l2l.vision.transforms.RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
]
_datasets = (train_dataset, validation_datatset, test_dataset)
_transforms = (train_transforms, validation_transforms, test_transforms)
return _datasets, _transforms
make callable objects instead of lambdas. Those can be pickled by the pytorch dataloader due to its multiprocessing.
Summary: for functions that need params use Callable object. For pure functions defining a function with a name is fine. For examples see: https://stackoverflow.com/a/74282085/1601580
I'll close this one for now. There's a proposal for a Taskset
which is compatible with torch's DataLoader but it needs a bit more testing before we can merge it. See #255.
code:
erorr
lines to remove:
and
related: https://stackoverflow.com/questions/74249550/how-does-one-find-the-name-of-a-local-variable-that-is-a-lambda-function-in-a-me