Closed ReactiveCJ closed 1 year ago
torch
-> pt
, tensorflow
-> tf
, flax
-> flax
. This is intended, not a bug.
Is this a custom runnable that you have?
I am running into this same issue. Given the following code:
from transformers import FuyuProcessor, FuyuForCausalLM
import bentoml
import torch
model_id = "adept/fuyu-8b"
model = FuyuForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
processor = FuyuProcessor.from_pretrained(model_id)
bentoml.transformers.save_model(
"fuyu-8b_processor",
processor,
)
bentoml.transformers.save_model(
"fuyu-8b_model",
model,
signatures={
"generate": {"batchable": False}
},
torch_dtype=torch.bfloat16,
)
And the service:
import bentoml
from PIL import Image
model = bentoml.transformers.get("fuyu-8b_model")
model.info.metadata["_framework"] = "torch"
model_runner = model.to_runner()
# This doesn't work, complaining that tensorflow doesn't exist
# model_runner = bentoml.transformers.get("fuyu-8b_model").to_runner()
processor_runner = bentoml.transformers.get("fuyu-8b_processor").to_runner()
svc = bentoml.Service(
name="visual-question-answering", runners=[model_runner, processor_runner]
)
input_spec = bentoml.io.Multipart(text=bentoml.io.Text(), image=bentoml.io.Image())
@svc.api(input=input_spec, output=bentoml.io.Text())
async def describe(text: str, image: Image) -> str:
inputs = await processor_runner.async_run(text=text, images=image, return_tensors="pt")
generated = await model_runner.generate.async_run(**inputs, max_new_tokens=16)
print(generated)
generated_text = await processor_runner.batch_decode.async_run(generated[:, -16:], skip_special_tokens=True)
print(generated_text)
return generated_text[0]
I've had to explicitly set the metadata framework to be torch
since it is pt
, and if I don't I run into this following block:
https://github.com/bentoml/BentoML/blob/main/src/bentoml/_internal/frameworks/transformers.py#L1153-L1180
Since the value of _framework
is pt
, it goes into the else block where it tries to run tensorflow
code but since I do not have tensorflow installed, it crashes with an import error.
Not sure if that if statement should be checking for pt
instead of torch
or if the bug is in the code that saves the pretrained model and doesn't set the framework correctly.
Ok that seems to be a bug then. Can you create a PR to fix this? Thanks.
Describe the bug
In the frameworks/transformers.py, when init a runnable, the code use torch to detect the framwork of transformer pretrain model. But when we prepare the bento models, the default framework name of meta data is "pt"
"torch" == bento_model.info.metadata["_framework"] -> "pt" == bento_model.info.metadata["_framework"]
To reproduce
No response
Expected behavior
No response
Environment
bentoml
: 1.1.6python
: 3.10.11platform
: Linux-5.15.90.1-microsoft-standard-WSL2-x86_64-with-glibc2.35uid_gid
: 1000:1000conda
: 23.1.0in_conda_env
: True