baaivision / Emu

Emu Series: Generative Multimodal Models from BAAI
https://baaivision.github.io/emu2/
Apache License 2.0
1.6k stars 84 forks source link

Implementation in generate_image() #22

Open jwh97nn opened 1 year ago

jwh97nn commented 1 year ago

Hi, I have a question regarding the image generation process, specifically the generate_image function at https://github.com/baaivision/Emu/blob/main/models/modeling_emu.py#L185

According to this function and the description in the paper, it forces the model to generate n_causal=32 <image> tokens consecutively. In each decoding step, the function concatenates an token after the generated sequence, then passes through the model. Finally, the outputs.hidden_state[-1][:, -n_causal:, :] is projected through the self.decoder.lm.stu_regress_head.

So my confusion is, what is the difference between directly appending n_causal=32 <image> tokens to the text prompts, and decoding the output at once? Because each <image> token is identical and there is no storage of intermediate outputs.

Please correct me if have misunderstood.

ryanzhangfan commented 1 year ago

The core idea here is that the image tokens should be generated in an auto-regressive manner which is the same as training, in which the latter token is generated depending on all previous tokens (including both prompt tokens and previously generated image result tokens).

The <image> token is just a special placeholder token which is only used in tokenizing the input sequence. It is useful when calling generate_image for a batch of input to ensure that the resulting token sequence length is the same after appending image tokens. All <image> token features are replaced by image features of prompt_image and result_image before being fed into LMM to generate the next token.

The main difference between the current implementation and inferring 32 images tokens at once is that the current implementation is auto-regressive but the latter is not. More specifically, the generation process of i-th token of the latter method does not depends on the results of 0-th to i-1-th token which is inconsistent with the LMM inference process.

Besides, you can of course append one [IMG] and 31 <image> tokens at once. But you still need to call 32 times of inference and at each step you should replace first i <image> tokens by corresponding generated tokens to ensure the auto-regressive property.