harrisonvanderbyl / rwkvstic

Framework agnostic python runtime for RWKV models
https://hazzzardous-rwkv-instruct.hf.space
MIT License
144 stars 18 forks source link

macos+JAX failure with 2.x.x #18

Open dmahurin opened 1 year ago

dmahurin commented 1 year ago

The example (in the README) worked and macos (M2) with JAX backend and versions including 1.2.4, but stopped working with versions 2.x.x, including with 2.1.2.

The failure for all current 2.x.x versions with macos+JAX is:

rwkvstic/rwkvMaster.py", line 20, in init self.emptyState = emptyState.clone() AttributeError: 'Array' object has no attribute 'clone'

tensorpro commented 1 year ago

This seems to be since rwkvMaster calls .clone() (PyTorch's function) instead of .copy() (Jax/Numpy's version).

It might be worth abstracting the copying out with an agnostic OP?

tensorpro commented 1 year ago

FWIW I got a fork that works with Jax in https://github.com/tensorpro/rwkvstic, but it does so by breaking other backends.

harrisonvanderbyl commented 1 year ago

Can you please test the latest master push?

tensorpro commented 1 year ago

works for me, thanks for the fix!

dmahurin commented 1 year ago

seems to work for me as well using JAX.