ohmeow / blurr

A library that integrates huggingface transformers with the world of fastai, giving fastai devs everything they need to train, evaluate, and deploy transformer specific models.
Apache License 2.0
289 stars 34 forks source link

How to use BatchLossFilter callback with blurr #52

Open maxmatical opened 3 years ago

maxmatical commented 3 years ago

I've been trying to experiment with using tsai's BatchLossFilter callback. If I try to run the training with this callback

model = HF_BaseModelWrapper(hf_model)

learn = Learner(dls, 


cbs = [BatchLossFilter(loss_perc=0.4)]

    cbs = cbs

I get the following error, which is due to the SequenceClassifierOutput object from huggingface


AttributeError                            Traceback (most recent call last)

<ipython-input-13-ce9f3d20977f> in <module>()
      2     3,
      3     lr_max=3e-5,
----> 4     cbs = cbs
      5 )

18 frames

/usr/local/lib/python3.7/dist-packages/fastai/losses.py in __call__(self, inp, targ, **kwargs)
     32         if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
     33         if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
---> 34         if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
     35         return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)

AttributeError: 'SequenceClassifierOutput' object has no attribute 'view'

Is there any way I can adapt BatchLossFilter to be functional with blurr? I haven't had any issues using other callbacks with blurr

ohmeow commented 3 years ago

First, cool library ... wasn't aware of that one!

Second, it would be helpful you can post a gist I can run so I can see full stack trace here. The inp is expecting a tensor, but in the case of Blurr (and HuggingFace) what is output at this point is an object with a bunch of info such as loss, etc.... With callbacks, I'm sure this can be altered to work with Blurr/HF.

Btw, any particular reason you're using BatchLoss? Just curious :)

maxmatical commented 3 years ago

Here is a gist using the BatchLossFilter callback with the standard blurr training script. I'm currently experimenting with BatchLossFilter since it had some traction on twitter a while back, plus intuitively, focusing on the harder examples could potentially improve performance, so extra tools in the toolbox never hurts 😄

i was able using the example implementation with other forms of data (images, timeseries etc.), so it looks to be an issue specific to huggingface if that helps