Aleph-Alpha / magma

MAGMA - a GPT-style multimodal model that can understand any combination of images and language. NOTE: The freely available model from this repo is only a demo. For the latest multimodal and multilingual models from Aleph Alpha check out our website https://app.aleph-alpha.com
MIT License
469 stars 55 forks source link

(#9) Improved inference interface #12

Closed Mayukhdeb closed 2 years ago

Mayukhdeb commented 2 years ago

Contains the following changes:

  1. the model, tokenizer, and transforms are now contained under a unified wrapper: Magma() which can be used as shown below:
    
    from multimodal_fewshot import Magma

magma = Magma( checkpoint_path = 'mp_rank_00_model_states.pt', ## downloads automatically if not present in this path config_path = 'configs/MAGMA_v1.yml', ) magma.to('cuda:0')


2. Image inputs are now handled by `ImageInput()`, which supports both urls and local image paths.
```python
inputs =[
    ImageInput('https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg'),
    'Describe the painting:'
]
  1. Magma() supports both low level and high level inference
    
    ## forward pass
    embeddings = magma.preprocess_inputs(inputs = inputs) ## returns a torch tensor of shape (1, sequence_length, hidden_dim)
    outputs = magma(embeddings) ## output logits shape: torch.Size([1, 150, 50400])

high level inference

completion = magma.generate(inputs = inputs, num_tokens = 4, topk = 1)

completion: "A cabin on a lake"

Mayukhdeb commented 2 years ago

Why not just add these functions into MultimodalLM?

This is a good idea, but as we can see, the forward function seen in MultimodalLM is not really straightforward. So I thought of creating a different wrapper which simplifies things for inference purposes (given the fact that a large fraction of it's users would purely do inference)

Of course one thing that I missed out is the internal generate() function :sweat_smile: -- will make necessary changes to use that instead.

I don't think it's a good idea to call model.eval() at init, might cause some problems when training

You're right, will put that in as a default arg instead on __init__ as:eval = True. If someone feels like training it, they can just set eval = False on init

you can wrap this whole function in a no grad, instead of just doing no grad over the forward pass

Yes, rookie mistake on my side :slightly_smiling_face:

So the to-do is as follows:

Feel free to let me know if you want any changes to the to-do list :eyes: