mlc-ai / relax

Apache License 2.0
137 stars 69 forks source link

[Bug] Error in relax.op.masked_fill and _ffi_api.full_like function #278

Closed sbwww closed 11 months ago

sbwww commented 11 months ago

Expected behavior

In the source code, values = _ffi_api.full_like(x, value) # type: ignore is supposed to be functional without dtype argument.

def masked_fill(x: Expr, mask: Expr, value: Expr):
    """Fill a tensor by a specified value in places defined by a mask.
    Parameters
    ----------
    x : relax.Expr
        The input data to the operator.
    mask : relax.Expr
        The mask.
    value : relax.Expr
        The value to set in the input tensor.
    Returns
    -------
    result : relax.Expr
        The filled tensor.
    """
    values = _ffi_api.full_like(x, value)  # type: ignore
    return _ffi_api.where(mask, values, x)  # type: ignore

Actual behavior

During compilation, relax.op.full_like must receive 3 arguments. TVMError: Function relax.op.full_like(0: RelayExpr, 1: RelayExpr, 2: DataType) -> RelayExpr expects 3 arguments, but 2 were provided.

I can manually replace masked_fill(x, mask, value) with full_like(x, value, x.struct_info.dtype) and where(mask, value, x) to get the right result as a workaround.

    attention_scores = masked_fill(attn_weights, astype(attention_mask, dtype="bool"), fill_value)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/op/mask.py", line 37, in masked_fill
    values = _ffi_api.full_like(x, value)  # type: ignore
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  3: TVMFuncCall
  2: _ZN3tvm7runtime13PackedFun
  1: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::DataType)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::DataType)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::DataType), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  0: _ZN3tvm7runtime6deta
  File "/workspace/tvm/include/tvm/runtime/packed_func.h", line 1731
TVMError: Function relax.op.full_like(0: RelayExpr, 1: RelayExpr, 2: DataType) -> RelayExpr expects 3 arguments, but 2 were provided.

Environment

mlc-ai-nightly-cu116==0.12.dev1365
mlc-chat-nightly-cu116==0.1.dev309

Steps to reproduce

Calling relax.op.masked_fill(x, mask, value)

Triage

relax:op