In the source code, values = _ffi_api.full_like(x, value) # type: ignore is supposed to be functional without dtype argument.
def masked_fill(x: Expr, mask: Expr, value: Expr):
"""Fill a tensor by a specified value in places defined by a mask.
Parameters
----------
x : relax.Expr
The input data to the operator.
mask : relax.Expr
The mask.
value : relax.Expr
The value to set in the input tensor.
Returns
-------
result : relax.Expr
The filled tensor.
"""
values = _ffi_api.full_like(x, value) # type: ignore
return _ffi_api.where(mask, values, x) # type: ignore
Actual behavior
During compilation, relax.op.full_like must receive 3 arguments. TVMError: Function relax.op.full_like(0: RelayExpr, 1: RelayExpr, 2: DataType) -> RelayExpr expects 3 arguments, but 2 were provided.
I can manually replace masked_fill(x, mask, value) with full_like(x, value, x.struct_info.dtype) and where(mask, value, x) to get the right result as a workaround.
attention_scores = masked_fill(attn_weights, astype(attention_mask, dtype="bool"), fill_value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/mlc-chat-venv/lib/python3.11/site-packages/tvm/relax/op/mask.py", line 37, in masked_fill
values = _ffi_api.full_like(x, value) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.__call__
File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
3: TVMFuncCall
2: _ZN3tvm7runtime13PackedFun
1: tvm::runtime::TypedPackedFunc<tvm::RelayExpr (tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::DataType)>::AssignTypedLambda<tvm::RelayExpr (*)(tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::DataType)>(tvm::RelayExpr (*)(tvm::RelayExpr, tvm::RelayExpr, tvm::runtime::DataType), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
0: _ZN3tvm7runtime6deta
File "/workspace/tvm/include/tvm/runtime/packed_func.h", line 1731
TVMError: Function relax.op.full_like(0: RelayExpr, 1: RelayExpr, 2: DataType) -> RelayExpr expects 3 arguments, but 2 were provided.
Expected behavior
In the source code,
values = _ffi_api.full_like(x, value) # type: ignore
is supposed to be functional without dtype argument.Actual behavior
During compilation,
relax.op.full_like
must receive 3 arguments.TVMError: Function relax.op.full_like(0: RelayExpr, 1: RelayExpr, 2: DataType) -> RelayExpr expects 3 arguments, but 2 were provided.
I can manually replace
masked_fill(x, mask, value)
withfull_like(x, value, x.struct_info.dtype)
andwhere(mask, value, x)
to get the right result as a workaround.Environment
Steps to reproduce
Calling
relax.op.masked_fill(x, mask, value)
Triage
relax:op