makcedward / nlpaug

Data augmentation for NLP
https://makcedward.github.io/
MIT License
4.46k stars 463 forks source link

Typeerror with ContextualWordEmbsAug #317

Open zhangyaqi20 opened 1 year ago

zhangyaqi20 commented 1 year ago

Hi, I was using nlpaug.augmenter.word.context_word_embs.ContextualWordEmbsAug to augment my text with bert embeddings.

Here is my code:

import torch
import nlpaug.augmenter.word.context_word_embs as naw

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
aug = naw.ContextualWordEmbsAug(
    model_path='bert-base-uncased',
    model_type='bert',
    action='substitute',
    aug_p=0.1,
    aug_min=1,
    aug_max=10,
    device=device.type, # 'cpu' or 'cuda'
    )

text = "try this text for aug"
augmented_text = aug.augment(data=text, n=10)
print("Original:")
print(text)
print("Augmented Text:")
print(augmented_text)

But I got this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-36-49516817fab2>](https://localhost:8080/#) in <module>
     12 
     13 text = "try this text for aug"
---> 14 augmented_text = aug.augment(data=text, n=10)
     15 print("Original:")
     16 print(text)

3 frames
[/usr/local/lib/python3.8/dist-packages/nlpaug/base_augmenter.py](https://localhost:8080/#) in augment(self, data, n, num_thread)
     96             elif self.__class__.__name__ in ['AbstSummAug', 'BackTranslationAug', 'ContextualWordEmbsAug', 'ContextualWordEmbsForSentenceAug']:
     97                 for _ in range(aug_num):
---> 98                     result = action_fx(clean_data)
     99                     if isinstance(result, list):
    100                         augmented_results.extend(result)

[/usr/local/lib/python3.8/dist-packages/nlpaug/augmenter/word/context_word_embs.py](https://localhost:8080/#) in substitute(self, data)
    469                 continue
    470 
--> 471             outputs = self.model.predict(masked_texts, target_words=original_tokens, n=2)
    472 
    473             # Update doc

[/usr/local/lib/python3.8/dist-packages/nlpaug/model/lang_models/bert.py](https://localhost:8080/#) in predict(self, texts, target_words, n)
    111                 seed = {'temperature': self.temperature, 'top_k': self.top_k, 'top_p': self.top_p}
    112                 target_token_logits = self.control_randomness(target_token_logits, seed)
--> 113                 target_token_logits, target_token_idxes = self.filtering(target_token_logits, seed)
    114                 if len(target_token_idxes) != 0:
    115                     new_tokens = self.pick(target_token_logits, target_token_idxes, target_word=target_token, n=10)

[/usr/local/lib/python3.8/dist-packages/nlpaug/model/lang_models/language_models.py](https://localhost:8080/#) in filtering(self, logits, seed)
    142                 logits = logits.index_select(0, idxes)
    143                 # TODO: Externalize to util for checking
--> 144                 if 'cuda' in self.device:
    145                     idxes = idxes.cpu()
    146                 idxes = idxes.detach().numpy().tolist()

TypeError: argument of type 'torch.device' is not iterable

Could you help me with this?

Thank you! Best wishes, Yaqi

zhangyaqi20 commented 1 year ago

Installing the latest version with pip install numpy git+https://github.com/makcedward/nlpaug.git solved this issue.