huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.62k stars 25.51k forks source link

Mismatched tensor size error when generating text with beam_search on mps #30662

Open zoryzhang opened 2 months ago

zoryzhang commented 2 months ago

System Info

Who can help?

@gante, @ArthurZucker and @younesbelkada

Information

Tasks

Reproduction

Code sample:

def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

device = "mps"
model = "google-t5/t5-small"
generator = T5ForConditionalGeneration.from_pretrained(model).to(device)
tokenizer = AutoTokenizer.from_pretrained(model)
tokenized_state = tokenizer(
    "alias lt_of_ofNat_lt_ofNat, ofNat_lt_ofNat_of_lt := ofNat_lt",
    padding="longest",
    max_length=1024,
    truncation=True,
    return_tensors="pt",
)
state_ids = tokenized_state.input_ids.to(device)
state_mask = tokenized_state.attention_mask.to(device)
seed_everything(4)
output = generator.generate(
    input_ids=state_ids,
    attention_mask=state_mask,
    max_length=512,
    num_beams=16,
    length_penalty=0,
    do_sample=False,
    num_return_sequences=16,
    early_stopping=False,
    output_scores=True,
    return_dict_in_generate=True,
)

When device = "mps" it will fail, while it's fine when device = "cpu".

Error message:

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/transformers/generation/utils.py:1655, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1648     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1649         input_ids=input_ids,
   1650         expand_size=generation_config.num_beams,
   1651         is_encoder_decoder=self.config.is_encoder_decoder,
   1652         **model_kwargs,
   1653     )
   1654     # 13. run beam search
-> 1655     result = self._beam_search(
   1656         input_ids,
   1657         beam_scorer,
   1658         logits_processor=prepared_logits_processor,
   1659         stopping_criteria=prepared_stopping_criteria,
   1660         pad_token_id=generation_config.pad_token_id,
   1661         output_scores=generation_config.output_scores,
   1662         output_logits=generation_config.output_logits,
   1663         return_dict_in_generate=generation_config.return_dict_in_generate,
   1664         synced_gpus=synced_gpus,
   1665         sequential=generation_config.low_memory,
   1666         **model_kwargs,
   1667     )
   1669 elif generation_mode == GenerationMode.BEAM_SAMPLE:
   1670     # 11. prepare logits warper
   1671     logits_warper = self._get_logits_warper(generation_config)

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/transformers/generation/utils.py:3261, in GenerationMixin._beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, output_logits, return_dict_in_generate, synced_gpus, sequential, **model_kwargs)
   3258     if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
   3259         this_peer_finished = True
-> 3261 sequence_outputs = beam_scorer.finalize(
   3262     input_ids,
   3263     beam_scores,
   3264     next_tokens,
   3265     next_indices,
   3266     pad_token_id=pad_token_id,
   3267     eos_token_id=eos_token_id,
   3268     max_length=stopping_criteria.max_length,
   3269     beam_indices=beam_indices,
   3270     decoder_prompt_len=decoder_prompt_len,
   3271 )
   3273 if return_dict_in_generate:
   3274     if not output_scores:

File ~/miniforge3/envs/py310/lib/python3.10/site-packages/transformers/generation/beam_search.py:404, in BeamSearchScorer.finalize(self, input_ids, final_beam_scores, final_beam_tokens, final_beam_indices, max_length, pad_token_id, eos_token_id, beam_indices, decoder_prompt_len)
    402 print(f"hypo: {hypo}, best_idx: {best_idx}, sent_lengths: {sent_lengths}, sent_max_len: {sent_max_len}")
    403 print(f"size hypo: {hypo.size()}, size decoded: {decoded.size()}")
--> 404 decoded[i, : sent_lengths[i]] = hypo
    406 if indices is not None:
    407     indices[i, : len(best_idx)] = torch.tensor(best_idx)

RuntimeError: The expanded size of the tensor (512) must match the existing size (114) at non-singleton dimension 0.  Target sizes: [512].  Tensor sizes: [114]

My effort: Unsure will this be related.

In "transformers/generation/beam_search.py", I added print(f"hypo: {hypo}, best_idx: {best_idx}, sent_lengths: {sent_lengths}, sent_max_len: {sent_max_len}") print(f"size hypo: {hypo.size()}, size decoded: {decoded.size()}") before decoded[i, : sent_lengths[i]] = hypo and a few other prints, and only to see

sent_lengths=tensor([-4485327733731293260, -4483870734795238840, -4483349007937621898,
        -4482544311454870149, -4481825153540684182, -4481758581547501948,
        -4481391817110174768, -4481007383177368869,               257025,
                      289153,               321281,               353409,
                      385537,               417665,               449793,
                      481921], device='mps:0')
sent_length[0] = len(best_hyp)=1
sent_length[1] = len(best_hyp)=114
sent_length[2] = len(best_hyp)=114
sent_length[3] = len(best_hyp)=114
sent_length[4] = len(best_hyp)=114
sent_length[5] = len(best_hyp)=114
sent_length[6] = len(best_hyp)=114
sent_length[7] = len(best_hyp)=104
sent_length[8] = len(best_hyp)=114
sent_length[9] = len(best_hyp)=106
sent_length[10] = len(best_hyp)=114
sent_length[11] = len(best_hyp)=114
sent_length[12] = len(best_hyp)=114
sent_length[13] = len(best_hyp)=114
sent_length[14] = len(best_hyp)=114
sent_length[15] = len(best_hyp)=114
hypo: tensor([0], device='mps:0'), best_idx: (tensor(0, device='mps:0'),), sent_lengths: tensor([3240385333, 3240385333, 3240385333, 3240385333, 3240385333, 3240385333,
        3240385333, 3240385333, 3240385333, 3240385333, 3240385333, 3240385333,
        3240385333, 3240385333, 3240385333, 3240385333], device='mps:0'), sent_max_len: 512
size hypo: torch.Size([1]), size decoded: torch.Size([16, 512])
hypo: tensor([    0, 32099,     3,    40,    17,   834,   858,   834,   858,   567,
          144,   834,    40,    17,   834,   858,   567,   144,     6,    13,
          567,   144,   834,    40,    17,   834,   858,   567,   144,   834,
          858,   834,    40,    17,     3,    10,  2423,    13,   567,   144,
          834,    40,    17,   834,   858,   567,   144,   834,   858,   834,
           40,    17,     3,    10,  2423,    13,   567,   144,   834,    40,
           17,   834,   858,   567,   144,   834,   858,   834,    40,    17,
            3,    10,  2423,    13,   567,   144,   834,    40,    17,   834,
          858,   567,   144,   834,   858,   834,    40,    17,     3,    10,
         2423,    13,   567,   144,   834,    40,    17,   834,   858,   567,
          144,   834,   858,   834,    40,    17,     3,    10,  2423,    13,
          567,   144,   834,    40], device='mps:0'), best_idx: (tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(3, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(1, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(0, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(2, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(2, device='mps:0'), tensor(1, device='mps:0'), tensor(1, device='mps:0'), tensor(0, device='mps:0')), sent_lengths: tensor([3240385333, 3240385333, 3240385333, 3240385333, 3240385333, 3240385333,
        3240385333, 3240385333, 3240385333, 3240385333, 3240385333, 3240385333,
        3240385333, 3240385333, 3240385333, 3240385333], device='mps:0'), sent_max_len: 512
size hypo: torch.Size([114]), size decoded: torch.Size([16, 512])

All print statements I added:

print(f"sent_lengths={sent_lengths}")

        # retrieve best hypotheses
        for i in range(batch_size):
            beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
            candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
            sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
            for j in range(self.num_beam_hyps_to_keep):
                best_hyp_tuple = sorted_hyps.pop()
                best_score = best_hyp_tuple[0]
                best_hyp = best_hyp_tuple[1]
                best_index = best_hyp_tuple[2]
                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

                # append hyp to lists
                best.append(best_hyp)

                print(f"sent_length[{self.num_beam_hyps_to_keep * i + j}] = len(best_hyp)={len(best_hyp)}")
                #print(f"best_hyp: {best_hyp}, self.num_beam_hyps_to_keep * i + j: {self.num_beam_hyps_to_keep * i + j}, sent_lengths: {sent_lengths}, max_length: {max_length}")

                # append indices to list
                best_indices.append(best_index)

                best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

        # prepare for adding eos
        sent_lengths_max = sent_lengths.max().item() + 1
        sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
        decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)

        if len(best_indices) > 0 and best_indices[0] is not None:
            indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
        else:
            indices = None

        # shorter batches are padded if needed
        if sent_lengths.min().item() != sent_lengths.max().item():
            if pad_token_id is None:
                raise ValueError("`pad_token_id` has to be defined")
            decoded.fill_(pad_token_id)

        if indices is not None:
            indices.fill_(-1)

        # fill with hypotheses and eos_token_id if the latter fits in
        for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
            print(f"hypo: {hypo}, best_idx: {best_idx}, sent_lengths: {sent_lengths}, sent_max_len: {sent_max_len}")
            print(f"size hypo: {hypo.size()}, size decoded: {decoded.size()}")
            decoded[i, : sent_lengths[i]] = hypo

            if indices is not None:
                indices[i, : len(best_idx)] = torch.tensor(best_idx)

            if sent_lengths[i] < sent_max_len:
                # inserting only the first eos_token_id
                decoded[i, sent_lengths[i]] = eos_token_id[0]

Expected behavior

There should be no error, and print (as indicated by using device="cpu")

sent_lengths=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
sent_length[0] = len(best_hyp)=1
sent_length[1] = len(best_hyp)=114
sent_length[2] = len(best_hyp)=114
sent_length[3] = len(best_hyp)=114
sent_length[4] = len(best_hyp)=114
sent_length[5] = len(best_hyp)=114
sent_length[6] = len(best_hyp)=114
sent_length[7] = len(best_hyp)=104
sent_length[8] = len(best_hyp)=114
sent_length[9] = len(best_hyp)=106
sent_length[10] = len(best_hyp)=114
sent_length[11] = len(best_hyp)=114
sent_length[12] = len(best_hyp)=114
sent_length[13] = len(best_hyp)=114
sent_length[14] = len(best_hyp)=114
sent_length[15] = len(best_hyp)=114
hypo: tensor([0]), best_idx: (tensor(0),), sent_lengths: tensor([  1, 114, 114, 114, 114, 114, 114, 104, 114, 106, 114, 114, 114, 114,
        114, 114]), sent_max_len: 115
size hypo: torch.Size([1]), size decoded: torch.Size([16, 115])
hypo: tensor([    0, 32099,     3,    40,    17,   834,   858,   834,   858,   567,
          144,   834,    40,    17,   834,   858,   567,   144,     6,    13,
          567,   144,   834,    40,    17,   834,   858,   567,   144,   834,
          858,   834,    40,    17,     3,    10,  2423,    13,   567,   144,
          834,    40,    17,   834,   858,   567,   144,   834,   858,   834,
           40,    17,     3,    10,  2423,    13,   567,   144,   834,    40,
           17,   834,   858,   567,   144,   834,   858,   834,    40,    17,
            3,    10,  2423,    13,   567,   144,   834,    40,    17,   834,
          858,   567,   144,   834,   858,   834,    40,    17,     3,    10,
         2423,    13,   567,   144,   834,    40,    17,   834,   858,   567,
          144,   834,   858,   834,    40,    17,     3,    10,  2423,    13,
          567,   144,   834,    40]), best_idx: (tensor(0), tensor(0), tensor(0), tensor(3), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(0), tensor(1), tensor(0), tensor(0), tensor(1), tensor(1), tensor(1), tensor(1), tensor(0), tensor(0), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(0), tensor(0), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(0), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(1), tensor(2), tensor(1), tensor(1), tensor(1), tensor(2), tensor(1), tensor(1), tensor(0)), sent_lengths: tensor([  1, 114, 114, 114, 114, 114, 114, 104, 114, 106, 114, 114, 114, 114,
        114, 114]), sent_max_len: 115
size hypo: torch.Size([114]), size decoded: torch.Size([16, 115])

instead of many "3240385333"

gante commented 1 month ago

Hi @zoryzhang 👋

We are refactoring several generate sections, and beam search is one of the targets -- it's likely that this bug will be squashed in the process. If it is not fixed on main a month from now, please ping me again 🤗

daria-kashina commented 1 week ago

Hi @gante, still face with the same beam search problem when using mps GPU. Could you please check once again was it fixed or not?

Platform: macOS-14.4.1-arm64 (Apple M2 _Pro) Python 3.11.9

transformers version: 4.41.2 huggingface-hub 0.23.4 safetensors 0.4.3 tensorflow 2.16.1 torch 2.2.2

Thank you in advance