BAAI-DCAI / Bunny

A family of lightweight multimodal models.
Apache License 2.0
799 stars 61 forks source link

How to output last hidden state during inference? #62

Closed GewelsJI closed 2 months ago

GewelsJI commented 2 months ago

Hi, Bunny team,

Thanks for providing such a nice project.

I am trying to extract the last hidden state of bunny-phi, and I add an extra parameter here, like this:

output_ids = model.generate(
                input_ids,
                images=image_tensor.unsqueeze(0).to(dtype=model.dtype, device='cuda', non_blocking=True),
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                top_p=args.top_p,
                num_beams=args.num_beams,
                # no_repeat_ngram_size=3,
                max_new_tokens=1024,
                use_cache=True,
                # what I add !
                return_dict_in_generate=True,
                output_hidden_states=True)

Next, I receive a bunch of tensors: output_ids['hidden_states'] is a tuple, and its length is 1024. I am not sure and do not know how to extract the last hidden state from Phi, ie., this tensor before lm_head.

Thanks, Daniel

Isaachhh commented 2 months ago

Hi Daniel,

Bunny calls forward of Phi here, so I think you can use a variable to save the last hidden state here and return it by add an item in CausalLMOutputWithPast.

And I also think output_ids['hidden_states'][-1] is what you need.

But this last_hidden_states has been layer normalized here.

GewelsJI commented 2 months ago

Hi, @Isaachhh

Appreciated your suggestions here.

Actually, I print the shape of all output_ids['hidden_states'], casue it is a tuple. During the casual-decoder inference, they will generate 1024 items, and the first item output_ids['hidden_states'][0] is with shape of (bs, token-nums, 2048), and the remainings are consistently with shape of (bs, 1, 2048). I assume the reason is the casual-decoder-only framework works like a token-by-token prediction, is that right?

Further, I have no idea why the length of output_ids['hidden_states'] is 1024?

Best, Daniel

Isaachhh commented 2 months ago

Please refer to here.

1024 is max_new_tokens

GewelsJI commented 2 months ago

I see. A nice discussion with you. Thank you again.