paperswithlove / papers-we-read

3 stars 0 forks source link

HPT - Open Multimodal Large Language Models #15

Open runhani opened 3 months ago

runhani commented 3 months ago

HPT - Open Multimodal Large Language Models

https://github.com/HyperGAI/HPT https://huggingface.co/HyperGAI/HPT technical blog

image

image

사용된 Pretrained Models

Pretrained LLM: Yi-6B-Chat Pretrained Visual Encoder: clip-vit-large-patch14-336

HPT의 구조는?

image

image

runhani commented 3 months ago

톺아보자

LLM

Visual Encoder

projector

Hformer

    def forward(self, x_):
        if self.gradient_checkpointing and self.training:
            print('Not support gradient checkpointing')        
        x = self.ln_vision(x_)
        query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)

        # x = [1, 784, 1024], query_tokens = [1, 32, 768]
        query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=x,
                return_dict=True,
        )
        # query_output.last_hidden_state = [1, 32, 768] → [1, 32, 4096]
        q_feat = self.llm_proj(query_output.last_hidden_state)

        # x_ = [1, 784, 1024] → mlp_outptus = [1, 784, 4096]
        mlp_outputs = self.connector(x_)
        mlp_feat = mlp_outputs

        # mean [1, 4096] →  expand [1, 1, 4096]
        int_feat = mlp_feat + q_feat.mean(dim=1)[:,None]

        out = int_feat + self.ffn(int_feat)

        return out

HPT

#(3, width, height) → (3,392,392)
visual_outputs = self.visual_encoder(image, output_hidden_states=True)
# 'last_hidden_state' = [1,785,1024]
# 'pooler_output' = [1,1024]
# 'hidden_states' = 25 length lists with [1,785,1024]
# self.visual_select_layer = -2
# visual_outputs.hidden_states[-2] = [1, 785, 1024]
# visual_outputs.hidden_states[self.visual_select_layer][:, 1:] = [1, 784, 1024]
pixel_values = self.projector(visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
mm_inputs = prepare_inputs_labels_for_multimodal(llm=self.llm, input_ids=ids, pixel_values=pixel_values)

generate_output = self.llm.generate(
            **mm_inputs,
            generation_config=gen_config,
            streamer=None,
            bos_token_id=self.tokenizer.bos_token_id,
            stopping_criteria=self.stop_criteria)
        predict = self.tokenizer.decode(generate_output[0],

Combine

def prepare_inputs_labels_for_multimodal():
runhani commented 3 months ago

MMMU overall : 41.2 (hpt-air-demo-local)

image