huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.72k stars 26.94k forks source link

isin() received an invalid combination of arguments #31040

Closed pseudotensor closed 5 months ago

pseudotensor commented 5 months ago

System Info

4.41.0 python 3.10

Who can help?

@ArthurZucker @gante

Information

Tasks

Reproduction

This change: https://github.com/huggingface/transformers/commit/7130a22db9033e47b34a5e836b6014d531179f02

Here:

https://github.com/huggingface/transformers/blob/main/src/transformers/generation/utils.py#L486-L493

is leading to alot of different software to fail with the below error:

 isin() received an invalid combination of arguments - got (test_elements=int, elements=Tensor, ), but expected one of:
 * (Tensor elements, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Number element, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Tensor elements, Number test_element, *, bool assume_unique, bool invert, Tensor out)

A work-around patch is:

--- /home/jon/miniconda3/envs/h2ogpt/lib/python3.10/site-packages/transformers/generation/utils.py      2024-05-26 01:04:39.151177467 -0700
+++ new.py      2024-05-26 01:02:53.993095157 -0700
@@ -468,12 +468,14 @@
             raise ValueError(
                 "Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
             )
+        pad_token_tensor = torch.tensor([pad_token_id], device=inputs.device) if pad_token_id is not None else None
+        eos_token_tensor = torch.tensor([eos_token_id], device=inputs.device) if eos_token_id is not None else None

         is_pad_token_in_inputs = (pad_token_id is not None) and (
-            torch.isin(elements=inputs, test_elements=pad_token_id).any()
+            torch.isin(elements=inputs, test_elements=pad_token_tensor).any()
         )
         is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
-            torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
+            torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
         )
         can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id
         attention_mask_from_padding = inputs.ne(pad_token_id).long()

E.g. Coqui XTT (no longer maintained) fails like this without the above patch.

What is going on?

Expected behavior

No failure. I expect a Number (as it says is allowed) to be converted properly to a tensor on same device so no failure.

zucchini-nlp commented 5 months ago

@pseudotensor hi! Can you share a minimal reproducible code?

If you're calling the generate(), all special tokens are converted to tensors here, before preparing attention mask

0xWOLAND commented 5 months ago

+1

0xWOLAND commented 5 months ago

@zucchini-nlp I have the same issue. This is effectively what I am doing:

model_name = 'xtts-v2'
prompt = 'This is a test prompt'

config = XttsConfig()
# Load the configuration from the model's config.json file
config.load_json(path)
model = Xtts.init_from_config(config)

model.load_checkpoint(
    config, checkpoint_dir=checkpoint_dir, eval=True
)

gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
    audio_path=samples
)
gen = model.inference_stream(
    prompt,
    language=language,
    gpt_cond_latent=gpt_cond_latent,
    speaker_embedding=speaker_embedding,
)

for chunk in gen:
    print(chunk)

And I get the error: Screenshot 2024-05-28 at 12 09 13 AM

eginhard commented 5 months ago

I opened https://github.com/idiap/coqui-ai-TTS/issues/31 in our Coqui fork. This is due to the XTTS streaming code modifying generate() and calling internal methods that have been changed in #30624. PRs welcome to fix it on our side, I'm not very familiar with that code.

0xWOLAND commented 5 months ago

@eginhard Are you aware of an older version of transformers where this streaming works?

eginhard commented 5 months ago

Anything below 4.41.0 should work. I've just released coqui-tts version 0.24.1 that limits transformers to lower versions to temporarily fix that until someone properly updates the streaming code. This issue can probably be closed because I don't think any action is needed here.

pseudotensor commented 5 months ago

No problem, I patched coqui too to handle

https://github.com/h2oai/h2ogpt/blob/main/docs/xtt.patch