juho-lee / set_transformer

Pytorch implementation of set transformer
MIT License
537 stars 101 forks source link

A little puzzle about the implementation details. #10

Open Jack-mi opened 3 years ago

Jack-mi commented 3 years ago

Hi juho-lee! I have two little puzzles about your paper. In section 1-Introduction. You said "A model for set-input problems should satisfy two critical requirements. First, it should be permutation invariant the output of the model should not change under any permutation of the elements in the input set. Second, such a model should be able to process input sets of any size." But after reading the whole paper, I actually didn't know how you tackle with these two problems. For problem 1, I guess you may remove the position embedding from the initial Transformers? As for problem 2, I had totally no idea how you achieved it. Thank you!

yoonholee commented 3 years ago

Hi, it seems your two questions are: (1) why is ST permutation invariant, (2) how can ST process input sets of any size.

(1) You’re exactly right. Simply removing the position embedding from Transformers corresponds to SAB in our paper and is permutation invariant. We also propose a new attention-based block called ISAB, which has lower computation cost and outperforms SAB in our experiments. SAB and ISAB are both permutation invariant because they determine outputs based only on input features and not their order.

(2) This is possible because of PMA, our attention-based pooling module. PMA takes as input a set of any size and outputs a set of size k. You can read more about this in section 3.2 of our paper (https://arxiv.org/pdf/1810.00825.pdf).

susu1210 commented 2 years ago

You have said that both SAB and ISB are permutation equivariant, but not permutation invariant. And the PMA is permutation invariant.

rpeys commented 1 month ago

Hi! Great paper. Following up on this old thread with a question - does your code actually handle sets with variably sized inputs? I would like to apply it to such a dataset, and I expected to see masks etc. to handle the variable sizes when calculating attention. Before I implement this myself, I wanted to check if I was missing something in your code. Thanks!

rpeys commented 1 month ago

Sorry, just found a separate issue that discussed my exact question, never mind!