ohmeow / blurr

A library that integrates huggingface transformers with the world of fastai, giving fastai devs everything they need to train, evaluate, and deploy transformer specific models.
https://ohmeow.github.io/blurr
Apache License 2.0
289 stars 34 forks source link

BlearnerForSummarization with t5-base errors out 'function' object has no attribute 'setup' #87

Open SiddharthPant opened 1 year ago

SiddharthPant commented 1 year ago

I tried running the example code in summarization page of doc with 't5-base' model, but it errors out. I have tried using latest release and master of blurr and fastcore but still issue persists. Here's the sample and the error it spits out:

learn = BlearnerForSummarization.from_data(
    cnndm_df,
    "t5-base",
    text_attr="article",
    summary_attr="highlights",
    max_length=256,
    max_target_length=130,
    dblock_splitter=RandomSplitter(),
    dl_kwargs={"bs": 2},
).to_fp16()

The error I get:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In [7], line 1
----> 1 learn = BlearnerForSummarization.from_data(
      2     cnndm_df,
      3     "t5-base",
      4     text_attr="article",
      5     summary_attr="highlights",
      6     max_length=256,
      7     max_target_length=130,
      8     dblock_splitter=RandomSplitter(),
      9     dl_kwargs={"bs": 2},
     10 ).to_fp16()

File ~/mambaforge/envs/aranya/lib/python3.8/site-packages/blurr/text/modeling/seq2seq/summarization.py:146, in BlearnerForSummarization.from_data(cls, data, pretrained_model_name_or_path, text_attr, summary_attr, max_length, max_target_length, dblock_splitter, hf_tok_kwargs, text_gen_kwargs, dl_kwargs, learner_kwargs)
    143 get_y = ItemGetter(summary_attr)
    145 if hf_arch == "t5":
--> 146     get_x.add(cls._add_t5_prefix)
    148 # define our DataBlock and DataLoaders
    149 batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
    150     hf_arch,
    151     hf_config,
   (...)
    156     text_gen_kwargs=text_gen_kwargs,
    157 )

File ~/mambaforge/envs/aranya/lib/python3.8/site-packages/fastcore/transform.py:204, in Pipeline.add(self, ts, items, train_setup)
    202 def add(self,ts, items=None, train_setup=False):
    203     if not is_listy(ts): ts=[ts]
--> 204     for t in ts: t.setup(items, train_setup)
    205     self.fs+=ts
    206     self.fs = self.fs.sorted(key='order')

AttributeError: 'function' object has no attribute 'setup'
SiddharthPant commented 1 year ago

I am able to workaround this by using mid-level API instead of the above high-level API function. But will just like to highlight the issue to devs.