entity-neural-network / incubator

Collection of in-progress libraries for entity neural networks.
Apache License 2.0
29 stars 10 forks source link

Implement relative positional encoding #139

Closed cswinter closed 2 years ago

cswinter commented 2 years ago

Implements a version of relative positional encoding for n-dimensional grids. Relative positional encoding with e.g. a 11 x 13 extent for an environment with a 2d grid can be enabled by passing --relpos-encoding='{"extent": [5, 6], "position_features": ["x", "y"], "per_entity_values": true}' to enn_ppo/train.py.

There are many variations and refinements of relative positional encoding. This implementation mostly follows the original formulation described in Shaw et al (2018). In particular, here is a non-exhaustive list of somewhat arbitrary design choices that we may want to revisit once we have some good benchmarks to test against:

The current implementation requires ds^2 memory, where d is the dimension of heads and s is the sequence length. Since our sequences are relatively short so far, this does present a major issue. The usual trick used to reduce memory usage by a factor of s only works for sequences and not our more general version where entities can be at arbitrary grid points. We could still achieve the same savings with a custom GPU kernel though.

cswinter commented 2 years ago

Some basic ablations here: https://wandb.ai/entity-neural-network/enn-ppo/reports/Relative-positional-encoding-ablations--VmlldzoxNDM0MzIx Relative positional encoding somewhat outperforms tuned baseline using translation and greatly outperforms policies that only see the raw position features.

A big caveat is that, at least on this task, we seem to need per-entity relative positional values to get good performance. I believe the reason for this is that per-entity values allow a single attention head/layer to easily access/compute per-entity positional information in a way that is impossible without per-entity values. Imagine a single head that attends from the actor entity to two entities equally: a snake segment entity and a food entity. The output of the attention head will be 0.5 * (value[snake] + value[food] + relposvalue[snake.pos] + relposvalue[food.pos]) where value[x] is the normal value vector of x derived from the embedding of entity x, and relposvalue[x] is the relative positional embedding value of the position of entity x. The actor now has access to the following information:

While per-entity relative positional values are good solution in the case of this environment, they are also quite limited. If, instead of separate food and snake entities, there was a single entity type with a feature that identifies whether it is "food" or "snake segment", per-entity relpos values wouldn't apply. This seems wrong, we ought to have a solution that works just as well in that case. More generally, the relevant property might not be the entity type, but some arbitrary feature of the entity learned by the network.

I believe we can come up with a new type of relative positional encoding which is fully general by allowing for a non-linear combination of a (projection of) the entity embeddings and the positional features. Since there are N^2 relative positional values, we probably can't afford a full matmul, but there some cheap-elementwise operations that I think could work well. In particular, a good approach could be to perform an element-wise multiplication of the relative positional values and a projection of the corresponding entity embedding using one of the GLU variants described in Shazeer (2020). This would effectively allow entities to apply an arbitrary gating function to any of the relative positional values, and should be strictly more powerful than per-entity positional encodings.

An important related question is whether all of this is even necessary. In principle, a multi-layer or multi-head attention network ought to be able to perform the same operation. E.g., a two-head attention layer could retrieve all snake entities with on of the heads and all food entities with the other head, which allows it to separately access and project the positions of the different entity types. Empirically, I haven't been able to get good performance even with networks with multiple layers/heads. It would be good to understand this better. Some thoughts: