tlc-pack / relax

Apache License 2.0
196 stars 58 forks source link

[Discuss] Extract Op Info from Primfunc #278

Open sunggg opened 2 years ago

sunggg commented 2 years ago

1. Motivation

Problem

Currently, once we lower op to primfunc implementation, it is hard to exploit op-level info (e.g., op name, op kind, op attribute..) although primfunc is supposed to contain them. This has been fine in Relay since the pipeline lowers abstraction strictly in the one-direction allowing one abstraction at a time.

In Relax, we are unlocking interaction between different abstraction-levels. New design of TIR-level layout planning is a good example - by manipulating both graph-level and TIR-level at the same time, we could eliminate the need of InferCorrectLayout that has been source of complexities and issues. However, this makes layout planning require lowering to happen before the planning and the loss of convenient op-level information during lowering makes BYOC mechanism difficult. For instance, the following snippet shows how TensorRT BYOC converts Relay/Relax conv2d op to TensorRT equivalent by using the op-level info (e.g., op name and its attributes, such as data_layout, strides, etc. These info may not be easily accessible in the current primfunc design.

class Conv2DOpConverter : public TensorRTOpConverter {
 public:
  // ....
  void Convert(TensorRTOpConverterParams* params) const {
    auto input_tensor = params->inputs.at(0).tensor;
    auto input_dims = TrtDimsToVector(input_tensor->getDimensions());
    auto weight_shape = params->inputs.at(1).weight_shape;
    ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("data_layout")[0], "NCHW");
    ICHECK(params->node.GetAttr<std::vector<std::string>>("out_layout")[0] == "" ||
           params->node.GetAttr<std::vector<std::string>>("out_layout")[0] == "NCHW");
    ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("kernel_layout")[0], "OIHW");
    auto str_strides = params->node.GetAttr<std::vector<std::string>>("strides");
    auto str_dilation = params->node.GetAttr<std::vector<std::string>>("dilation");
    auto str_padding = params->node.GetAttr<std::vector<std::string>>("padding");
    int groups = std::stoi(params->node.GetAttr<std::vector<std::string>>("groups")[0]);
    int channels = weight_shape[0];
    if (params->node.HasAttr("channels") &&
        !params->node.GetAttr<std::vector<std::string>>("channels")[0].empty()) {
      channels = std::stoi(params->node.GetAttr<std::vector<std::string>>("channels")[0]);
    }
    // ...
    const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]);
    const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type;
    nvinfer1::Weights bias{weight_type, nullptr, 0};
    auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size,
                                                      params->inputs.at(1).weight, bias);
    //...
  }
};

Goal

To solve such problems, such as achieving benefit from TIR-level planning while supporting BYOC, this doc investigates whether it is possible to access the op-level info in TIR-level in a convenient form. Specifically, this op-level info includes

Please note that tir::PatternKindAnalyzer in Relax is already able to deduce operator kind based on the TIR primfunc. This doc examines whether similar approach is achievable for other info.

At the end of the day, we may provide the convenient interface to access those info. Although this doc would not discuss its best design, a couple of options can be:

2. Findings

Operator Name

This can be obtained during the lowering and easily annotated in primfunc.

Operator Kinds

Already supported by tir::PatternKindAnalyzer in Relax.

Operator Attributes

By using attributes, TVM lowers each operator into its valid implementation. Therefore, this section assumes the primfunc implementation would embed the attribute information in a certain way and examines whether we can extract them. Since layout transformation at TIR-level might affects the attributes (we call it layout-sensitive attribute), we also look into which attributes should be updated accordingly on the layout transformation.

Case Study

Representative Ops w/o Attributes

Representative Ops w/ Attributes

Summary

3. Suggestion for Relax Layout Planner

With access to op-level info in primfunc, there can be two options to make relax layout planner work with BYOC:

masahi commented 1 year ago

Some thought on this problem:

To guarantee the soundness of graph-level layout transformation, we need to be able to infer the new layout-sensitive attributes for all ops, 100% reliably. That might be difficult, especially at the beginning of development.

To guarantee the soundness while making gradual development possible, the TIR-level transformation pass can materialize transform_layout back to the original layout, when it encounters a TIR primfunc for which we cannot infer the attributes for the corresponding graph-level op. We need this behavior only when the TIR-level pass is invoked for the purpose of graph-level transformation ("graph mode").

This way, we don't have to worry about tricky ops like stride_slice until we commit to implement the attribute inference rule for it.

psrivas2 commented 1 year ago

That could be an interesting direction @masahi! Clarification question: If we materialize the layout in such cases, but are not able to raise the transformed PrimFunc back to operator level, would BYOC backends be able to pattern match such operators?

masahi commented 1 year ago

Interesting question, for backends that can do its own layout transform internally (DNNL), TVM-side layout-transform is always optional (only improves performance). So pattern matching is agnostic to layouts. While other backends (CUTLASS) expects the right layout for pattern matching to succeed, so we need to break the graph there.

But I expect there would be no need to "infer" the new attributes for most compute-intensive ops that we want to offload to BYOC, since their layouts are typically fixed by users. We only need to worry about layout-sensitive ops in-between, like reduction and other shape-changing ops, that might not be offloaded to BYOC anyway.

quic-sanirudh commented 1 year ago

Are there any plans to add support for extracting op info as mentioned here at some point? Was there a final decision on how this is going to be supported?

psrivas2 commented 1 year ago

Hello @quic-sanirudh!

Are there any plans to add support for extracting op info as mentioned here at some point?

Yes, extracting op info would be supported. As mentioned in the comments above, @masahi and @sunggg have also laid out some of the possible approaches. A lot of details need to be figured out still.

Was there a final decision on how this is going to be supported?

It is going to be supported, but the design of how exactly this would work has not been decided yet. If you are interested in this problem, please feel free to start discussion on the design here or in a separate thread.

quic-sanirudh commented 1 year ago

Thanks @psrivas2 for the quick reply. I was curious on how this would work in the presence of fusion. Basically if we extract the op info before fusion, we have to assume that it'll only be valid until fusion is performed, or some way to iterate through the attributes of each individual op that is part of a fused prim_func.

I'll think a bit more about this and explain a bit more with an example.

psrivas2 commented 1 year ago

Looking forward to the example. However, the use case that we have (graph operator level layout transformation), does not need to preserve this information in presence of fusion because graph operator level layout transformation for BYOC and the pattern matching for BYOC would happen before fusion.

quic-sanirudh commented 1 year ago

@psrivas2, thanks for the reply. Actually my question is not related to just layout transformation. Please correct me if I'm mistaken here, but I thought the point of this op info extraction is to make it so that we have these op specific attributes available during other transformations.

For example, if we have the strides or padding information for conv2d/pooling ops, that might be useful for writing passes that target those specific ops. If the information is available during scheduling, that would help improve the scheduling as well (automated based on rules or manual).

Say for example a user would like to write a new shedule_rule that targets a particular type of op, such that based on its attributes, the number of tiles can be decided, that might turn out to be really useful (just a random thought, I don't have a concrete example yet). My idea was, if we need something like that, we might need to extract op attributes through a pass before fusion, and retain it in some way after fusion, perhaps in the form of attributes to that fused PrimFunc.