BlinkDL / ChatRWKV

ChatRWKV is like ChatGPT but powered by RWKV (100% RNN) language model, and open source.
Apache License 2.0
9.39k stars 689 forks source link

fix jit #182

Closed daquexian closed 11 months ago

daquexian commented 11 months ago

move some methods to the global scope to fix the error:

  File "/ChatRWKV/v2/../rwkv_pip_package/src/rwkv/model.py", line 484, in RWKV
    def matmul(self, a, b, mx=None, rx=None, my=None, ry=None, output_dtype: Optional[torch.dtype]=None) -> torch.Tensor:
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1381, in script
    fn = torch._C._jit_script_compile(
RuntimeError:
'Tensor (inferred)' object has no attribute or method 'mm8'.:
  File "/ChatRWKV/v2/../rwkv_pip_package/src/rwkv/model.py", line 491
            return matmul_float(a, b, output_dtype=output_dtype)
        if b.dtype == torch.uint8:
            return self.mm8(a, b, mx, rx, my, ry).to(output_dtype)
                   ~~~~~~~~ <--- HERE

Add explicit type annotation to fix the error:

  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1381, in script
    fn = torch._C._jit_script_compile(
RuntimeError:

mm8(Tensor x, Tensor w, Tensor mx, Tensor rx, Tensor my, Tensor ry) -> Tensor:
Expected a value of type 'Tensor (inferred)' for argument 'mx' but instead found type 'Optional[Tensor]'.
Inferred 'mx' to be of type 'Tensor' because it was not annotated with an explicit type.
:
  File "/ChatRWKV/v2/../rwkv_pip_package/src/rwkv/model.py", line 137
        return matmul_float(a, b, output_dtype=output_dtype)
    elif b.dtype == torch.uint8:
        return mm8(a, b, mx, rx, my, ry).to(output_dtype)
               ~~~ <--- HERE
    else:
        raise ValueError("Unsupported dtype")
'matmul' is being compiled since it was called from 'RWKV.att_one'

Tested on v5/v4 models with jit on/off.