tenstorrent / tt-tvm

TVM for Tenstorrent ASICs
Apache License 2.0
18 stars 6 forks source link

No proper support for `aten::triu` OP inputs when inputs are CallNodes [Bug] #3

Closed JushBJJ closed 1 month ago

JushBJJ commented 5 months ago

Hi I'm doing the bounties with @marty1885 and @JonathanALevine.

Context

When attempting to implement Qwen 1.5 (0.5B) https://github.com/tenstorrent/tt-buda-demos/issues/20, we encountered an issue where we had this error:

Traceback (most recent call last):
  File "/home/jush/qwen/.venv/lib/python3.8/site-packages/numpy/lib/twodim_base.py", line 536, in triu
    mask = tri(*m.shape[-2:], k=k-1, dtype=bool)
  File "/home/jush/qwen/.venv/lib/python3.8/site-packages/tvm/relay/expr.py", line 146, in __sub__
    raise TypeError(f'convert "{str(other)}" with `const` first')
TypeError: convert "1" with `const` first

Qwen 1.5 (0.5B) implementation demo can be found in https://github.com/tenstorrent/tt-buda-demos/pull/37

The OP code in this case for triu would be:

%aten::triu_0_0 : Int(256, 256, strides=[256, 1], requires_grad=0, device=cpu) = aten::triu(%aten::ones_like_0_0, %aten::Int_7_0), scope: utils.modeling_qwen2.Qwen2ForCausalLM::/utils.modeling_qwen2.Qwen2Model::model # /home/jush/qwen/.venv/lib/python3.8/site-packages/transformers/modeling_attn_mask_utils.py:169:0

*Requires transformers >=v4.37.0

Root Cause

The triu function in the PyTorchOpConverter class in tvm/relay/frontend/pytorch.py uses np.triu to handle these inputs.

mask = np.triu(np.ones(x_shape), inputs[1]).astype(np.bool)

However, it is not possible to handle these input types when the inputs parameter is a list containing nested functions. In this case, inputs[1] would be:

CallNode(Op(add), [CallNode(Op(subtract), [CallNode(Op(subtract), [CallNode(Op(add), [Constant(256), Constant(0)], (nullptr), []), CallNode(Op(multiply), [Constant(1), Constant(256)], (nullptr), [])], (nullptr), []), CallNode(Op(multiply), [Constant(1), Constant(32768)], (nullptr), [])], (nullptr), []), Constant(1)], (nullptr), [])

Workaround

In draft PR https://github.com/tenstorrent/tt-tvm/pull/2, I used self.trilu(inputs, input_types, mode="triu") simply because _op.trilu can do triangular upper instead of triangular bottom if upper=True.

def trilu(self, inputs, input_types, mode):
    data = inputs[0]
    k = inputs[1] if inputs[1] else 0
    upper = True if mode == "triu" else False
    return _op.trilu(data, k, upper)

This workaround successfully compiles the Qwen 1.5 (0.5B) model but was still unable to properly run.

See https://github.com/tenstorrent/tt-tvm/pull/2 to see the issues of this workaround.

Environment

OS: Ubuntu 20.04 Pybuda Version: v0.10.5.gs.240315 TVM Version (from latest Pybuda): 0.14.0

Steps to reproduce

See https://github.com/tenstorrent/tt-buda-demos/pull/37

Triage

frontend:pytorch

staylorTT commented 4 months ago

Hi There, thanks for filing this issue. I will do my best to get someone to take a look at this from the TT side.

AleksKnezevic commented 3 months ago

To clarify @JushBJJ are the inputs to the op constant (even if they're call nodes). That is, would we be able to extract them using _infer_value?

JushBJJ commented 3 months ago

To clarify @JushBJJ are the inputs to the op constant (even if they're call nodes). That is, would we be able to extract them using _infer_value?

I will get into this very soon later today, my e75 just arrived yesterday

staylorTT commented 1 month ago

@JushBJJ Do you have any updates here?

JushBJJ commented 1 month ago

@JushBJJ Do you have any updates here?

Hi, incredibly sorry for forgetting about this. Since I don't have much to do atm this is full priority now. Currently trying to pass some Buda hurdles atm with the more recent updates but so far I haven't reached this issue again...yet

staylorTT commented 1 month ago

Appreciate the update. If you hit this again please report back.

JushBJJ commented 1 month ago

@staylorTT Yup this can stay closed, can confirm that this is no longer an issue