pytorch / text

Models, data loaders and abstractions for language processing, powered by PyTorch
https://pytorch.org/text
BSD 3-Clause "New" or "Revised" License
3.48k stars 816 forks source link

using a Sampler() with Iterator() #283

Open michaelcapizzi opened 6 years ago

michaelcapizzi commented 6 years ago

Hi all -

Hoping that there is someone out there who has figured out a solution to my problem.

I am trying to train a mean-teacher (https://github.com/CuriousAI/mean-teacher) network that uses both labeled and unlabeled data. It's very important that each batch have some labeled data, and the implementation linked above builds a custom Sampler() to ensure that.

But I have found that I can't use the torch.DataLoader() class because it does not expect Example instances, which were built from torchtext.Dataset().

So my question has two parts:

  1. Is there a way to make torchtext.Datasets play nice with torch.DataLoader? or.....
  2. Is there something I can do to torchtext.Iterator to force a particular label distribution in each batch?
ProKil commented 5 years ago

For 1

You can write a customized collate_fn to deal with torchtext.data.Example objects.

Here I wrote a new collate_fn function for torchtext.data.Batch. You can easily adapt it to torchtext.data.Example.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data

import torchtext

# remember to import these
from torch.utils.data.dataloader import _use_shared_memory, int_classes, string_classes 
import collections

def torchtext_collate(batch):
    r"""Slightly different from default_collate: add torchtext.data.Batch to it.
        Puts each data field into a tensor with outer dimension batch size"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))
            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], collections.Mapping):
        return {key: torchtext_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], torchtext.data.Batch):   # difference here
        return {key: torchtext_collate([getattr(d, key) for d in batch]) for key in batch[0].dataset.fields.keys()}
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [torchtext_collate(samples) for samples in transposed]

    raise TypeError((error_msg.format(type(batch[0]))))