acerbilab / relational-neural-processes

Practical Equivariances via Relational Conditional Neural Processes (Huang et al., NeurIPS 2023)
MIT License
2 stars 0 forks source link

Permutation equivariance of input features dimensions #6

Closed lacerbi closed 10 months ago

lacerbi commented 1 year ago

How do we implement permutation equivariance of input features? (see this as an example).

Note that this is not necessarily something we want to implement for all models. When dealing with specific models, input features dimensions are not equivariant, since they have a specific meaning. However, we want this equivariance e.g. for generic GP models.

I have a coupe of ideas for this. I will write them below as two separate comments.

lacerbi commented 1 year ago

Solution 1: Permutation equivariance via explicit summation

Specifically, let $\pi \in \mathcal{S}_D$ be a permutation, where $\mathcal{S}_2 = \{ (1,2), (2,1) \}$, $\mathcal{S}_3 = \{ (1,2,3), (1,3,2), (2,1,3), (2,3,1), (3,1,2), (3,2,1) \}$, etc. We write $\pi \mathbf{x}$ as the vector where the permutation $\pi$ is applied to the elements of the vector $\mathbf{x}$.

Encoder

Decoder

Comments

lacerbi commented 1 year ago

Solution 2: Canonical ordering

Encoder

Decoder

Comments

lacerbi commented 1 year ago

Solution 3: Permutation invariance via bi-dimensional deep set

In conclusion, I think that this approach is very appealing, but I am not 100% sure how to apply this in the context of the standard CNP architecture (and ours).

st-- commented 1 year ago

I was wondering why a permutation equivariant architecture (like in AHGP paper) might not work. Luigi's intuition is that the way they implement scrambles the information up too much (dimensions belonging together can only be recovered based on $y$-value).

One useful set of synthetic test cases might be non-axis-aligned periodic functions in higher-D: here, learning the correlations between dimensions = learning direction of the wave turns the problem into a simple 1D problem, and similar $y$-values repeat often.

lacerbi commented 1 year ago

Solution 4: Permutation equivariance via relational bi-dimensional deep set

OK, I think I cracked this (at least theoretically, not sure how it is going to work in practice).

I'll explain first how to introduce permutation equivariance (with respect to features/input dimensions) in a standard CNPs. This also affords simultaneous training on a multiple number of feature dimensions.

Previous approaches using bi-dimensional deep sets/attention

As a reminder, the current approach suggested by the AHGP paper and similarly in the neural diffusion process paper is to have a bidimensional deep set / attention mechanism. None of this has been applied specifically to CNPs to my knowledge, but the application would be a trivial extension (especially given the neural diffusion process paper).

In the following, I denote with xn(i) the $i$-th element of the input vector $\mathbf{x}_n \in \mathbb{R}^{d_x}$, where $d_x$ is the number of input features; and with yn the output vector $\mathbf{y}_n \in \mathbb{R}^{d_y}$; for $1 \le n \le N$ with $N$ the size of the context set. For simplicity, we can restrict ourselves to the case $d_y = 1$, but there should be no difference for the multi-output case.

Let's put our context set on a table of pairs (the vector yn is repeated for each input dimension):

(x1(1), y1),  (x1(2), y1), ..., (x1(d_x),y1)
(x2(1), y2),  (x2(2), y2), ..., (x2(d_x),y2)
...
(xN(1), yN),  (xN(2), yN), ..., (xN(d_x),yN)

In a nutshell, what previous methods did is first to build a permutation invariant representation for each column of this table. First, we embed each pair (xn(i), yn) into a higher dimensional vector $\mathbf{z}_{n,i}$, then we aggregate over the data dimension (i.e. over each column). Doing this operation in parallel for each column, it yields $d_x$ column embeddings $\mathbf{h}_1, \ldots, {\mathbf{h}_d}_x$. Finally, they apply a transformer or a DeepSet to these representations, to obtain either an equivariant or invariant output.

The problem of this approach is that it introduces more invariances that you would want. It is easy to show that you can apply distinct permutations of the data separately for each input dimension to the representation above, and you would get the same output. In other words, we have killed the correlation across input features; this approach only preserves the correlation between each feature xn(i) and the outputs yn.

It is true that if the yn are unique in the context set, then in theory it is possible to reconstruct the correlations among the xn(i) using yn as the binding feature. However, it seems we are asking the network to do a lot of work in an unnatural way, and it can break in situations like periodic functions (see @st-- 's comment above).

The new proposal for CNPS

In short, we want to keep information about each point, and one way to do it is again via a sort-of relational encoding of features.

RCNPs with equivariant inputs

Finally, we would like to apply the procedure above to RCNPs (specifically, for the translational-equivariant case; no need for the isotropic case since it is already equivariant to permutations of input features!). I describe it in a separate comment below, Solution 4 (part b).

manuelhaussmann commented 1 year ago

I'll give it a closer thought once my experiments are (finally) running. But on a first reading it sounds like a reasonable approach.

lacerbi commented 1 year ago

Solution 4 (part b): Permutation equivariance via relational bi-dimensional deep set

See Solution 4 above for the first part, about how to implement permutation equivariance of input features in CNPs. Here I discuss the application to RCNPs, in particular for translational invariance.

RCNPs with equivariant inputs

The execution is actually quite simple. Where in the standard CNP we would define the table above:

(x1(1), y1),  (x1(2), y1), ..., (x1(d_x),y1)
(x2(1), y2),  (x2(2), y2), ..., (x2(d_x),y2)
...
(xN(1), yN),  (xN(2), yN), ..., (xN(d_x),yN)

Instead, for the RCNP we have a similar table:

(rho1(1), y1),  (rho1(2), y1), ..., (rho1(d_x),y1)
(rho2(1), y2),  (rho2(2), y2), ..., (rho2(d_x),y2)
...
(rhoN(1), yN),  (rhoN(2), yN), ..., (rhoN(d_x),yN)

where rhon(i), in math $\rho{n,i}$, is the relational encoding for the $i$-th feature of the $n$-th data point, obtained as: $$\rho_{n,i} \equiv \rho({x}_{n,i}) = \bigoplus_{n^\prime=1}^N h_\theta( {x}{n^\prime,i} - {x}_{n,i}, \mathbf{y}_{n^\prime} ),$$ where we used the difference encoding (i.e., the comparison function $g(\cdot, \cdot)$ is the difference, which is suitable to encode translational equivariance), and $h_\theta$ is, as usual, the relational encoding network.

This expression is almost identical to the standard relational encoding, with the only difference that here it is applied separately for each input feature $i$, so every data point $\mathbf{x}_n$ ends up having $d_x$ separate relational encodings (the encoding network is the same for all features, and operates on them in parallel).

After we obtain the table above, everything proceeds exactly like for the implementation of permutation equivariance for CNPs described in Solution 4 above.