Blaizzy / mlx-vlm

MLX-VLM is a package for running Vision LLMs locally on your Mac using MLX.
MIT License
144 stars 12 forks source link

[Feature Request] Direct Python Interface for mlx_vlm.generate #8

Closed s-smits closed 1 month ago

s-smits commented 2 months ago

Currently, the mlx_vlm.generate function can only be called from the command line using python -m mlx_vlm.generate. I would like to request a direct Python interface for this function, allowing me to call it from my Python code without having to use the command line.

The desired API would be similar to the following:

from mlx_vlm import generate # Maybe add a load function just as in mlx_lm.generate?

image = "http://images.cocodataset.org/val2017/000000039769.jpg"
caption = generate(model='qnguyen3/nanoLLaVA',
                image=image,
                # processor = Automatically determined by the model choice
                # image_processor = Automatically determined by the model choice
                prompt="Describe this image.",
                temp=0.0,
                max_tokens=100,
                verbose=False,
                formatter=None,
                repetition_penalty=None,
                repetition_context_size=None,
                top_p=1
                )

This would allow me to easily integrate the mlx_vlm.generate function into my Python code without calling a subprocess, and use it to generate captions for images programmatically.

s-smits commented 2 months ago
from mlx_vlm import generate 
from transformers import (
    AutoProcessor,
    PreTrainedTokenizer,
)
model = 'qnguyen3/nanoLLaVA'
caption = vlm_generate(model=model,
        image=image_path,
        processor=PreTrainedTokenizer,
        image_processor=AutoProcessor,
        prompt="Describe this image.",
        temp=0.0,
        max_tokens=100,
        verbose=True,
        formatter=None,
        repetition_penalty=None,
        repetition_context_size=None,
        top_p=1
)

gives

File "/Users/air/Repositories/test-repo/src/main.py", line 19, in <module>
  main()
File "/Users/air/Repositories/test-repo/src/main.py", line 13, in main
  all_data = process_data(folders, args)
File "/Users/air/Repositories/test-repo/src/data_processing_and_inference.py", line 11, in process_data
  file_processor(folders, args)
File "/Users/air/Repositories/test-repo/src/prepare_data.py", line 205, in file_processor
  concatenate_files(folders, args.word_limit, file_queue, args.images, args)
File "/Users/air/Repositories/test-repo/src/prepare_data.py", line 178, in concatenate_files
  file_text = process_file(file_path, file_type, images, args)
File "/Users/air/Repositories/test-repo/src/prepare_data.py", line 157, in process_file
  return read_pdf_file(file_path, args, images=images)
File "/Users/air/Repositories/test-repo/src/prepare_data.py", line 100, in read_pdf_file
  caption = caption_image_file(image_index, args)  
File "/Users/air/Repositories/test-repo/src/prepare_data.py", line 36, in caption_image_file
  caption = vlm_generate(model=model,
File "/opt/homebrew/lib/python3.10/site-packages/mlx_vlm/utils.py", line 691, in generate
  prompt_tokens = mx.array(processor.encode(prompt))
TypeError: PreTrainedTokenizerBase.encode() missing 1 required positional argument: 'text'
Blaizzy commented 2 months ago

This is currently possible, and I will add more documentation examples.

Here how you do it:

model_path = "mlx-community/nanoLLaVA"
model_path = get_model_path(model_path)
model, processor = load(model_path)
config = load_config(model_path)
image_processor = load_image_processor(config)

prompt = processor.apply_chat_template(
    [{"role": "user", "content": f"<image>\nWhat's so funny about this image?"}],
    tokenize=False,
    add_generation_prompt=True,
)

image_path =  "./assets/image.png",
output = generate(model, processor, image_path prompt, image_processor, verbose=False)
Blaizzy commented 2 months ago

Let me know if it works for you :)

s-smits commented 2 months ago

It works for me, however it's still quit some overhead to import 5 functions from two different libraries. Do you happen to have any plans to integrate a more high-level solution where only model, processor = load ... and generate = ... are only needed?

Additionaly, adding trust_remote_code for this automatically would be a plus: The repository for /Users/air/.cache/huggingface/hub/models--mlx-community--nanoLLaVA/snapshots/hash contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co//Users/air/.cache/huggingface/hub/models--mlx-community--nanoLLaVA/snapshots/hash. You can avoid this prompt in future by passing the argumenttrust_remote_code=True.

I didn't have this pop-up when generating with the bash line.

Blaizzy commented 2 months ago

It works for me, however it's still quit some overhead to import 5 functions from two different libraries. Please, show me an example,

Do you happen to have any plans to integrate a more high-level solution where only model, processor = load ... and generate = ... are only needed?

Yes, I do.

Nanollava style models use image_processor alongside processor whilst Llava doesn't. So I want to make it explicit.

Perhaps this would be better:

from mlx_vlm import load, generate

model_path = "mlx-community/nanoLLaVA"
model, processor = load(model_path)
image_processor = load_image_processor(model_path) # None for Llava

prompt = processor.apply_chat_template(
    [{"role": "user", "content": f"<image>\nWhat's so funny about this image?"}],
    tokenize=False,
    add_generation_prompt=True,
)

image_path =  "./assets/image.png",
output = generate(model, processor, image_path prompt, image_processor, verbose=False)

What do you think?

s-smits commented 2 months ago

Great, integrating get_model_path into load and config into load_image_processor will reduce the overhead quite a bit. Looks good!

Blaizzy commented 2 months ago

Awesome!

This was already done a day ago ✅ https://github.com/Blaizzy/mlx-vlm/pull/7

It will be a part of the v0.0.4 release.