kohjingyu / fromage

🧀 Code and models for the ICML 2023 paper "Grounding Language Models to Images for Multimodal Inputs and Outputs".
https://jykoh.com/fromage
Apache License 2.0
466 stars 34 forks source link

Huggingface pipeline #15

Closed Marcusntnu closed 1 year ago

Marcusntnu commented 1 year ago

Hi, trying to make a huggingface pipeline so I can run batched inference. Any ideas? Here's my attempt at a pipeline so far. With the current examples I'm not able to do batching.

' class FromagePipeline(Pipeline): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if "meta_model" in kwargs: self.meta_model = kwargs["meta_model"] return preprocess_kwargs, {}, {}

def preprocess(self, inputs):
    print(inputs)
    text = inputs["text"]
    image = inputs["image"]

    input_embs = []
    input_ids = []

    pixel_values = utils.get_pixel_values_for_model(self.meta_model.model.feature_extractor, image)
    pixel_values = pixel_values.to(device=self.meta_model.model.logit_scale.device, dtype=self.meta_model.model.logit_scale.dtype)
    pixel_values = pixel_values[None, ...]
    visual_embs = self.meta_model.model.get_visual_embs(pixel_values, mode='captioning')  # (1, n_visual_tokens, D)

    input_embs.append(visual_embs)

    text_ids = self.meta_model.model.tokenizer(text, add_special_tokens=True, return_tensors="pt").input_ids.to(self.meta_model.model.logit_scale.device)
    text_embs = self.meta_model.model.input_embeddings(text_ids)  # (1, T, D)

    input_embs.append(text_embs)
    input_ids.append(text_ids)
    input_embs = torch.cat(input_embs, dim=1)
    input_ids = torch.cat(input_ids, dim=1)

    return input_embs

def _forward(self, model_inputs):
    generated_ids, generated_embeddings, _ = self.meta_model.model.generate(model_inputs, ret_scale_factor=0)
    return generated_ids

def postprocess(self, model_outputs):
    caption = self.meta_model.model.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0]
    return utils.truncate_caption(caption)

'

Marcusntnu commented 1 year ago

Sorry for the issue posting if it's not the most suitable. What I'm really trying to do is optimize some inference pipeline, could also be with dataloader.

If not very suitable feel free to close.

kohjingyu commented 1 year ago

This looks roughly correct to me, can you explain what the issue is (or why it doesn't run at the moment)?

This is some code that I've used for (batched) captioning. It is roughly the same as what you have, but maybe it will still be helpful:

embeddings = model.model.get_visual_embs(images)
input_ids = tokenizer('A picture of', add_special_tokens=True, return_tensors="pt").input_ids.to(images.device)
prompt_embeddings = meta_model.model.input_embeddings(input_ids)
prompt_embeddings = prompt_embeddings.repeat([images.shape[0], 1, 1])
embeddings = torch.cat([embeddings, prompt_embeddings], dim=1)

gen_id, _, _ = meta_model.model.generate(embeddings, max_len=32, min_word_tokens=32)
gen_cap = meta_model.model.tokenizer.batch_decode(gen_id, skip_special_tokens=True)

gen_cap will be a list of text captions. Each one is decoded to 32 words (for batching purposes), so you will have to do some postprocessing to truncate it (e.g., to the first period) and such.