huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.7k stars 26.44k forks source link

Allow infer_framework_load_model to use the originally specified config. #31580

Open inf3rnus opened 3 months ago

inf3rnus commented 3 months ago

What does this PR do?

Allows infer_framework_load_model to use the originally specified config, currently if you specify config in the model_kwargs, you get a duplicate key error.

I don't have time to test this, but want to point out where there's a problem/inconsistency. Just check the diff, it's one line!

Fixes # (issue) Currently if you specify config in the model_kwargs, you get a duplicate key error.

Before submitting

Models:

Library:

amyeroberts commented 3 months ago

Hi @inf3rnus, thanks for opening a PR!

Could you share a code snippet which reproduces the error this PR resolves?

inf3rnus commented 3 months ago

@amyeroberts

Sure!

from transformers import pipeline, AutoConfig
from constants import MODEL, TASK
from streamer import SingleTokenStreamer

MODEL = "TheBloke/samantha-mistral-instruct-7B-GPTQ"

config = AutoConfig.from_pretrained(MODEL)

if hasattr(config, "quantization_config") and config.quantization_config is not None:
    config.quantization_config["disable_exllama"] = True

pipe = pipeline(
    TASK,
    model=MODEL,
    config=config,
    ### params ###
    device_map="auto",
    model_kwargs={"config": config},
    trust_remote_code=True,
)

Problem occurs at line 297 in base.py, in infer_framework_load_model(), because model_kwargs is packed, but config is a required argument.

The desire is to just do this, where the original config is just passed along.

from transformers import pipeline, AutoConfig
from constants import MODEL, TASK
from streamer import SingleTokenStreamer

MODEL = "TheBloke/samantha-mistral-instruct-7B-GPTQ"

config = AutoConfig.from_pretrained(MODEL)

if hasattr(config, "quantization_config") and config.quantization_config is not None:
    config.quantization_config["disable_exllama"] = True

pipe = pipeline(
    TASK,
    model=MODEL,
    config=config,
    ### params ###
    device_map="auto",
    trust_remote_code=True,
)
amyeroberts commented 3 months ago

@inf3rnus OK, thanks for sharing, makes sense!

If I use the snippet assuming text-generation pipeline and this change, the run still fails with the issue that config is in the model_kwargs. Next steps for the PR would be adding tests which do something similar to the snippet, but with a smaller model, and fix all areas of the pipeline that would need to be addressed.

inf3rnus commented 3 months ago

@amyeroberts NP! Are you guys going to do that? If I have time I'll do it, but my time is really constrained right now...

If it's just adding a test to make sure this doesn't cause a regression and prevent future regressions, I might be able to get that done pretty easily. I'm somewhat familiar with this code base, generally understand what all is happening, but I don't know every place this issue may be present. All I know is the issue is present when using the pipeline function.

inf3rnus commented 3 months ago

And I would actually propose the change be modified so that the dupe cannot happen by design. E.g. keep config in model_kwargs, and then pull it out in the function, drop config as a required param. Or something in a similar vein, there might be an even more elegant solution...

Another thought would be to not unpack model_kwargs until it's needed and default its config key's value to the top level config.... Which I like better than my initial suggestion.

amyeroberts commented 3 months ago

Time is unfortunately something we're all poor in! Maybe @Rocketknight1 if you'd have time to look into this, as you're the current pipeline master 🧙‍♂️

LysandreJik commented 2 months ago

This seems like a reasonable change, thanks @inf3rnus! Agree with Amy that adding a test or two to ensure no-regression would be good.

amyeroberts commented 1 month ago

cc @Rocketknight1

Rocketknight1 commented 1 month ago

Hey! Sorry for the delay, but I've added a test that should cover this case. However, there are some test issues caused by the fork being out of date - @inf3rnus, can you sync upstream on your fork's main branch, then rebase the PR branch onto main, and finally force-push?

inf3rnus commented 1 month ago

@Rocketknight1 Amazing! Sure, I'll do that later today when I'm back at my computer

HuggingFaceDocBuilderDev commented 1 month ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.