thunlp / OpenPrompt

An Open-Source Framework for Prompt-Learning.
https://thunlp.github.io/OpenPrompt/
Apache License 2.0
4.38k stars 455 forks source link

Not compatible with huggingface dataset #289

Open quang-anh-nguyen opened 1 year ago

quang-anh-nguyen commented 1 year ago

Hello, I was trying to use the PromptDataLoader for an instance of the datasets.Dataset class, as shown in the below code.

train_loader = opr.PromptDataLoader(
    dataset=datasets['train'], 
    template=template, 
    tokenizer=tokenizer,
    tokenizer_wrapper_class=wrapper_plm
)

But I always get the error

NotImplementedError                       Traceback (most recent call last)
Cell In[253], line 3
      1 from openprompt.data_utils import InputExample
----> 3 train_loader = opr.PromptDataLoader(
      4     dataset=datasets['train'], 
      5     template=template, 
      6     tokenizer=tokenizer,
      7     tokenizer_wrapper_class=wrapper_plm
      8 )

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/openprompt/pipeline_base.py:100, in PromptDataLoader.__init__(self, dataset, template, tokenizer_wrapper, tokenizer, tokenizer_wrapper_class, verbalizer, max_seq_length, batch_size, shuffle, teacher_forcing, decoder_max_length, predict_eos_token, truncate_method, drop_last, **kwargs)
     96 assert hasattr(self.template, 'wrap_one_example'), "Your prompt has no function variable \
     97                                                  named wrap_one_example"
     99 # process
--> 100 self.wrap()
    101 self.tokenize()
    103 if self.shuffle:

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/openprompt/pipeline_base.py:129, in PromptDataLoader.wrap(self)
    127         self.wrapped_dataset.append(wrapped_example)
    128 else:
--> 129     raise NotImplementedError

NotImplementedError: 

When I looked at the source code, apparently the reason is that in the PromptDataLoader.wrap method, dataset must be torch.utils.data.Dataset or List[InputExample]. However, changing the class will be very complicated.

    def wrap(self):
        r"""A simple interface to pass the examples to prompt, and wrap the text with template.
        """
        if isinstance(self.raw_dataset, Dataset) or isinstance(self.raw_dataset, List):
            assert len(self.raw_dataset) > 0, 'The dataset to be wrapped is empty.'
            # for idx, example in tqdm(enumerate(self.raw_dataset),desc='Wrapping'):
            for idx, example in enumerate(self.raw_dataset):
                if self.verbalizer is not None and hasattr(self.verbalizer, 'wrap_one_example'): # some verbalizer may also process the example.
                    example = self.verbalizer.wrap_one_example(example)
                wrapped_example = self.template.wrap_one_example(example)
                self.wrapped_dataset.append(wrapped_example)
        else:
            raise NotImplementedError

Can you please make it compatible with datasets.Dataset, since I believe that many people use huggingface? Thank you.