juho-lee / set_transformer

Pytorch implementation of set transformer
MIT License
537 stars 101 forks source link

Question about Deep Sets Implementation #4

Closed arnavs closed 4 years ago

arnavs commented 4 years ago

Hi @juho-lee,

First of all, thanks for making this code publicly available. It's very useful.

One question, though. I am looking at your implementation of the Zaheer et al network ("Deep Sets.") In his paper, we have something like rho(sum (phi(x))), where we are adding over each element of the set (I believe you call this a set pooling method in your paper )

In your DeepSet class, we have a succession of Linear -> ReLU -> Linear -> ReLU layers, that operate on the entire data set, and then are pooled at the end.

Could you explain a little about why these are equivalent?

juho-lee commented 4 years ago

Hi, Linear layers act on individual elements, so it is equivalent to applying the same linear operation (phi(x)) to each element in a set. Also, Linear layers in Pytorch supports batched operation, thus the same applies for batched inputs (batch_size num_elements dim tensors).

arnavs commented 4 years ago

Roger that, thank you @juho-lee for the response.