lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.56k stars 643 forks source link

rotary embedding values #413

Closed sklin93 closed 2 years ago

sklin93 commented 2 years ago

The reason for using value -10 here is " image axial positions have range [-1, 1]", but in fact, in the RotaryEmbedding class, for img_axial_pos_emb = RotaryEmbedding(dim = rot_dim, freqs_for = 'pixel') img_freqs_axial = img_axial_pos_emb(torch.linspace(-1, 1, steps = image_fmap_size))

freqs (has its default value here) and t (ranges in [-1, 1]) will einsum, and the results are not in range [-1, 1]. (e.g. if using head_dim=64, aka rot_dim=21, then the img_freqs with range from -15.7 to 15.7) and I think this makes the choice of -10 not suitable?