graldij / transformer-fusion

Official repository of the "Transformer Fusion with Optimal Transport" paper, published as a conference paper at ICLR 2024.
20 stars 5 forks source link

How you ensure T_qk@T_qk^T=I? #4

Closed daidaiershidi closed 1 month ago

daidaiershidi commented 1 month ago

Thank you for bringing such an interesting piece of work. In the paper, I noticed that you hoped for T_qk@T_qk^T=I. How you ensure this condition? I didn't find implementation of this in the code. Is this constraint necessary? image

graldij commented 1 month ago

Hi! Thanks for the interest in our work! Setting the option eq_t_map for the qk_fusion entry in the YAML config files, you consider together the Q and K activations/weights for the optimal transport problem. Then, the hard alignment solution returns a single permutation matrix that satisfies the given constraint.