Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 212 forks source link

Support for filetype based kwargs to be passed. #958

Closed karthikrangasai closed 2 years ago

karthikrangasai commented 2 years ago

🚀 Feature

Support keyword args for specific functions that are used for certain filetypes.

Motivation

In the current API, one thing to note is that some filetype based kwargs can't be sent to the from_* methods. For example, in the from_csv method we can't pass pandas.read_csv specific keyword arguments like sep to enable usage for tsv files etc.

For example:

>>> datamodule = TextClassificationData.from_csv(
...     "text",
...     "category",
...     train_file=os.path.join(folder, "train.tsv"),
...     val_file=os.path.join(folder, "dev.tsv"),
...     backbone="distilbert-base-uncased",
... )

Produces this error:

Using custom data configuration default-bf9cfe0d06f99407
Downloading and preparing dataset csv/default to /home/karthikrangasai/.cache/huggingface/datasets/csv/default-bf9cfe0d06f99407/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2414.68it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 470.21it/s]
0 tables [00:00, ? tables/s]Failed to read file '<_io.BufferedReader name='/home/karthikrangasai/Documents/Projects/ML_DL_Projects/code_mixing_indian_languages/split_data/codalab_hate_speech/tamil/tamil_offensive_train.tsv'>' with error <class 'pandas.errors.ParserError'>: Error tokenizing data. C error: Expected 1 fields in line 12, saw 5

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/karthikrangasai/Open_Source/PyTorch_Lightning/lightning-flash/flash/core/data/data_module.py", line 1095, in from_csv
    return cls.from_input(
  File "/home/karthikrangasai/Open_Source/PyTorch_Lightning/lightning-flash/flash/core/data/data_module.py", line 575, in from_input
    train_dataset, val_dataset, test_dataset, predict_dataset = input.to_datasets(
  File "/home/karthikrangasai/Open_Source/PyTorch_Lightning/lightning-flash/flash/core/data/io/input.py", line 307, in to_datasets
    train_dataset = self.generate_dataset(train_data, RunningStage.TRAINING)
  File "/home/karthikrangasai/Open_Source/PyTorch_Lightning/lightning-flash/flash/core/data/io/input.py", line 342, in generate_dataset
    data = load_data(data, mock_dataset)
  File "/home/karthikrangasai/Open_Source/PyTorch_Lightning/lightning-flash/flash/text/classification/data.py", line 109, in load_data
    hf_dataset, input, *other = self._to_hf_dataset(data)
  File "/home/karthikrangasai/Open_Source/PyTorch_Lightning/lightning-flash/flash/text/classification/data.py", line 94, in _to_hf_dataset
    hf_dataset, *other = self.to_hf_dataset(data)
  File "/home/karthikrangasai/Open_Source/PyTorch_Lightning/lightning-flash/flash/text/classification/data.py", line 165, in to_hf_dataset
    dataset_dict = load_dataset("csv", data_files={"train": str(file)})
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/datasets/load.py", line 1112, in load_dataset
    builder_instance.download_and_prepare(
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/datasets/builder.py", line 636, in download_and_prepare
    self._download_and_prepare(
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/datasets/builder.py", line 726, in _download_and_prepare
    self._prepare_split(split_generator, **prepare_split_kwargs)
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/datasets/builder.py", line 1185, in _prepare_split
    for key, table in utils.tqdm(
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/tqdm/std.py", line 1180, in __iter__
    for obj in iterable:
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/datasets/packaged_modules/csv/csv.py", line 144, in _generate_tables
    for batch_idx, df in enumerate(csv_file_reader):
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/pandas/io/parsers.py", line 1034, in __next__
    return self.get_chunk()
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/pandas/io/parsers.py", line 1084, in get_chunk
    return self.read(nrows=size)
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/pandas/io/parsers.py", line 1057, in read
    index, columns, col_dict = self._engine.read(nrows)
  File "/home/karthikrangasai/.virtualenvs/pl_env/lib/python3.8/site-packages/pandas/io/parsers.py", line 2036, in read
    data = self._reader.read(nrows)
  File "pandas/_libs/parsers.pyx", line 756, in pandas._libs.parsers.TextReader.read
  File "pandas/_libs/parsers.pyx", line 783, in pandas._libs.parsers.TextReader._read_low_memory
  File "pandas/_libs/parsers.pyx", line 827, in pandas._libs.parsers.TextReader._read_rows
  File "pandas/_libs/parsers.pyx", line 814, in pandas._libs.parsers.TextReader._tokenize_rows
  File "pandas/_libs/parsers.pyx", line 1951, in pandas._libs.parsers.raise_parser_error
pandas.errors.ParserError: Error tokenizing data. C error: Expected 1 fields in line 12, saw 5

Passing the sep kwarg throws an unexpected keyword argument error:

>>> datamodule = TextClassificationData.from_csv(
...     "text",
...     "category",
...     train_file=os.path.join(folder, "train.tsv"),
...     val_file=os.path.join(folder, "dev.tsv"),
...     backbone="distilbert-base-uncased",
...     sep="\t",
... )
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/karthikrangasai/Open_Source/PyTorch_Lightning/lightning-flash/flash/core/data/data_module.py", line 1095, in from_csv
    return cls.from_input(
  File "/home/karthikrangasai/Open_Source/PyTorch_Lightning/lightning-flash/flash/core/data/data_module.py", line 565, in from_input
    input_transform = input_transform or cls.input_transform_cls(
TypeError: __init__() got an unexpected keyword argument 'sep'

But the method used internally which is load_dataset accepts this:

>>> datasets = load_dataset("csv", data_files={"train": os.path.join(folder, "train.tsv")}, sep="\t")
Using custom data configuration default-166b74c78bdaaf52
Downloading and preparing dataset csv/default to /home/karthikrangasai/.cache/huggingface/datasets/csv/default-166b74c78bdaaf52/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5059.47it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 909.83it/s]
Dataset csv downloaded and prepared to /home/karthikrangasai/.cache/huggingface/datasets/csv/default-166b74c78bdaaf52/0.0.0/9144e0a4e8435090117cea53e6c7537173ef2304525df4a077c435d8ee7828ff. Subsequent calls will reuse this data.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 211.30it/s]

Pitch

Something like this should be possible:

datamodule = TextClassificationData.from_csv(
    "text",
    "category",
    train_file=os.path.join(folder, "train.tsv"),
    val_file=os.path.join(folder, "dev.tsv"),
    backbone="distilbert-base-uncased",
    sep="\t",
)

I suspect this will be the case with other tasks as well, I haven't checked it personally.

This is something Flash should support so that it can support a variety of filetypes out of the box.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.