openvinotoolkit / openvino

OpenVINO™ is an open-source toolkit for optimizing and deploying AI inference
https://docs.openvino.ai
Apache License 2.0
7.24k stars 2.26k forks source link

[Good First Issue][JAX FE]: Support jax.lax.lt and jax.lax.neg operation for JAX #26577

Closed rkazants closed 1 month ago

rkazants commented 1 month ago

Context

OpenVINO component responsible for support of JAX/Flax models is called as JAX Frontend (JAX FE). JAX FE converts a JAX/Flax model represented by ClosedJAXpr graph object with operations from jax.lax opset to OpenVINO IR containing operations from OpenVINO opset.

In order to infer JAX/Flax models containing jax.lax.lt operation by OpenVINO, JAX FE needs to be extended with this operation support.

What needs to be done?

For jax.lax.lt operation support, you need to implement the corresponding loader into JAX FE op directory and to register it into the dictionary of Loaders. One loader is responsible for conversion (or decomposition) of one type of JAX operation.

Here is an example of loader implementation for jax.lax.reshape operation:

OutputVector translate_reshape(const NodeContext& context) {
    num_inputs_check(context, 1, 1);
    Output<Node> input = context.get_input(0);
    auto new_sizes = context.const_named_param<std::vector<int64_t>>("new_sizes");
    if (context.has_param("dimensions")) {
        auto dimensions = context.const_named_param<std::vector<int64_t>>("dimensions");
        // transpose the input first.
        auto permutation_node = std::make_shared<v0::Constant>(element::i64, Shape{dimensions.size()}, dimensions);
        input = std::make_shared<v1::Transpose>(input, permutation_node);
    }

    auto new_shape_node = std::make_shared<v0::Constant>(element::i64, Shape{new_sizes.size()}, new_sizes);
    Output<Node> res = std::make_shared<v1::Reshape>(input, new_shape_node, false);
    return {res};
};

In this example, translate_reshape expresses jax.lax.reshape using OpenVINO opset. Since jax.lax.reshape performs transposition and tensor reshaping according to JAX documentation, the resulted decomposition contains OpenVINO Transpose and Reshape operations. For Transpose and Reshape nodes, this conversion parses constant parameters dimensions to permute input tensor and new_size that is the target shape of the result.

Once you are done with implementation of the translator, you need to implement the corresponding layer tests test_lt.py and put it into layer_tests/jax_tests directory. Example how to run some layer test:

export TEST_DEVICE=CPU
export JAX_TRACE_MODE=JAXPR
export 
cd openvino/tests/layer_tests/jax_tests
pytest test_reshape.py

Example Pull Requests

Resources

Contact points

Ticket

No response

aku221b commented 1 month ago

.take Hi @rkazants, I would like to work on this!

github-actions[bot] commented 1 month ago

Thank you for looking into this issue! Please let us know if you have any questions or require any help.

github-actions[bot] commented 1 month ago

Thanks for being interested in this issue. It looks like this ticket is already assigned to a contributor. Please communicate with the assigned contributor to confirm the status of the issue.

aku221b commented 1 month ago

@rkazants after building, I ranpytest test_reshape.py ,but some testcases failed.

EDIT-: Debugged , needed to set appropiate test precision -: export TEST_PRECISION="FP16"

hub-bla commented 1 month ago

Hi @aku221b, there's a PR that just got merged and I thought you might find it helpful ;)

26719

aku221b commented 1 month ago

Thanks a lot @hub-bla!, will look into it.

aku221b commented 1 month ago

@rkazants Please review PR #26771 for this issue.

aku221b commented 1 month ago

@rkazants By mistake I made commits with the wrong email😅, and its becoming tricky to find some commits and reset the email as it may alter commit history. I have raised a fresh PR with the latest code. Sorry for the unnecessary trouble!

EDIT: Fresh PR-: 26847