VeriSilicon / TIM-VX

VeriSilicon Tensor Interface Module
Other
226 stars 84 forks source link

Provide GatherElements operator? #248

Closed Nullkooland closed 2 years ago

Nullkooland commented 2 years ago

The GatherElements op is similar to Gather op but indexing at elements level (instead of tensor spans).

# given 3-D CHW input and index tensor and axis.
out[i][j][k] = input[index[i][j][k]][j][k]  # if axis == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if axis == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if axis == 2

See ONNX GatherElements and torch.gather.

Currently TIM-VX does not provide the GatherElements op and it's very complicated to express it indirectly via GatherNd op., is there any plan to provide this op?

Nullkooland commented 2 years ago

Also the ScatterElements op, which is the inverse version of GatherElements. See ONNX ScatterElements

sunshinemyson commented 2 years ago

@Goose-Bomb , Thanks for the feedback. Since its ONNX operator, we didn't have ONNXRuntime execution provider opened yet, Are you working on this? or could you share me why you need this? Maybe we can double check the operators from ONNX mapping status if you have clear requirement.

Nullkooland commented 2 years ago

@Goose-Bomb , Thanks for the feedback. Since its ONNX operator, we didn't have ONNXRuntime execution provider opened yet, Are you working on this? or could you share me why you need this? Maybe we can double check the operators from ONNX mapping status if you have clear requirement.

No, I'm not working on the ONNX backend for TIM-VX. The ONNX GatherElements op is mapped from PyTorch's torch.gather, it can be further mapped to TVM's relay.gather and I'm working on a TIM-VX backend for TVM.

In my application, what I need is: Given input feature map of [N, C, L] (N: batch size, C: number of features, L: length of each feature) and indices [N, K], I want to gather K out of C features batchwisely. This operation can appear in some object detection model's post-processing part, where indices tensor is produced by applying TopK on the confidence map and is used to gather K bounding boxes params (in each batch) from the output feature map.

Here is the demonstration: image

output[i, j, ...] = input[i, indices[i, j], ...]

In TensorFlow, this can be done by using tf.gather with axis=1, batch_dims=1.

However, PyTorch's equivalent op torch.index_select does not support batching (BTW, TIM-VX's Gather op has no batch_dims parameter, so it cannot support batching either)., but there is a workaround that first expand the indices from [N, K] to [N, K, L] then apply the torch.gather (equivalent to the ONNX GatherElements op) to get the same output.

Also, I noticed that TIM-VX's TopK #193 op does not support batching or specifying an axis? And its implementation is on the CPU. Is it possible to provide a GPU (you call PPU I guess?) accelerated version of it? See torch.topk.

In a nutshell, the requirement is essentially batching support for Gather-Scatter ops. I understand that the batch_size is usually set to 1 in inference but it's nicer to have a general op mapping so the users don't need to adjust their models when deploying to NPU.

sunshinemyson commented 2 years ago

@Goose-Bomb , Do you have a sample code with TVM? I'd like to ask team try to support this.

Nullkooland commented 2 years ago

@Goose-Bomb , Do you have a sample code with TVM? I'd like to ask team try to support this.

Sure, I'll provide a TVM Relay example tomorrow.

Nullkooland commented 2 years ago

Here's the sample code with TVM.

import tvm
from tvm import relay

import numpy as np
from numpy.typing import NDArray

if __name__ == "__main__":
    n = 2   # batch size
    c = 10  # number of features
    l = 3   # feature length

    k = 5   # top-K

    # Given input confidence of [N, C] and feature map of [N, C, L].
    confidence_data = np.random.uniform(
        0.0, 1.0, size=(n, c)).astype(np.float32) # -> [n, c]
    features_data = np.tile(
        np.arange(c, dtype=np.float32).reshape((1, c, 1)), reps=(n, 1, l)
    ) # -> [n, c, l]

    # Build Relay graph.
    confidences = relay.var("confidences", shape=(n, c), dtype="float32")
    features = relay.var("features", shape=(n, c, l), dtype="float32")

    # Apply TopK to find K highest confidences and their indices.
    confidences_topk, indices_topk = relay.topk(
        data=confidences,
        k=k,
        axis=1,
        ret_type="both"
    )  # -> [n, k], [n, k]

    # Option 1, Use relay.gather (equivalent to ONNX GatherElements)
    # to select K out of C features.

    # indices_expanded = relay.op.expand_dims(indices, axis=2)  # -> [n, k, 1]
    # indices_expanded = relay.op.repeat(
    #     indices_expanded, repeats=l, axis=2)  # -> [n, k, l]
    # features_topk = relay.op.gather(
    #     data=features, indices=indices_expanded, axis=1)  # -> [n, k, l]

    # Option 2, Use relay.take (equivalent to ONNX Gather) with batching
    # to select K out of C features.

    features_topk = relay.op.take(
        data=features, indices=indices_topk, axis=1, batch_dims=1
    )

    # Get module from expr.
    output = relay.Tuple((indices_topk, features_topk))
    func = relay.Function([confidences, features], output)
    mod = tvm.IRModule.from_expr(func)
    mod = relay.transform.InferType()(mod)

    # Show Relay graph.
    print("[Relay Graph]")
    print(mod.astext(show_meta_data=False))

    # Build and run model.
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target="llvm", params=None)

    device = tvm.cpu()
    rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](device))

    confidences_input = tvm.nd.array(confidence_data, device)
    features_input = tvm.nd.array(features_data, device)

    rt_mod.set_input("confidences", confidences_input)
    rt_mod.set_input("features", features_input)
    rt_mod.run()

    indices_topk_data = rt_mod.get_output(0).numpy()
    features_topk_data = rt_mod.get_output(1).numpy()

    # Print results.
    print("Input confidences:")
    print(np.around(confidence_data, decimals=2))

    print("Input features:")
    print(features_data)

    print("TopK indices:")
    print(indices_topk_data)

    print("TopK features:")
    print(features_topk_data)
sunshinemyson commented 2 years ago

@Goose-Bomb Thanks for the demo. We are working on the solution now. Will let you know what's the solution once we finish internal discussion.

sunshinemyson commented 2 years ago

We are going to support this in next release.