where to generate the input to the transformer block, you also flatten the image num dimension along with height and width. So, is my understanding correct if I say that the DiT block learns to associate the images from this operation?
The DiT jointly processes all patches from all images jointly. So if there are 8 images each with 16 x 16 patches, then the transformer processes 8 16 16 tokens.
Hey,
I read through the paper and one of my doubts is how is the information between images shared with each other to predict rays?
The only hint I found was in the code implementation here: https://github.com/jasonyzhang/RayDiffusion/blob/f53303a2983f04b1ce6e60c967cecd1e1695b74f/ray_diffusion/model/dit.py#L272
where to generate the input to the transformer block, you also flatten the image num dimension along with height and width. So, is my understanding correct if I say that the DiT block learns to associate the images from this operation?