Lightning-Universe / lightning-transformers

Flexible components pairing 🤗 Transformers with :zap: Pytorch Lightning
https://lightning-transformers.readthedocs.io
Apache License 2.0
610 stars 77 forks source link

What format does the model expect for the data(set)? #216

Closed jmwoloso closed 2 years ago

jmwoloso commented 2 years ago

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

This is probably due to my unfamiliarity with datasets and pytorch (I've used the TF implementations of Transformers, to-date), but I'm getting the error below when trying to load custom data from csvs.

Traceback (most recent call last):
  File "train.py", line 10, in hydra_entry
    main(cfg)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
    run(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
    trainer.fit(model, datamodule=data_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
    return self._run_train()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
    self._evaluation_loop.run()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
    output = self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
    return self.training_type_plugin.validation_step(*step_kwargs.values())
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
    return self.model(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    loss = self.module(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
    return super().forward(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
    return self.common_step("val", batch)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
    outputs = self.model(**batch)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
    global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.py", line 10, in hydra_entry
    main(cfg)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
    run(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
    trainer.fit(model, datamodule=data_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
    return self._run_train()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
    self._evaluation_loop.run()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
    output = self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
    return self.training_type_plugin.validation_step(*step_kwargs.values())
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
    return self.model(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    loss = self.module(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
    return super().forward(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
    return self.common_step("val", batch)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
    outputs = self.model(**batch)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
    global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.py", line 10, in hydra_entry
    main(cfg)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
    run(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
    trainer.fit(model, datamodule=data_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
    return self._run_train()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
    self._evaluation_loop.run()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
    output = self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
    return self.training_type_plugin.validation_step(*step_kwargs.values())
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
    return self.model(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    loss = self.module(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
    return super().forward(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
    return self.common_step("val", batch)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
    outputs = self.model(**batch)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
    global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.py", line 10, in hydra_entry
    main(cfg)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 69, in main
    run(
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/cli/train.py", line 60, in run
    trainer.fit(model, datamodule=data_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
    self._call_and_handle_interrupt(
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
    self._dispatch()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1275, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in run_stage
    return self._run_train()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1307, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1371, in _run_sanity_check
    self._evaluation_loop.run()
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 110, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
    self.advance(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
    output = self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
    output = self.trainer.accelerator.validation_step(step_kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
    return self.training_type_plugin.validation_step(*step_kwargs.values())
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 896, in validation_step
    return self.model(*args, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1606, in forward
    loss = self.module(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 73, in forward
    return super().forward(*inputs, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 61, in validation_step
    return self.common_step("val", batch)
  File "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/lightning_transformers/task/nlp/text_classification/model.py", line 50, in common_step
    outputs = self.model(**batch)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/anaconda/envs/pml/lib/python3.8/site-packages/transformers/models/longformer/modeling_longformer.py", line 1854, in forward
    global_attention_mask = torch.zeros_like(input_ids)
TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not list

Code

What have you tried?

This is my current implementation of subclassing the TextClassificationDataModule

import functools

from datasets import Dataset, load_dataset
from transformers import PreTrainedTokenizerBase

from lightning_transformers.task.nlp.text_classification import TextClassificationDataModule

class PrismTextClassificationDataModule(TextClassificationDataModule):
    def __init__(self, cfg, tokenizer, **kwargs):
        # tokenizer_cfg = cfg.copy()
        # tokenizer_cfg.pop("train_file")
        # tokenizer_cfg.pop("validation_file")
        super().__init__(tokenizer, cfg)
        # self.__dict__.update({"num_classes": 2})
        # self.train_file = kwargs["train_file"]
        # self.validation_file = kwargs["validation_file"]
        self.train_file = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/train.csv"
        self.validation_file = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/val.csv"
        self.test_file = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/nd24s/code/Users/JWolosonov/lightning-transformers/test.csv"
        self.data_files = {
            "train": self.train_file,
            "validation": self.validation_file,
            # "test": self.test_file
        }

    def process_data(self, dataset, stage) -> Dataset:
        ds = load_dataset("csv", data_files=self.data_files)
        columns = [
            "InputIds", 
            # "AttentionMask", 
            "Label"]
        ds = ds.rename_column("InputIds", "input_ids")
        # ds = ds.rename_column("AttentionMask", "attention_mask")
        ds = ds.rename_column("Label", "labels")
        columns = [
            "input_ids", 
            # "attention_mask", 
            "labels"]
        ds.set_format("pytorch", columns=columns)
        print(ds.__dict__)
        # train = pd.read_csv(self.train_file)
        # train = train.rename(columns={"InputIds": "input_ids", 
        #                               "AttentionMask": "attention_mask",
        #                               "Label": "labels"})
        # val = pd.read_csv(self.validation_file)
        # val = val.rename(columns=columns={"InputIds": "input_ids", 
        #                               "AttentionMask": "attention_mask",
        #                               "Label": "labels"})
        # ds = {}
        return ds

What's your environment?

jmwoloso commented 2 years ago

after some playing around, I figured this out. I needed to implement all of the methods and properties the normal data modules have. my text was processed ahead of time, so I just needed to return the dict with the input_ids. my class implementation is below:

import functools
import json
from typing import Dict

from datasets import Dataset, load_dataset, set_caching_enabled, ClassLabel
import pandas as pd
import numpy as np
import torch.multiprocessing

from lightning_transformers.task.nlp.text_classification import TextClassificationDataModule

# https://github.com/pytorch/pytorch/issues/11201
# too many open files error
torch.multiprocessing.set_sharing_strategy("file_system")

class MyTextClassificationDataModule(TextClassificationDataModule):    
    def process_data(self, dataset, stage):

        dataset = Dataset.from_parquet(
            {"train": self.cfg.train_file,
             "validation": self.cfg.validation_file},
            columns=["input_ids", "label"]
        )

        dataset = self.preprocess(dataset)

        dataset.set_format("pytorch", columns=["input_ids", "labels"])
        self.labels = dataset["train"].features["labels"]

        return dataset

    @property
    def num_classes(self) -> int:
        return self.labels.num_classes

    @property
    def model_data_kwargs(self) -> Dict[str, int]:
        return {"num_labels": self.num_classes}

    @staticmethod
    def convert_to_features(example_batch, input_feature_fields, **fn_kwargs):
        return {"input_ids": example_batch["input_ids"]}

    @staticmethod
    def preprocess(ds, **fn_kwargs):
        ds = ds.map(
            # todo: change this to self.convert_to_features for users to override
            MyTextClassificationDataModule.convert_to_features,
            batched=True,
            with_indices=True,
            fn_kwargs=fn_kwargs,
        )
        ds.rename_column_("label", "labels")
        return ds