arogozhnikov / eindex

Multidimensional indexing for tensors
112 stars 2 forks source link
eindex animated logo

eindex

Concept of multidimensional indexing for tensors

Example of K-means clustering

Plain numpy

def kmeans(init_centers, X, n_iterations: int):
    n_clusters, n_dim = init_centers.shape
    n_onservations, n_dim = X.shape

    centers = init_centers.copy()
    for _ in range(n_iterations):
        d = cdist(centers, X)
        clusters = np.argmin(d, axis=0)
        new_centers_sum = np.zeros_like(centers)
        indices_dim = np.arange(n_dim)[None, :]
        np.add.at(new_centers_sum, (clusters[:, None], indices_dim), X)
        cluster_counts = np.bincount(clusters, minlength=n_clusters)
        centers = new_centers_sum / cluster_counts[:, None]
    return centers

With eindex

def kmeans_eindex(init_centers, X, n_iterations: int):
    centers = init_centers
    for _ in range(n_iterations):
        d = cdist(centers, X)
        clusters = EX.argmin(d, 'cluster i -> [cluster] i')
        centers = EX.scatter(X, clusters, 'i c, [cluster] i -> cluster c',  
                             agg='mean', cluster=len(centers))
    return centers

Tutorial notebook

Goals

Non-goals: there is no goal to develop 'the shortest notation' or 'the most advanced/comprehensive tool for indexing' or 'cover as many operations as possible' or 'completely replace default indexing'.

Examples

Follow tutorial first to learn about all operations provided.

Click to unfold #### - how do I select a single embedding from every image in a batch? Let's say you have pairs of images and captions, and you want to take closest embedding from image for every token: ```python score = einsum(images_bhwc, sentences_btc, 'b h w c, b token c -> b h w token') closest_index = argmax(score, 'b h w token -> [h, w] b token') closest_emb = gather(images_bhwc, closest_index, 'b h w c, [h, w] b token -> b token c') ``` To adjust this example for video not image, replace 'h w' to 'h w t'. Yes, that simple. #### - how to collect top-1 or top-3 predicted word for every position in audio/text? ```python [most_likely_words] = argmax(prob_tbc, 't b w -> [w] t b') [top_words] = argsort(prob_tbc, 't b w -> [w] t b order')[..., -3:] ``` #### - how to average embeddings over neighbors in a graph? ```python # without batch (single graph) gatherscatter(embeddings, edges, 'vin c, [vin, vout] edge -> vout') # with batch (multile graphs) gatherscatter(embeddings, edges, 'b vin c, [b, vin, vout] edge -> b vout') ``` #### - can eindex help with (complex) positional embeddings? If we're speaking about trainable abspos, it can be just saved as `emb_hwc` and added every time to a batch. There is no need for indexing. But it can be very helpful for complex scenarios: let's take T5-relpos as an example, when a bias is added to every attention logit before softmax-ing. That's simple to implement for 1d, and *much* harder for 2d/3d. Let's implement T5-relpos in 2d with `eindex`: ```python N = None pos # [I, J] i j pos1 = pos[:, :, :, N, N] pos2 = pos[:, N, N, :, :] xy_diff = (pos1 - pos2) % image_side # we make shifts positive by wrapping attention_bias = gather(biases, xy_diff, 'i j head , [i, j] i1 j1 i2 j2 -> i1 j1 i2 j2 head') ``` Note that we use 2d-relative position (shift in x and y), while most implementations just use sequence shift. In a similar way we could produce vector-shift attention (another common version of relpos): ```python vector_shift = gather(vectors, xy_diff, 'i j head c, [i, j] i1 j1 i2 j2 -> i1 j1 i2 j2 head c') ```

Implementation

Repo provides two implementation:

Development Status

API looks solid, but breaking changes are still possible, so lock the version in your projects (e.g. eindex==0.1.0)

Related projects

Other projects you likely want to look at:

Contributing

We welcome the following contributions:

Discussions

Use discussions at github for this project https://github.com/arogozhnikov/eindex/discussions