makcedward / nlpaug

Data augmentation for NLP
https://makcedward.github.io/
MIT License
4.46k stars 463 forks source link

*** ValueError: expected sequence of length 43 at dim 1 (got 56) when using batch_size with ContextualWordEmbsForSentenceAug #266

Open kgarg8 opened 2 years ago

kgarg8 commented 2 years ago

Hi,

I encounter the following error when I try to supply a batch to nas.ContextualWordEmbsForSentenceAug. After checking other post, I expected that supplying a list of batch_size will work but it doesn't. Any suggestions will be appreciated.

# df1: my dataframe
batch_size = 32
for i in range(0, len(df1), batch_size):
        rows = df1[i:i+batch_size]
        aug = nas.ContextualWordEmbsForSentenceAug(model_path='gpt2', device='cuda', batch_size=batch_size)
        aug_text = aug.augment(rows['Column1'].tolist())

Thanks

AlexandrePieroux commented 2 years ago

I've got the same issue, did you found a solution ?

On my side I'm just sending text lists to a pipeline in which we have the ContextualWordEmbsForSentenceAug:

import nlpaug.augmenter.word as naw
import nlpaug.augmenter.sentence as nas
import nlpaug.flow as naf

pipeline = [
    naw.SynonymAug(),
    naw.AntonymAug(),
    naw.ContextualWordEmbsAug(),
    nas.ContextualWordEmbsForSentenceAug()
]
aug = naf.Sometimes(pipeline, aug_p=1/len(pipeline), verbose=1)

res = []
for index, data in df.groupby(label_col):
   aug_data = aug.augment(data[text_col].tolist(), num_thread=5)
   a_data = pd.DataFrame(aug_data, columns=['text'])
   a_data['label'] = index
   res.append(a_data)
aug_data = pd.concat(res)
kgarg8 commented 2 years ago

Unfortunately, no

AlexandrePieroux commented 2 years ago

Anyone know if this was fixed ?

nmendozam commented 2 years ago

It seems that this could be related with an error on the tokenizer as shown here. Mean while I have managed to over come this by passing each string individually to the augmenter:

from nlpaug.augmenter.sentence import ContextualWordEmbsForSentenceAug

aug = ContextualWordEmbsForSentenceAug()
for text in df["column"].tolist():
    print(aug.augment(text, num_thread=5))