MilaNLProc / simple-generation

A python package to run inference with HuggingFace language and vision-language checkpoints wrapping many convenient features.
Other
25 stars 2 forks source link

Truncation by tokenizer not working correctly #4

Closed lorelupo closed 2 months ago

lorelupo commented 7 months ago

Hello,

Truncation of the input_ids during tokenization, .i.e., line 336, does not work properly, throwing the following warning:

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.

And then, in the generation loop :

Error The size of tensor a (8192) must match the size of tensor b (10824) at non-singleton dimension 3 Generation failed. Skipping batch.

I suggest replacing lambda x: self.tokenizer(x["text"], truncation=True) with

lambda x: self.tokenizer(
    x["text"],
    truncation=True,
    max_length=self.model.config.max_position_embeddings-current_generation_args["max_new_tokens"]-8,
    )

and modifying the _prepare_generation_args method accordingly:

def _prepare_generation_args(self, **generation_kwargs):
        current_generation_args = self.generation_config.to_dict()

        logger.info("Setting pad_token_id to eos_token_id for open-end generation")
        current_generation_args["pad_token_id"] = self.tokenizer.eos_token_id
        current_generation_args["eos_token_id"] = self.tokenizer.eos_token_id

        # We fix when some model default to the outdated "max_length" parameter
        if "max_new_tokens" in current_generation_args:
            if "max_length" in current_generation_args:
                logger.warning(
                    "Found 'max_length' in the model's default generation config. Using 'max_new_tokens' instead."
                )
                current_generation_args.pop(
                    "max_length"
                )
        elif "max_length" in current_generation_args:
            logger.warning(
                "Found 'max_length' in the model's default generation config. Renaming it 'max_new_tokens'."
            )
            current_generation_args["max_new_tokens"] = current_generation_args.pop(
                "max_length"
            )
        else:
            current_generation_args["max_new_tokens"] = 1000

        if len(generation_kwargs) > 0:
            logger.info(
                "Custom generation args passed. Any named parameters will override the same default one."
            )
            current_generation_args.update(generation_kwargs)

        # Postprocess generation kwargs
        if (
            "temperature" in current_generation_args
            and current_generation_args["temperature"] == 0
        ):
            logger.info("Temperature cannot be 0. Setting it to 1e-4.")
            current_generation_args["temperature"] = 1e-4

        return current_generation_args

I can do a PR if needed :)

lorelupo commented 7 months ago

Actually, max_length=self.model.config.max_position_embeddings-current_generation_args["max_new_tokens"]-8, fails if max_position_embeddings is not present in the model config. Likely this happens when a model use relative position embeddings.

To avoid this, we could instead do:

        # set the maximum length of the input text for the model
        max_position_embeddings = self.model.config.max_position_embeddings if hasattr(self.model.config, "max_position_embeddings") else None
        max_length = max_position_embeddings - current_generation_args["max_new_tokens"] - 8 if max_position_embeddings else None

        # Processing the input text
        dataset = Dataset.from_dict({"text": texts})
        dataset = dataset.map(
            lambda x: self.tokenizer(
                x["text"],
                truncation=True,
                max_length=max_length,
                ),
            batched=True,
            remove_columns=["text"],
            desc="Tokenizing texts",
        )
github-actions[bot] commented 3 months ago

Stale issue message