Closed jmwoloso closed 2 years ago
after some playing around, I figured this out. I needed to implement all of the methods and properties the normal data modules have. my text was processed ahead of time, so I just needed to return the dict with the input_ids
. my class implementation is below:
import functools
import json
from typing import Dict
from datasets import Dataset, load_dataset, set_caching_enabled, ClassLabel
import pandas as pd
import numpy as np
import torch.multiprocessing
from lightning_transformers.task.nlp.text_classification import TextClassificationDataModule
# https://github.com/pytorch/pytorch/issues/11201
# too many open files error
torch.multiprocessing.set_sharing_strategy("file_system")
class MyTextClassificationDataModule(TextClassificationDataModule):
def process_data(self, dataset, stage):
dataset = Dataset.from_parquet(
{"train": self.cfg.train_file,
"validation": self.cfg.validation_file},
columns=["input_ids", "label"]
)
dataset = self.preprocess(dataset)
dataset.set_format("pytorch", columns=["input_ids", "labels"])
self.labels = dataset["train"].features["labels"]
return dataset
@property
def num_classes(self) -> int:
return self.labels.num_classes
@property
def model_data_kwargs(self) -> Dict[str, int]:
return {"num_labels": self.num_classes}
@staticmethod
def convert_to_features(example_batch, input_feature_fields, **fn_kwargs):
return {"input_ids": example_batch["input_ids"]}
@staticmethod
def preprocess(ds, **fn_kwargs):
ds = ds.map(
# todo: change this to self.convert_to_features for users to override
MyTextClassificationDataModule.convert_to_features,
batched=True,
with_indices=True,
fn_kwargs=fn_kwargs,
)
ds.rename_column_("label", "labels")
return ds
❓ Questions and Help
Before asking:
What is your question?
This is probably due to my unfamiliarity with datasets and pytorch (I've used the TF implementations of Transformers, to-date), but I'm getting the error below when trying to load custom data from csvs.
Code
What have you tried?
This is my current implementation of subclassing the
TextClassificationDataModule
What's your environment?
requirements.txt: