Open dmahurin opened 1 year ago
This seems to be since rwkvMaster calls .clone()
(PyTorch's function) instead of .copy()
(Jax/Numpy's version).
It might be worth abstracting the copying out with an agnostic OP?
FWIW I got a fork that works with Jax in https://github.com/tensorpro/rwkvstic, but it does so by breaking other backends.
Can you please test the latest master push?
works for me, thanks for the fix!
seems to work for me as well using JAX.
The example (in the README) worked and macos (M2) with JAX backend and versions including 1.2.4, but stopped working with versions 2.x.x, including with 2.1.2.
The failure for all current 2.x.x versions with macos+JAX is:
rwkvstic/rwkvMaster.py", line 20, in init self.emptyState = emptyState.clone() AttributeError: 'Array' object has no attribute 'clone'