move some methods to the global scope to fix the error:
File "/ChatRWKV/v2/../rwkv_pip_package/src/rwkv/model.py", line 484, in RWKV
def matmul(self, a, b, mx=None, rx=None, my=None, ry=None, output_dtype: Optional[torch.dtype]=None) -> torch.Tensor:
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1381, in script
fn = torch._C._jit_script_compile(
RuntimeError:
'Tensor (inferred)' object has no attribute or method 'mm8'.:
File "/ChatRWKV/v2/../rwkv_pip_package/src/rwkv/model.py", line 491
return matmul_float(a, b, output_dtype=output_dtype)
if b.dtype == torch.uint8:
return self.mm8(a, b, mx, rx, my, ry).to(output_dtype)
~~~~~~~~ <--- HERE
Add explicit type annotation to fix the error:
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1381, in script
fn = torch._C._jit_script_compile(
RuntimeError:
mm8(Tensor x, Tensor w, Tensor mx, Tensor rx, Tensor my, Tensor ry) -> Tensor:
Expected a value of type 'Tensor (inferred)' for argument 'mx' but instead found type 'Optional[Tensor]'.
Inferred 'mx' to be of type 'Tensor' because it was not annotated with an explicit type.
:
File "/ChatRWKV/v2/../rwkv_pip_package/src/rwkv/model.py", line 137
return matmul_float(a, b, output_dtype=output_dtype)
elif b.dtype == torch.uint8:
return mm8(a, b, mx, rx, my, ry).to(output_dtype)
~~~ <--- HERE
else:
raise ValueError("Unsupported dtype")
'matmul' is being compiled since it was called from 'RWKV.att_one'
move some methods to the global scope to fix the error:
Add explicit type annotation to fix the error:
Tested on v5/v4 models with jit on/off.