rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
348 stars 130 forks source link

HuggingFace datasets wrapper #1257

Open albertz opened 1 year ago

albertz commented 1 year ago

We should be able to use HuggingFace datasets directly in RETURNN.

I guess the most canonical way would be to write a RETURNN Dataset for this. Maybe derived from CachedDataset2.

A separate independent more direct PyTorch dataset wrapper might make sense. Or actually I think not needed, as HuggingFace already directly supports this?

dthulke commented 1 year ago

I guess the most canonical way would be to write a RETURNN Dataset for this. Maybe derived from CachedDataset2.

I already implemented this as a custom dataset some time ago (see implementation below).

Things to discuss are maybe how to handle tokenisation and other preprocessing steps. In the current implementation there is either the option to define a map function (see https://huggingface.co/docs/datasets/about_map_batch) in the config file or to provide a preprocessed dataset stored with save_to_disk.

num_outputs and data_type is currently extracted from the dataset features attribute (which may not always be available).

Another point to discuss is how to handle caching for large datasets (currently the default caching mechanism of hf datasets is used which does not fit very well with our setups).

Code ```python import numpy from returnn.datasets.basic import DatasetSeq from returnn.datasets.cached2 import CachedDataset2 from returnn.util.basic import OptionalNotImplementedError class HuggingfaceDataset(CachedDataset2): @staticmethod def kwargs_update_from_config(config, kwargs): super().kwargs_update_from_config(config, kwargs) if 'map_func' in kwargs: if isinstance(kwargs['map_func'], str): kwargs['map_func'] = config.typed_value(kwargs['map_func']) def __init__(self, dataset_opts, map_func=None, map_func_args=None, data_key='data', seq_tag_key='id', features=None, **kwargs): super(HuggingfaceDataset, self).__init__(**kwargs) self._seq_order = None self.dataset_opts = dataset_opts if isinstance(map_func, str): from returnn.config import get_global_config config = get_global_config(raise_exception=False) map_func = config.typed_value(map_func) if map_func_args is not None: map_func = map_func(**map_func_args) self.map_func = map_func self.dataset = None self.data_key = data_key self.seq_tag_key = seq_tag_key self.feature_keys = features self.data_dtype = {} def initialize(self): # Load the dataset import datasets if isinstance(self.dataset_opts, dict): self.dataset = datasets.load_dataset(**self.dataset_opts) else: self.dataset = datasets.load_from_disk(self.dataset_opts) assert isinstance(self.dataset, datasets.Dataset) if self.map_func is not None: self.dataset = self.map_func(self.dataset) if self.feature_keys is None: self.feature_keys = list(self.dataset.features.keys()) if self.seq_tag_key is not None and self.seq_tag_key in self.feature_keys: self.feature_keys.remove(self.seq_tag_key) else: assert False, "Dataset does not have a seq_tag" self.dataset.set_format('numpy') if self.seq_tag_key is not None: assert self.seq_tag_key in self.dataset.column_names self.labels = {} self.num_outputs = {} for key in self.feature_keys: feature = self.dataset.features[key] dtype = None num_classes = None spatial_dims = 0 while type(feature) is datasets.features.Sequence: spatial_dims += 1 if feature.length != -1: num_classes = feature.length feature = feature.feature if type(feature) is datasets.features.ClassLabel: self.labels[key] = feature.names dtype = feature.dtype num_classes = feature.num_classes elif type(feature) is datasets.features.Value: dtype = feature.dtype elif isinstance(feature, (datasets.features.Array2D, datasets.features.Array3D, datasets.features.Array4D)): dtype = feature.dtype num_classes = feature.shape[-1] spatial_dims += len(feature.shape) else: assert False, f"Unsupported feature type {type(feature)}" len_shape = spatial_dims self.num_outputs[key] = [num_classes, len_shape] self.data_dtype[key] = dtype super().initialize() def get_data_dim(self, key): if key in self.num_outputs: return self.num_outputs[key][0] return super().get_data_dim(key) def get_data_dtype(self, key): return self.data_dtype[key] def _get_seq_len(self, seq_idx): return len(self.dataset[seq_idx][self.data_key]) @property def num_seqs(self): assert self._seq_order is not None, "num_seqs is only known after calling init_seq_order()" return len(self._seq_order) def get_tag(self, sorted_seq_idx): return self.dataset[int(self.get_corpus_seq_idx(sorted_seq_idx))][self.seq_tag_key] def get_all_tags(self): return list(self.dataset[self.seq_tag_key]) def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): """ :param int|None epoch: :param list[str]|None seq_list: List of sequence tags, to set a predefined order. :param list[int]|None seq_order: List of corpus sequence indices, to set a predefined order. :rtype: bool :returns whether the order changed (True is always safe to return) """ super(HuggingfaceDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) if seq_order: self._seq_order = seq_order # TODO can we return False? return True if seq_list: all_tags = self.get_all_tags() self._seq_order = [all_tags.index(tag) for tag in seq_list] # TODO can we return False? return True try: self._seq_order = self.get_seq_order_for_epoch( epoch=epoch, num_seqs=self.dataset.num_rows, get_seq_len=self._get_seq_len) except OptionalNotImplementedError: # only support seq_ordering that need no length here assert self.seq_ordering in ["default", "reverse", "random"] self._seq_order = self.get_seq_order_for_epoch( epoch=epoch, num_seqs=self.dataset.num_rows, get_seq_len=None) return True def _collect_single_seq(self, seq_idx): """ :param int seq_idx: sorted seq idx :return: """ corpus_seq_idx = self.get_corpus_seq_idx(seq_idx) def ensure_numpy(x): if not isinstance(x, numpy.ndarray): return numpy.array(x) return x dataset_item = self.dataset[int(corpus_seq_idx)] features = {f: ensure_numpy(dataset_item[f]) for f in self.feature_keys} return DatasetSeq( seq_idx, features=features, targets=None, seq_tag=dataset_item[self.seq_tag_key] ) def get_current_seq_order(self): """ :rtype: list[int] """ assert self._seq_order is not None return self._seq_order def get_corpus_seq_idx(self, sorted_seq_idx): """ :param int sorted_seq_idx: :return corpus_seq_idx :rtype: int """ return self._seq_order[sorted_seq_idx] def can_serialize_data(self, key): return True def serialize_data(self, key, data): if key in self.labels: return super().serialize_data(key, data) if isinstance(data, numpy.ndarray): data = data.tolist() return data ```
albertz commented 1 year ago

Related is the Sisyphus job to prepare HuggingFace datasets (https://github.com/rwth-i6/i6_core/pull/253). Doesn't this handle the caching? Ideally we should prepare our dataset wrapper here such that it works properly together with this download preparation job.

albertz commented 1 year ago

@dthulke Can you say some examples what HF datasets you use?

dthulke commented 1 year ago

Related is the Sisyphus job to prepare HuggingFace datasets (https://github.com/rwth-i6/i6_core/pull/253). Doesn't this handle the caching? Ideally we should prepare our dataset wrapper here such that it works properly together with this download preparation job.

Yes, this handles the caching of the initial dataset download, but not the caching of the processed version (via dataset.map). But we could add a separate job for this.

@dthulke Can you say some examples what HF datasets you use?

In RETURNN, I mainly use hf datasets for sequence classification (e.g. sentiment analysis) or sequence tagging task (named entity recognition). For example: https://huggingface.co/datasets/conll2003

In addition, I have a few datasets the I load with custom dataset loading scripts or the default json dataset implementation.

One example, for additional preprocessing (beyond tokenisation) is to include document-level/cross-sentence context or to convert NER labels (given as start end positions) to BIO labels.