HVision-NKU / StoryDiffusion

Create Magic Story!
Apache License 2.0
5.43k stars 518 forks source link

what's the purpose of bool_matrix1024, bool_matrix4096 in the return of cal_attn_mask_xl function #97

Open parryppp opened 1 month ago

parryppp commented 1 month ago

The image of bool_matrix1024, bool_matrix4096 are shown as belowed.

image image
Z-YuPeng commented 1 month ago

Our method is based on sampling tokens to interact between images through attention operations. To perform the sampling operation, the mask identifies the tokens to be sampled. In order to reduce memory consumption, we switch from using the mask to using indices.

parryppp commented 1 month ago
image

The reshaped attention mask is shown above. Do you mean that, for example, if i want to generate 4 consistent images, the yellow zone in the attention map would not be masked, then what does 'randsample' in the paper mean?"

image
Z-YuPeng commented 1 month ago

The generated random mask is exactly the means of implementing random sampling. Once we have randomized a mask, it means that only the tokens indicated by mask = 1 will be considered.

Z-YuPeng commented 1 month ago

the yellow zone corresponding to the concatenation operation below, we found that we cannot drop a image's own tokens, as this would lead to a significant decline in image quality.

Z-YuPeng commented 1 month ago

The mask operation combines random sampling and concatenation into a single step because we initially found that doing so was faster and equivalent to random sampling but also led to greater memory usage. Later, we reverted to the original approach due to concerns about memory consumption raised in issues. https://github.com/HVision-NKU/StoryDiffusion/blob/main/utils/gradio_utils.py#L258

parryppp commented 1 month ago

Thank you for your explanation, I now understand much more clearly. But I still have a question about the shape of attention mask. why does the attention mask ensure that squares on the diagonal remain set to 1 as shown in the figure belowed, is it the purpose of code in L249-L252? image https://github.com/HVision-NKU/StoryDiffusion/blob/b67d205784a3acba6194f13e941f9b5a4dbdc34c/utils/gradio_utils.py#L249 why not just simply generate a random attention mask just like the belowed figure? https://github.com/HVision-NKU/StoryDiffusion/blob/b67d205784a3acba6194f13e941f9b5a4dbdc34c/utils/gradio_utils.py#L261 image